mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-09-20 08:40:50 +00:00
implement sumMapFiltered
This commit is contained in:
parent
be4eed19ae
commit
8ad1a55f3b
@ -12,10 +12,10 @@ namespace DB
|
|||||||
namespace
|
namespace
|
||||||
{
|
{
|
||||||
|
|
||||||
AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, const DataTypes & arguments, const Array & params)
|
using SumMapArgs = std::pair<const DataTypePtr &, DataTypes>;
|
||||||
{
|
|
||||||
assertNoParameters(name, params);
|
|
||||||
|
|
||||||
|
SumMapArgs parseArguments(const std::string & name, const DataTypes & arguments)
|
||||||
|
{
|
||||||
if (arguments.size() < 2)
|
if (arguments.size() < 2)
|
||||||
throw Exception("Aggregate function " + name + " requires at least two arguments of Array type.",
|
throw Exception("Aggregate function " + name + " requires at least two arguments of Array type.",
|
||||||
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||||
@ -25,6 +25,7 @@ AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, con
|
|||||||
throw Exception("First argument for function " + name + " must be an array.",
|
throw Exception("First argument for function " + name + " must be an array.",
|
||||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||||
|
|
||||||
|
|
||||||
const DataTypePtr & keys_type = array_type->getNestedType();
|
const DataTypePtr & keys_type = array_type->getNestedType();
|
||||||
|
|
||||||
DataTypes values_types;
|
DataTypes values_types;
|
||||||
@ -37,6 +38,15 @@ AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, con
|
|||||||
values_types.push_back(array_type->getNestedType());
|
values_types.push_back(array_type->getNestedType());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return {keys_type, std::move(values_types)};
|
||||||
|
}
|
||||||
|
|
||||||
|
AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, const DataTypes & arguments, const Array & params)
|
||||||
|
{
|
||||||
|
assertNoParameters(name, params);
|
||||||
|
|
||||||
|
auto [keys_type, values_types] = parseArguments(name, arguments);
|
||||||
|
|
||||||
AggregateFunctionPtr res(createWithNumericBasedType<AggregateFunctionSumMap>(*keys_type, keys_type, values_types));
|
AggregateFunctionPtr res(createWithNumericBasedType<AggregateFunctionSumMap>(*keys_type, keys_type, values_types));
|
||||||
if (!res)
|
if (!res)
|
||||||
res.reset(createWithDecimalType<AggregateFunctionSumMap>(*keys_type, keys_type, values_types));
|
res.reset(createWithDecimalType<AggregateFunctionSumMap>(*keys_type, keys_type, values_types));
|
||||||
@ -46,11 +56,30 @@ AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, con
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AggregateFunctionPtr createAggregateFunctionSumMapFiltered(const std::string & name, const DataTypes & arguments, const Array & params)
|
||||||
|
{
|
||||||
|
if (params.size() != 1)
|
||||||
|
throw Exception("Aggregate function " + name + "requires exactly one parameter of Array type.",
|
||||||
|
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||||
|
|
||||||
|
Array keys_to_keep = params.front().safeGet<Array>();
|
||||||
|
|
||||||
|
auto [keys_type, values_types] = parseArguments(name, arguments);
|
||||||
|
|
||||||
|
AggregateFunctionPtr res(createWithNumericBasedType<AggregateFunctionSumMapFiltered>(*keys_type, keys_type, values_types, keys_to_keep));
|
||||||
|
if (!res)
|
||||||
|
res.reset(createWithDecimalType<AggregateFunctionSumMapFiltered>(*keys_type, keys_type, values_types, keys_to_keep));
|
||||||
|
if (!res)
|
||||||
|
throw Exception("Illegal type of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void registerAggregateFunctionSumMap(AggregateFunctionFactory & factory)
|
void registerAggregateFunctionSumMap(AggregateFunctionFactory & factory)
|
||||||
{
|
{
|
||||||
factory.registerFunction("sumMap", createAggregateFunctionSumMap);
|
factory.registerFunction("sumMap", createAggregateFunctionSumMap);
|
||||||
|
factory.registerFunction("sumMapFiltered", createAggregateFunctionSumMapFiltered);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -50,9 +50,9 @@ struct AggregateFunctionSumMapData
|
|||||||
* ([1,2,3,4,5,6,7,8,9,10],[10,10,45,20,35,20,15,30,20,20])
|
* ([1,2,3,4,5,6,7,8,9,10],[10,10,45,20,35,20,15,30,20,20])
|
||||||
*/
|
*/
|
||||||
|
|
||||||
template <typename T>
|
template <typename T, typename Derived>
|
||||||
class AggregateFunctionSumMap final : public IAggregateFunctionDataHelper<
|
class AggregateFunctionSumMapBase : public IAggregateFunctionDataHelper<
|
||||||
AggregateFunctionSumMapData<NearestFieldType<T>>, AggregateFunctionSumMap<T>>
|
AggregateFunctionSumMapData<NearestFieldType<T>>, Derived>
|
||||||
{
|
{
|
||||||
private:
|
private:
|
||||||
using ColVecType = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>;
|
using ColVecType = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>;
|
||||||
@ -61,7 +61,7 @@ private:
|
|||||||
DataTypes values_types;
|
DataTypes values_types;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
AggregateFunctionSumMap(const DataTypePtr & keys_type, const DataTypes & values_types)
|
AggregateFunctionSumMapBase(const DataTypePtr & keys_type, const DataTypes & values_types)
|
||||||
: keys_type(keys_type), values_types(values_types) {}
|
: keys_type(keys_type), values_types(values_types) {}
|
||||||
|
|
||||||
String getName() const override { return "sumMap"; }
|
String getName() const override { return "sumMap"; }
|
||||||
@ -109,6 +109,11 @@ public:
|
|||||||
array_column.getData().get(values_vec_offset + i, value);
|
array_column.getData().get(values_vec_offset + i, value);
|
||||||
const auto & key = keys_vec.getData()[keys_vec_offset + i];
|
const auto & key = keys_vec.getData()[keys_vec_offset + i];
|
||||||
|
|
||||||
|
if (!keepKey(key))
|
||||||
|
{
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
IteratorType it;
|
IteratorType it;
|
||||||
if constexpr (IsDecimalNumber<T>)
|
if constexpr (IsDecimalNumber<T>)
|
||||||
{
|
{
|
||||||
@ -253,6 +258,47 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
const char * getHeaderFilePath() const override { return __FILE__; }
|
const char * getHeaderFilePath() const override { return __FILE__; }
|
||||||
|
|
||||||
|
virtual bool keepKey(const T & key) const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class AggregateFunctionSumMap final : public AggregateFunctionSumMapBase<T, AggregateFunctionSumMap<T>>
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
AggregateFunctionSumMap(const DataTypePtr & keys_type, DataTypes & values_types)
|
||||||
|
: AggregateFunctionSumMapBase<T, AggregateFunctionSumMap<T>>{keys_type, values_types}
|
||||||
|
{}
|
||||||
|
|
||||||
|
String getName() const override { return "sumMap"; }
|
||||||
|
|
||||||
|
bool keepKey(const T &) const override { return true; }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class AggregateFunctionSumMapFiltered final : public AggregateFunctionSumMapBase<T, AggregateFunctionSumMapFiltered<T>>
|
||||||
|
{
|
||||||
|
private:
|
||||||
|
std::vector<T> keys_to_keep;
|
||||||
|
|
||||||
|
public:
|
||||||
|
AggregateFunctionSumMapFiltered(const DataTypePtr & keys_type, const DataTypes & values_types, const Array & keys_to_keep)
|
||||||
|
: AggregateFunctionSumMapBase<T, AggregateFunctionSumMapFiltered<T>>{keys_type, values_types}
|
||||||
|
{
|
||||||
|
this->keys_to_keep.reserve(keys_to_keep.size());
|
||||||
|
for (const Field & f : keys_to_keep)
|
||||||
|
{
|
||||||
|
this->keys_to_keep.emplace_back(f.safeGet<NearestFieldType<T>>());
|
||||||
|
}
|
||||||
|
std::sort(begin(this->keys_to_keep), end(this->keys_to_keep));
|
||||||
|
}
|
||||||
|
|
||||||
|
String getName() const override { return "sumMapFiltered"; }
|
||||||
|
|
||||||
|
bool keepKey(const T & key) const override
|
||||||
|
{
|
||||||
|
return std::binary_search(begin(keys_to_keep), end(keys_to_keep), key);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -8,6 +8,8 @@
|
|||||||
2000-01-01 00:01:00 ([4,5,6,7,8],[10,10,20,10,10])
|
2000-01-01 00:01:00 ([4,5,6,7,8],[10,10,20,10,10])
|
||||||
2000-01-01 00:00:00 [1,2,3,4,5] [10,10,20,10,10]
|
2000-01-01 00:00:00 [1,2,3,4,5] [10,10,20,10,10]
|
||||||
2000-01-01 00:01:00 [4,5,6,7,8] [10,10,20,10,10]
|
2000-01-01 00:01:00 [4,5,6,7,8] [10,10,20,10,10]
|
||||||
|
([1],[10])
|
||||||
|
([1, 4, 8], [10, 20, 10])
|
||||||
([1],[1])
|
([1],[1])
|
||||||
([1],[1])
|
([1],[1])
|
||||||
(['a'],[1])
|
(['a'],[1])
|
||||||
|
@ -12,6 +12,9 @@ SELECT sumMapMerge(s) FROM (SELECT sumMapState(statusMap.status, statusMap.reque
|
|||||||
SELECT timeslot, sumMap(statusMap.status, statusMap.requests) FROM test.sum_map GROUP BY timeslot ORDER BY timeslot;
|
SELECT timeslot, sumMap(statusMap.status, statusMap.requests) FROM test.sum_map GROUP BY timeslot ORDER BY timeslot;
|
||||||
SELECT timeslot, sumMap(statusMap.status, statusMap.requests).1, sumMap(statusMap.status, statusMap.requests).2 FROM test.sum_map GROUP BY timeslot ORDER BY timeslot;
|
SELECT timeslot, sumMap(statusMap.status, statusMap.requests).1, sumMap(statusMap.status, statusMap.requests).2 FROM test.sum_map GROUP BY timeslot ORDER BY timeslot;
|
||||||
|
|
||||||
|
SELECT sumMapFiltered([1])(statusMap.status, statusMap.requests) FROM test.sum_map;
|
||||||
|
SELECT sumMapFiltered([1, 4, 8])(statusMap.status, statusMap.requests) FROM test.sum_map;
|
||||||
|
|
||||||
DROP TABLE test.sum_map;
|
DROP TABLE test.sum_map;
|
||||||
|
|
||||||
select sumMap(val, cnt) from ( SELECT [ CAST(1, 'UInt64') ] as val, [1] as cnt );
|
select sumMap(val, cnt) from ( SELECT [ CAST(1, 'UInt64') ] as val, [1] as cnt );
|
||||||
|
@ -155,3 +155,8 @@ Solution: Write in the GROUP BY query SearchPhrase HAVING uniqUpTo(4)(UserID) >=
|
|||||||
```
|
```
|
||||||
|
|
||||||
[Original article](https://clickhouse.yandex/docs/en/query_language/agg_functions/parametric_functions/) <!--hide-->
|
[Original article](https://clickhouse.yandex/docs/en/query_language/agg_functions/parametric_functions/) <!--hide-->
|
||||||
|
|
||||||
|
|
||||||
|
## sumMapFiltered(keys_to_keep)(keys, values)
|
||||||
|
|
||||||
|
Same behavior as [sumMap](reference.md#sumMap) except that an array of keys is passed as a parameter. This can be especially useful when working with a high cardinality of keys.
|
||||||
|
Loading…
Reference in New Issue
Block a user