Merge pull request #24539 from ildus/map_combinator

add Map combinator for the Map type
This commit is contained in:
Kruglov Pavel 2021-11-03 11:25:23 +03:00 committed by GitHub
commit a22ab40468
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 506 additions and 9 deletions

View File

@ -25,6 +25,12 @@ Example 2: `uniqArray(arr)` Counts the number of unique elements in all a
-If and -Array can be combined. However, Array must come first, then If. Examples: `uniqArrayIf(arr, cond)`, `quantilesTimingArrayIf(level1, level2)(arr, cond)`. Due to this order, the cond argument wont be an array.
## -Map {#agg-functions-combinator-map}
The -Map suffix can be appended to any aggregate function. This will create an aggregate function which gets Map type as an argument, and aggregates values of each key of the map separately using the specified aggregate function. The result is also of a Map type.
Examples: `sumMap(map(1,1))`, `avgMap(map('a', 1))`.
## -SimpleState {#agg-functions-combinator-simplestate}
If you apply this combinator, the aggregate function returns the same value but with a different type. This is a [SimpleAggregateFunction(...)](../../sql-reference/data-types/simpleaggregatefunction.md) that can be stored in a table to work with [AggregatingMergeTree](../../engines/table-engines/mergetree-family/aggregatingmergetree.md) tables.

View File

@ -0,0 +1,135 @@
#include "AggregateFunctionMap.h"
#include "AggregateFunctions/AggregateFunctionCombinatorFactory.h"
#include "Functions/FunctionHelpers.h"
namespace DB
{
namespace ErrorCodes
{
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
class AggregateFunctionCombinatorMap final : public IAggregateFunctionCombinator
{
public:
String getName() const override { return "Map"; }
DataTypes transformArguments(const DataTypes & arguments) const override
{
if (arguments.empty())
throw Exception(
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Incorrect number of arguments for aggregate function with " + getName() + " suffix");
const auto * map_type = checkAndGetDataType<DataTypeMap>(arguments[0].get());
if (map_type)
{
if (arguments.size() > 1)
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, getName() + " combinator takes only one map argument");
return DataTypes({map_type->getValueType()});
}
// we need this part just to pass to redirection for mapped arrays
auto check_func = [](DataTypePtr t) { return t->getTypeId() == TypeIndex::Array; };
const auto * tup_type = checkAndGetDataType<DataTypeTuple>(arguments[0].get());
if (tup_type)
{
const auto & types = tup_type->getElements();
bool arrays_match = arguments.size() == 1 && types.size() >= 2 && std::all_of(types.begin(), types.end(), check_func);
if (arrays_match)
{
const auto * val_array_type = assert_cast<const DataTypeArray *>(types[1].get());
return DataTypes({val_array_type->getNestedType()});
}
}
else
{
bool arrays_match = arguments.size() >= 2 && std::all_of(arguments.begin(), arguments.end(), check_func);
if (arrays_match)
{
const auto * val_array_type = assert_cast<const DataTypeArray *>(arguments[1].get());
return DataTypes({val_array_type->getNestedType()});
}
}
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Aggregate function " + getName() + " requires map as argument");
}
AggregateFunctionPtr transformAggregateFunction(
const AggregateFunctionPtr & nested_function,
const AggregateFunctionProperties &,
const DataTypes & arguments,
const Array & params) const override
{
const auto * map_type = checkAndGetDataType<DataTypeMap>(arguments[0].get());
if (map_type)
{
const auto & key_type = map_type->getKeyType();
switch (key_type->getTypeId())
{
case TypeIndex::Enum8:
case TypeIndex::Int8:
return std::make_shared<AggregateFunctionMap<Int8>>(nested_function, arguments);
case TypeIndex::Enum16:
case TypeIndex::Int16:
return std::make_shared<AggregateFunctionMap<Int16>>(nested_function, arguments);
case TypeIndex::Int32:
return std::make_shared<AggregateFunctionMap<Int32>>(nested_function, arguments);
case TypeIndex::Int64:
return std::make_shared<AggregateFunctionMap<Int64>>(nested_function, arguments);
case TypeIndex::Int128:
return std::make_shared<AggregateFunctionMap<Int128>>(nested_function, arguments);
case TypeIndex::Int256:
return std::make_shared<AggregateFunctionMap<Int256>>(nested_function, arguments);
case TypeIndex::UInt8:
return std::make_shared<AggregateFunctionMap<UInt8>>(nested_function, arguments);
case TypeIndex::Date:
case TypeIndex::UInt16:
return std::make_shared<AggregateFunctionMap<UInt16>>(nested_function, arguments);
case TypeIndex::DateTime:
case TypeIndex::UInt32:
return std::make_shared<AggregateFunctionMap<UInt32>>(nested_function, arguments);
case TypeIndex::UInt64:
return std::make_shared<AggregateFunctionMap<UInt64>>(nested_function, arguments);
case TypeIndex::UInt128:
return std::make_shared<AggregateFunctionMap<UInt128>>(nested_function, arguments);
case TypeIndex::UInt256:
return std::make_shared<AggregateFunctionMap<UInt256>>(nested_function, arguments);
case TypeIndex::UUID:
return std::make_shared<AggregateFunctionMap<UUID>>(nested_function, arguments);
case TypeIndex::FixedString:
case TypeIndex::String:
return std::make_shared<AggregateFunctionMap<String>>(nested_function, arguments);
default:
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Map key type " + key_type->getName() + " is not is not supported by combinator " + getName());
}
}
else
{
// in case of tuple of arrays or just arrays (checked in transformArguments), try to redirect to sum/min/max-MappedArrays to implement old behavior
auto nested_func_name = nested_function->getName();
if (nested_func_name == "sum" || nested_func_name == "min" || nested_func_name == "max")
{
AggregateFunctionProperties out_properties;
auto & aggr_func_factory = AggregateFunctionFactory::instance();
return aggr_func_factory.get(nested_func_name + "MappedArrays", arguments, params, out_properties);
}
else
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Aggregation '" + nested_func_name + "Map' is not implemented for mapped arrays");
}
}
};
void registerAggregateFunctionCombinatorMap(AggregateFunctionCombinatorFactory & factory)
{
factory.registerCombinator(std::make_shared<AggregateFunctionCombinatorMap>());
}
}

View File

@ -0,0 +1,247 @@
#pragma once
#include <unordered_map>
#include <AggregateFunctions/AggregateFunctionCombinatorFactory.h>
#include <AggregateFunctions/IAggregateFunction.h>
#include <Columns/ColumnFixedString.h>
#include <Columns/ColumnMap.h>
#include <Columns/ColumnString.h>
#include <Columns/ColumnTuple.h>
#include <Columns/ColumnVector.h>
#include <Core/ColumnWithTypeAndName.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeMap.h>
#include <DataTypes/DataTypeTuple.h>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
#include "base/types.h"
#include <Common/Arena.h>
#include "AggregateFunctions/AggregateFunctionFactory.h"
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
template <typename KeyType>
struct AggregateFunctionMapCombinatorData
{
using SearchType = KeyType;
std::unordered_map<KeyType, AggregateDataPtr> merged_maps;
static void writeKey(KeyType key, WriteBuffer & buf) { writeBinary(key, buf); }
static void readKey(KeyType & key, ReadBuffer & buf) { readBinary(key, buf); }
};
template <>
struct AggregateFunctionMapCombinatorData<String>
{
struct StringHash
{
using hash_type = std::hash<std::string_view>;
using is_transparent = void;
size_t operator()(std::string_view str) const { return hash_type{}(str); }
};
#ifdef __cpp_lib_generic_unordered_lookup
using SearchType = std::string_view;
#else
using SearchType = std::string;
#endif
std::unordered_map<String, AggregateDataPtr, StringHash, std::equal_to<>> merged_maps;
static void writeKey(String key, WriteBuffer & buf)
{
writeVarUInt(key.size(), buf);
writeString(key, buf);
}
static void readKey(String & key, ReadBuffer & buf)
{
UInt64 size;
readVarUInt(size, buf);
key.resize(size);
buf.readStrict(key.data(), size);
}
};
template <typename KeyType>
class AggregateFunctionMap final
: public IAggregateFunctionDataHelper<AggregateFunctionMapCombinatorData<KeyType>, AggregateFunctionMap<KeyType>>
{
private:
DataTypePtr key_type;
AggregateFunctionPtr nested_func;
using Data = AggregateFunctionMapCombinatorData<KeyType>;
using Base = IAggregateFunctionDataHelper<Data, AggregateFunctionMap<KeyType>>;
public:
AggregateFunctionMap(AggregateFunctionPtr nested, const DataTypes & types) : Base(types, nested->getParameters()), nested_func(nested)
{
if (types.empty())
throw Exception(
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function " + getName() + " requires at least one argument");
if (types.size() > 1)
throw Exception(
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function " + getName() + " requires only one map argument");
const auto * map_type = checkAndGetDataType<DataTypeMap>(types[0].get());
if (!map_type)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Aggregate function " + getName() + " requires map as argument");
key_type = map_type->getKeyType();
}
String getName() const override { return nested_func->getName() + "Map"; }
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeMap>(DataTypes{key_type, nested_func->getReturnType()}); }
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{
const auto & map_column = assert_cast<const ColumnMap &>(*columns[0]);
const auto & map_nested_tuple = map_column.getNestedData();
const IColumn::Offsets & map_array_offsets = map_column.getNestedColumn().getOffsets();
const size_t offset = map_array_offsets[row_num - 1];
const size_t size = (map_array_offsets[row_num] - offset);
const auto & key_column = map_nested_tuple.getColumn(0);
const auto & val_column = map_nested_tuple.getColumn(1);
auto & merged_maps = this->data(place).merged_maps;
for (size_t i = 0; i < size; ++i)
{
typename Data::SearchType key;
if constexpr (std::is_same<KeyType, String>::value)
{
StringRef key_ref;
if (key_type->getTypeId() == TypeIndex::FixedString)
key_ref = assert_cast<const ColumnFixedString &>(key_column).getDataAt(offset + i);
else
key_ref = assert_cast<const ColumnString &>(key_column).getDataAt(offset + i);
#ifdef __cpp_lib_generic_unordered_lookup
key = static_cast<std::string_view>(key_ref);
#else
key = key_ref.toString();
#endif
}
else
{
key = assert_cast<const ColumnVector<KeyType> &>(key_column).getData()[offset + i];
}
AggregateDataPtr nested_place;
auto it = merged_maps.find(key);
if (it == merged_maps.end())
{
// create a new place for each key
nested_place = arena->alignedAlloc(nested_func->sizeOfData(), nested_func->alignOfData());
nested_func->create(nested_place);
merged_maps.emplace(key, nested_place);
}
else
nested_place = it->second;
const IColumn * nested_columns[1] = {&val_column};
nested_func->add(nested_place, nested_columns, offset + i, arena);
}
}
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena * arena) const override
{
auto & merged_maps = this->data(place).merged_maps;
const auto & rhs_maps = this->data(rhs).merged_maps;
for (const auto & elem : rhs_maps)
{
const auto & it = merged_maps.find(elem.first);
if (it != merged_maps.end())
{
nested_func->merge(it->second, elem.second, arena);
}
else
merged_maps[elem.first] = elem.second;
}
}
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{
auto & merged_maps = this->data(place).merged_maps;
writeVarUInt(merged_maps.size(), buf);
for (const auto & elem : merged_maps)
{
this->data(place).writeKey(elem.first, buf);
nested_func->serialize(elem.second, buf);
}
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena * arena) const override
{
auto & merged_maps = this->data(place).merged_maps;
UInt64 size;
readVarUInt(size, buf);
for (UInt64 i = 0; i < size; ++i)
{
KeyType key;
AggregateDataPtr nested_place;
this->data(place).readKey(key, buf);
nested_place = arena->alignedAlloc(nested_func->sizeOfData(), nested_func->alignOfData());
nested_func->create(nested_place);
merged_maps.emplace(key, nested_place);
nested_func->deserialize(nested_place, buf, arena);
}
}
void insertResultInto(AggregateDataPtr place, IColumn & to, Arena * arena) const override
{
auto & map_column = assert_cast<ColumnMap &>(to);
auto & nested_column = map_column.getNestedColumn();
auto & nested_data_column = map_column.getNestedData();
auto & key_column = nested_data_column.getColumn(0);
auto & val_column = nested_data_column.getColumn(1);
auto & merged_maps = this->data(place).merged_maps;
// sort the keys
std::vector<KeyType> keys;
keys.reserve(merged_maps.size());
for (auto & it : merged_maps)
{
keys.push_back(it.first);
}
std::sort(keys.begin(), keys.end());
// insert using sorted keys to result column
for (auto & key : keys)
{
key_column.insert(key);
nested_func->insertResultInto(merged_maps[key], val_column, arena);
}
IColumn::Offsets & res_offsets = nested_column.getOffsets();
res_offsets.push_back(val_column.size());
}
bool allocatesMemoryInArena() const override { return true; }
AggregateFunctionPtr getNestedFunction() const override { return nested_func; }
};
}

View File

@ -145,9 +145,20 @@ struct MaxMapDispatchOnTupleArgument
void registerAggregateFunctionSumMap(AggregateFunctionFactory & factory)
{
factory.registerFunction("sumMap", createAggregateFunctionMap<
// these functions used to be called *Map, with now these names occupied by
// Map combinator, which redirects calls here if was called with
// array or tuple arguments.
factory.registerFunction("sumMappedArrays", createAggregateFunctionMap<
SumMapVariants<false, false>::DispatchOnTupleArgument>);
factory.registerFunction("minMappedArrays",
createAggregateFunctionMap<MinMapDispatchOnTupleArgument>);
factory.registerFunction("maxMappedArrays",
createAggregateFunctionMap<MaxMapDispatchOnTupleArgument>);
// these functions could be renamed to *MappedArrays too, but it would
// break backward compatibility
factory.registerFunction("sumMapWithOverflow", createAggregateFunctionMap<
SumMapVariants<false, true>::DispatchOnTupleArgument>);
@ -157,12 +168,6 @@ void registerAggregateFunctionSumMap(AggregateFunctionFactory & factory)
factory.registerFunction("sumMapFilteredWithOverflow",
createAggregateFunctionMap<
SumMapVariants<true, true>::DispatchOnTupleArgument>);
factory.registerFunction("minMap",
createAggregateFunctionMap<MinMapDispatchOnTupleArgument>);
factory.registerFunction("maxMap",
createAggregateFunctionMap<MaxMapDispatchOnTupleArgument>);
}
}

View File

@ -377,7 +377,17 @@ public:
assertNoParameters(getName(), params_);
}
String getName() const override { return "sumMap"; }
String getName() const override
{
if constexpr (overflow)
{
return "sumMapWithOverflow";
}
else
{
return "sumMap";
}
}
bool keepKey(const T &) const { return true; }
};

View File

@ -65,6 +65,7 @@ void registerAggregateFunctionCombinatorNull(AggregateFunctionCombinatorFactory
void registerAggregateFunctionCombinatorOrFill(AggregateFunctionCombinatorFactory &);
void registerAggregateFunctionCombinatorResample(AggregateFunctionCombinatorFactory &);
void registerAggregateFunctionCombinatorDistinct(AggregateFunctionCombinatorFactory &);
void registerAggregateFunctionCombinatorMap(AggregateFunctionCombinatorFactory & factory);
void registerWindowFunctions(AggregateFunctionFactory & factory);
@ -134,6 +135,7 @@ void registerAggregateFunctions()
registerAggregateFunctionCombinatorOrFill(factory);
registerAggregateFunctionCombinatorResample(factory);
registerAggregateFunctionCombinatorDistinct(factory);
registerAggregateFunctionCombinatorMap(factory);
}
}

View File

@ -31,7 +31,8 @@ void DataTypeCustomSimpleAggregateFunction::checkSupportedFunctions(const Aggreg
/// TODO Make it sane.
static const std::vector<String> supported_functions{"any", "anyLast", "min",
"max", "sum", "sumWithOverflow", "groupBitAnd", "groupBitOr", "groupBitXor",
"sumMap", "minMap", "maxMap", "groupArrayArray", "groupUniqArrayArray"};
"sumMap", "minMap", "maxMap", "groupArrayArray", "groupUniqArrayArray",
"sumMappedArrays", "minMappedArrays", "maxMappedArrays"};
// check function
if (std::find(std::begin(supported_functions), std::end(supported_functions), function->getName()) == std::end(supported_functions))

View File

@ -0,0 +1,34 @@
1 {1:10,2:10,3:10}
1 {3:10,4:10,5:10}
2 {4:10,5:10,6:10}
2 {6:10,7:10,8:10}
3 {1:10,2:10,3:10}
4 {3:10,4:10,5:10}
5 {4:10,5:10,6:10}
5 {6:10,7:10,8:10}
Map(UInt16, UInt64) {1:20,2:20,3:40,4:40,5:40,6:40,7:20,8:20}
Map(UInt16, UInt32) {1:20,2:20,3:40,4:40,5:40,6:40,7:20,8:20}
Map(UInt16, UInt64) {1:20,2:20,3:40,4:40,5:40,6:40,7:20,8:20}
{1:10,2:10,3:10,4:10,5:10,6:10,7:10,8:10}
{1:10,2:10,3:10,4:10,5:10,6:10,7:10,8:10}
Map(UInt16, Float64) {1:10,2:10,3:10,4:10,5:10,6:10,7:10,8:10}
{1:2,2:2,3:4,4:4,5:4,6:4,7:2,8:2}
1 {1:10,2:10,3:20,4:10,5:10}
2 {4:10,5:10,6:20,7:10,8:10}
3 {1:10,2:10,3:10}
4 {3:10,4:10,5:10}
5 {4:10,5:10,6:20,7:10,8:10}
{'01234567-89ab-cdef-0123-456789abcdef':1}
{'1':'1'}
{'1':'1'}
{1:1}
{'1970-01-02':1}
{'1970-01-01 03:00:01':1}
{'a':1}
{'1':'2'}
{1:1}
{1:1}
{1:1}
{1:1}
{1:1,2:2,3:6,4:8,5:10,6:12,7:7,8:8}
{1:1,2:2,3:6,4:8,5:10,6:12,7:7,8:8}

View File

@ -0,0 +1,57 @@
SET send_logs_level = 'fatal';
SET allow_experimental_map_type = 1;
DROP TABLE IF EXISTS map_comb;
CREATE TABLE map_comb(a int, statusMap Map(UInt16, UInt32)) ENGINE = Log;
INSERT INTO map_comb VALUES (1, map(1, 10, 2, 10, 3, 10)),(1, map(3, 10, 4, 10, 5, 10)),(2, map(4, 10, 5, 10, 6, 10)),(2, map(6, 10, 7, 10, 8, 10)),(3, map(1, 10, 2, 10, 3, 10)),(4, map(3, 10, 4, 10, 5, 10)),(5, map(4, 10, 5, 10, 6, 10)),(5, map(6, 10, 7, 10, 8, 10));
SELECT * FROM map_comb ORDER BY a;
SELECT toTypeName(res), sumMap(statusMap) as res FROM map_comb;
SELECT toTypeName(res), sumWithOverflowMap(statusMap) as res FROM map_comb;
SELECT toTypeName(res), sumMapMerge(s) as res FROM (SELECT sumMapState(statusMap) AS s FROM map_comb);
SELECT minMap(statusMap) FROM map_comb;
SELECT maxMap(statusMap) FROM map_comb;
SELECT toTypeName(res), avgMap(statusMap) as res FROM map_comb;
SELECT countMap(statusMap) FROM map_comb;
SELECT a, sumMap(statusMap) FROM map_comb GROUP BY a ORDER BY a;
DROP TABLE map_comb;
-- check different types
select minMap(val) from values ('val Map(UUID, Int32)',
(map('01234567-89ab-cdef-0123-456789abcdef', 1)),
(map('01234567-89ab-cdef-0123-456789abcdef', 2)));
select minMap(val) from values ('val Map(String, String)', (map('1', '1')), (map('1', '2')));
select minMap(val) from values ('val Map(FixedString(1), FixedString(1))', (map('1', '1')), (map('1', '2')));
select minMap(val) from values ('val Map(UInt64, UInt64)', (map(1, 1)), (map(1, 2)));
select minMap(val) from values ('val Map(Date, Int16)', (map(1, 1)), (map(1, 2)));
select minMap(val) from values ('val Map(DateTime(\'Europe/Moscow\'), Int32)', (map(1, 1)), (map(1, 2)));
select minMap(val) from values ('val Map(Enum16(\'a\'=1), Int16)', (map('a', 1)), (map('a', 2)));
select maxMap(val) from values ('val Map(String, String)', (map('1', '1')), (map('1', '2')));
select minMap(val) from values ('val Map(Int128, Int128)', (map(1, 1)), (map(1, 2)));
select minMap(val) from values ('val Map(Int256, Int256)', (map(1, 1)), (map(1, 2)));
select minMap(val) from values ('val Map(UInt128, UInt128)', (map(1, 1)), (map(1, 2)));
select minMap(val) from values ('val Map(UInt256, UInt256)', (map(1, 1)), (map(1, 2)));
select sumMap(map(1,2), 1, 2); -- { serverError 42 }
select sumMap(map(1,2), map(1,3)); -- { serverError 42 }
-- array and tuple arguments
select avgMap([1,1,1], [2,2,2]); -- { serverError 43 }
select minMap((1,1)); -- { serverError 43 }
select minMap(([1,1,1],1)); -- { serverError 43 }
select minMap([1,1,1],1); -- { serverError 43 }
select minMap([1,1,1]); -- { serverError 43 }
select minMap(([1,1,1])); -- { serverError 43 }
DROP TABLE IF EXISTS sum_map_decimal;
CREATE TABLE sum_map_decimal(statusMap Map(UInt16,Decimal32(5))) ENGINE = Log;
INSERT INTO sum_map_decimal VALUES (map(1,'1.0',2,'2.0',3,'3.0')), (map(3,'3.0',4,'4.0',5,'5.0')), (map(4,'4.0',5,'5.0',6,'6.0')), (map(6,'6.0',7,'7.0',8,'8.0'));
SELECT sumMap(statusMap) FROM sum_map_decimal;
SELECT sumWithOverflowMap(statusMap) FROM sum_map_decimal;
DROP TABLE sum_map_decimal;