implement sumMapFiltered

This commit is contained in:
Léo Ercolanelli 2019-01-22 17:47:43 +01:00
parent be4eed19ae
commit 8ad1a55f3b
5 changed files with 92 additions and 7 deletions

View File

@ -12,10 +12,10 @@ namespace DB
namespace namespace
{ {
AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, const DataTypes & arguments, const Array & params) using SumMapArgs = std::pair<const DataTypePtr &, DataTypes>;
{
assertNoParameters(name, params);
SumMapArgs parseArguments(const std::string & name, const DataTypes & arguments)
{
if (arguments.size() < 2) if (arguments.size() < 2)
throw Exception("Aggregate function " + name + " requires at least two arguments of Array type.", throw Exception("Aggregate function " + name + " requires at least two arguments of Array type.",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
@ -25,6 +25,7 @@ AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, con
throw Exception("First argument for function " + name + " must be an array.", throw Exception("First argument for function " + name + " must be an array.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
const DataTypePtr & keys_type = array_type->getNestedType(); const DataTypePtr & keys_type = array_type->getNestedType();
DataTypes values_types; DataTypes values_types;
@ -37,6 +38,15 @@ AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, con
values_types.push_back(array_type->getNestedType()); values_types.push_back(array_type->getNestedType());
} }
return {keys_type, std::move(values_types)};
}
AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, const DataTypes & arguments, const Array & params)
{
assertNoParameters(name, params);
auto [keys_type, values_types] = parseArguments(name, arguments);
AggregateFunctionPtr res(createWithNumericBasedType<AggregateFunctionSumMap>(*keys_type, keys_type, values_types)); AggregateFunctionPtr res(createWithNumericBasedType<AggregateFunctionSumMap>(*keys_type, keys_type, values_types));
if (!res) if (!res)
res.reset(createWithDecimalType<AggregateFunctionSumMap>(*keys_type, keys_type, values_types)); res.reset(createWithDecimalType<AggregateFunctionSumMap>(*keys_type, keys_type, values_types));
@ -46,11 +56,30 @@ AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, con
return res; return res;
} }
AggregateFunctionPtr createAggregateFunctionSumMapFiltered(const std::string & name, const DataTypes & arguments, const Array & params)
{
if (params.size() != 1)
throw Exception("Aggregate function " + name + "requires exactly one parameter of Array type.",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
Array keys_to_keep = params.front().safeGet<Array>();
auto [keys_type, values_types] = parseArguments(name, arguments);
AggregateFunctionPtr res(createWithNumericBasedType<AggregateFunctionSumMapFiltered>(*keys_type, keys_type, values_types, keys_to_keep));
if (!res)
res.reset(createWithDecimalType<AggregateFunctionSumMapFiltered>(*keys_type, keys_type, values_types, keys_to_keep));
if (!res)
throw Exception("Illegal type of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return res;
}
} }
void registerAggregateFunctionSumMap(AggregateFunctionFactory & factory) void registerAggregateFunctionSumMap(AggregateFunctionFactory & factory)
{ {
factory.registerFunction("sumMap", createAggregateFunctionSumMap); factory.registerFunction("sumMap", createAggregateFunctionSumMap);
factory.registerFunction("sumMapFiltered", createAggregateFunctionSumMapFiltered);
} }
} }

View File

@ -50,9 +50,9 @@ struct AggregateFunctionSumMapData
* ([1,2,3,4,5,6,7,8,9,10],[10,10,45,20,35,20,15,30,20,20]) * ([1,2,3,4,5,6,7,8,9,10],[10,10,45,20,35,20,15,30,20,20])
*/ */
template <typename T> template <typename T, typename Derived>
class AggregateFunctionSumMap final : public IAggregateFunctionDataHelper< class AggregateFunctionSumMapBase : public IAggregateFunctionDataHelper<
AggregateFunctionSumMapData<NearestFieldType<T>>, AggregateFunctionSumMap<T>> AggregateFunctionSumMapData<NearestFieldType<T>>, Derived>
{ {
private: private:
using ColVecType = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>; using ColVecType = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>;
@ -61,7 +61,7 @@ private:
DataTypes values_types; DataTypes values_types;
public: public:
AggregateFunctionSumMap(const DataTypePtr & keys_type, const DataTypes & values_types) AggregateFunctionSumMapBase(const DataTypePtr & keys_type, const DataTypes & values_types)
: keys_type(keys_type), values_types(values_types) {} : keys_type(keys_type), values_types(values_types) {}
String getName() const override { return "sumMap"; } String getName() const override { return "sumMap"; }
@ -109,6 +109,11 @@ public:
array_column.getData().get(values_vec_offset + i, value); array_column.getData().get(values_vec_offset + i, value);
const auto & key = keys_vec.getData()[keys_vec_offset + i]; const auto & key = keys_vec.getData()[keys_vec_offset + i];
if (!keepKey(key))
{
continue;
}
IteratorType it; IteratorType it;
if constexpr (IsDecimalNumber<T>) if constexpr (IsDecimalNumber<T>)
{ {
@ -253,6 +258,47 @@ public:
} }
const char * getHeaderFilePath() const override { return __FILE__; } const char * getHeaderFilePath() const override { return __FILE__; }
virtual bool keepKey(const T & key) const = 0;
};
template <typename T>
class AggregateFunctionSumMap final : public AggregateFunctionSumMapBase<T, AggregateFunctionSumMap<T>>
{
public:
AggregateFunctionSumMap(const DataTypePtr & keys_type, DataTypes & values_types)
: AggregateFunctionSumMapBase<T, AggregateFunctionSumMap<T>>{keys_type, values_types}
{}
String getName() const override { return "sumMap"; }
bool keepKey(const T &) const override { return true; }
};
template <typename T>
class AggregateFunctionSumMapFiltered final : public AggregateFunctionSumMapBase<T, AggregateFunctionSumMapFiltered<T>>
{
private:
std::vector<T> keys_to_keep;
public:
AggregateFunctionSumMapFiltered(const DataTypePtr & keys_type, const DataTypes & values_types, const Array & keys_to_keep)
: AggregateFunctionSumMapBase<T, AggregateFunctionSumMapFiltered<T>>{keys_type, values_types}
{
this->keys_to_keep.reserve(keys_to_keep.size());
for (const Field & f : keys_to_keep)
{
this->keys_to_keep.emplace_back(f.safeGet<NearestFieldType<T>>());
}
std::sort(begin(this->keys_to_keep), end(this->keys_to_keep));
}
String getName() const override { return "sumMapFiltered"; }
bool keepKey(const T & key) const override
{
return std::binary_search(begin(keys_to_keep), end(keys_to_keep), key);
}
}; };
} }

View File

@ -8,6 +8,8 @@
2000-01-01 00:01:00 ([4,5,6,7,8],[10,10,20,10,10]) 2000-01-01 00:01:00 ([4,5,6,7,8],[10,10,20,10,10])
2000-01-01 00:00:00 [1,2,3,4,5] [10,10,20,10,10] 2000-01-01 00:00:00 [1,2,3,4,5] [10,10,20,10,10]
2000-01-01 00:01:00 [4,5,6,7,8] [10,10,20,10,10] 2000-01-01 00:01:00 [4,5,6,7,8] [10,10,20,10,10]
([1],[10])
([1, 4, 8], [10, 20, 10])
([1],[1]) ([1],[1])
([1],[1]) ([1],[1])
(['a'],[1]) (['a'],[1])

View File

@ -12,6 +12,9 @@ SELECT sumMapMerge(s) FROM (SELECT sumMapState(statusMap.status, statusMap.reque
SELECT timeslot, sumMap(statusMap.status, statusMap.requests) FROM test.sum_map GROUP BY timeslot ORDER BY timeslot; SELECT timeslot, sumMap(statusMap.status, statusMap.requests) FROM test.sum_map GROUP BY timeslot ORDER BY timeslot;
SELECT timeslot, sumMap(statusMap.status, statusMap.requests).1, sumMap(statusMap.status, statusMap.requests).2 FROM test.sum_map GROUP BY timeslot ORDER BY timeslot; SELECT timeslot, sumMap(statusMap.status, statusMap.requests).1, sumMap(statusMap.status, statusMap.requests).2 FROM test.sum_map GROUP BY timeslot ORDER BY timeslot;
SELECT sumMapFiltered([1])(statusMap.status, statusMap.requests) FROM test.sum_map;
SELECT sumMapFiltered([1, 4, 8])(statusMap.status, statusMap.requests) FROM test.sum_map;
DROP TABLE test.sum_map; DROP TABLE test.sum_map;
select sumMap(val, cnt) from ( SELECT [ CAST(1, 'UInt64') ] as val, [1] as cnt ); select sumMap(val, cnt) from ( SELECT [ CAST(1, 'UInt64') ] as val, [1] as cnt );

View File

@ -155,3 +155,8 @@ Solution: Write in the GROUP BY query SearchPhrase HAVING uniqUpTo(4)(UserID) >=
``` ```
[Original article](https://clickhouse.yandex/docs/en/query_language/agg_functions/parametric_functions/) <!--hide--> [Original article](https://clickhouse.yandex/docs/en/query_language/agg_functions/parametric_functions/) <!--hide-->
## sumMapFiltered(keys_to_keep)(keys, values)
Same behavior as [sumMap](reference.md#sumMap) except that an array of keys is passed as a parameter. This can be especially useful when working with a high cardinality of keys.