mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-21 15:12:02 +00:00
implement sumMapFiltered
This commit is contained in:
parent
be4eed19ae
commit
8ad1a55f3b
@ -12,10 +12,10 @@ namespace DB
|
||||
namespace
|
||||
{
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, const DataTypes & arguments, const Array & params)
|
||||
{
|
||||
assertNoParameters(name, params);
|
||||
using SumMapArgs = std::pair<const DataTypePtr &, DataTypes>;
|
||||
|
||||
SumMapArgs parseArguments(const std::string & name, const DataTypes & arguments)
|
||||
{
|
||||
if (arguments.size() < 2)
|
||||
throw Exception("Aggregate function " + name + " requires at least two arguments of Array type.",
|
||||
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
@ -25,6 +25,7 @@ AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, con
|
||||
throw Exception("First argument for function " + name + " must be an array.",
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
|
||||
const DataTypePtr & keys_type = array_type->getNestedType();
|
||||
|
||||
DataTypes values_types;
|
||||
@ -37,6 +38,15 @@ AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, con
|
||||
values_types.push_back(array_type->getNestedType());
|
||||
}
|
||||
|
||||
return {keys_type, std::move(values_types)};
|
||||
}
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, const DataTypes & arguments, const Array & params)
|
||||
{
|
||||
assertNoParameters(name, params);
|
||||
|
||||
auto [keys_type, values_types] = parseArguments(name, arguments);
|
||||
|
||||
AggregateFunctionPtr res(createWithNumericBasedType<AggregateFunctionSumMap>(*keys_type, keys_type, values_types));
|
||||
if (!res)
|
||||
res.reset(createWithDecimalType<AggregateFunctionSumMap>(*keys_type, keys_type, values_types));
|
||||
@ -46,11 +56,30 @@ AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, con
|
||||
return res;
|
||||
}
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionSumMapFiltered(const std::string & name, const DataTypes & arguments, const Array & params)
|
||||
{
|
||||
if (params.size() != 1)
|
||||
throw Exception("Aggregate function " + name + "requires exactly one parameter of Array type.",
|
||||
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
|
||||
Array keys_to_keep = params.front().safeGet<Array>();
|
||||
|
||||
auto [keys_type, values_types] = parseArguments(name, arguments);
|
||||
|
||||
AggregateFunctionPtr res(createWithNumericBasedType<AggregateFunctionSumMapFiltered>(*keys_type, keys_type, values_types, keys_to_keep));
|
||||
if (!res)
|
||||
res.reset(createWithDecimalType<AggregateFunctionSumMapFiltered>(*keys_type, keys_type, values_types, keys_to_keep));
|
||||
if (!res)
|
||||
throw Exception("Illegal type of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
||||
void registerAggregateFunctionSumMap(AggregateFunctionFactory & factory)
|
||||
{
|
||||
factory.registerFunction("sumMap", createAggregateFunctionSumMap);
|
||||
factory.registerFunction("sumMapFiltered", createAggregateFunctionSumMapFiltered);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -50,9 +50,9 @@ struct AggregateFunctionSumMapData
|
||||
* ([1,2,3,4,5,6,7,8,9,10],[10,10,45,20,35,20,15,30,20,20])
|
||||
*/
|
||||
|
||||
template <typename T>
|
||||
class AggregateFunctionSumMap final : public IAggregateFunctionDataHelper<
|
||||
AggregateFunctionSumMapData<NearestFieldType<T>>, AggregateFunctionSumMap<T>>
|
||||
template <typename T, typename Derived>
|
||||
class AggregateFunctionSumMapBase : public IAggregateFunctionDataHelper<
|
||||
AggregateFunctionSumMapData<NearestFieldType<T>>, Derived>
|
||||
{
|
||||
private:
|
||||
using ColVecType = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>;
|
||||
@ -61,7 +61,7 @@ private:
|
||||
DataTypes values_types;
|
||||
|
||||
public:
|
||||
AggregateFunctionSumMap(const DataTypePtr & keys_type, const DataTypes & values_types)
|
||||
AggregateFunctionSumMapBase(const DataTypePtr & keys_type, const DataTypes & values_types)
|
||||
: keys_type(keys_type), values_types(values_types) {}
|
||||
|
||||
String getName() const override { return "sumMap"; }
|
||||
@ -109,6 +109,11 @@ public:
|
||||
array_column.getData().get(values_vec_offset + i, value);
|
||||
const auto & key = keys_vec.getData()[keys_vec_offset + i];
|
||||
|
||||
if (!keepKey(key))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
IteratorType it;
|
||||
if constexpr (IsDecimalNumber<T>)
|
||||
{
|
||||
@ -253,6 +258,47 @@ public:
|
||||
}
|
||||
|
||||
const char * getHeaderFilePath() const override { return __FILE__; }
|
||||
|
||||
virtual bool keepKey(const T & key) const = 0;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class AggregateFunctionSumMap final : public AggregateFunctionSumMapBase<T, AggregateFunctionSumMap<T>>
|
||||
{
|
||||
public:
|
||||
AggregateFunctionSumMap(const DataTypePtr & keys_type, DataTypes & values_types)
|
||||
: AggregateFunctionSumMapBase<T, AggregateFunctionSumMap<T>>{keys_type, values_types}
|
||||
{}
|
||||
|
||||
String getName() const override { return "sumMap"; }
|
||||
|
||||
bool keepKey(const T &) const override { return true; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class AggregateFunctionSumMapFiltered final : public AggregateFunctionSumMapBase<T, AggregateFunctionSumMapFiltered<T>>
|
||||
{
|
||||
private:
|
||||
std::vector<T> keys_to_keep;
|
||||
|
||||
public:
|
||||
AggregateFunctionSumMapFiltered(const DataTypePtr & keys_type, const DataTypes & values_types, const Array & keys_to_keep)
|
||||
: AggregateFunctionSumMapBase<T, AggregateFunctionSumMapFiltered<T>>{keys_type, values_types}
|
||||
{
|
||||
this->keys_to_keep.reserve(keys_to_keep.size());
|
||||
for (const Field & f : keys_to_keep)
|
||||
{
|
||||
this->keys_to_keep.emplace_back(f.safeGet<NearestFieldType<T>>());
|
||||
}
|
||||
std::sort(begin(this->keys_to_keep), end(this->keys_to_keep));
|
||||
}
|
||||
|
||||
String getName() const override { return "sumMapFiltered"; }
|
||||
|
||||
bool keepKey(const T & key) const override
|
||||
{
|
||||
return std::binary_search(begin(keys_to_keep), end(keys_to_keep), key);
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
|
@ -8,6 +8,8 @@
|
||||
2000-01-01 00:01:00 ([4,5,6,7,8],[10,10,20,10,10])
|
||||
2000-01-01 00:00:00 [1,2,3,4,5] [10,10,20,10,10]
|
||||
2000-01-01 00:01:00 [4,5,6,7,8] [10,10,20,10,10]
|
||||
([1],[10])
|
||||
([1, 4, 8], [10, 20, 10])
|
||||
([1],[1])
|
||||
([1],[1])
|
||||
(['a'],[1])
|
||||
|
@ -12,6 +12,9 @@ SELECT sumMapMerge(s) FROM (SELECT sumMapState(statusMap.status, statusMap.reque
|
||||
SELECT timeslot, sumMap(statusMap.status, statusMap.requests) FROM test.sum_map GROUP BY timeslot ORDER BY timeslot;
|
||||
SELECT timeslot, sumMap(statusMap.status, statusMap.requests).1, sumMap(statusMap.status, statusMap.requests).2 FROM test.sum_map GROUP BY timeslot ORDER BY timeslot;
|
||||
|
||||
SELECT sumMapFiltered([1])(statusMap.status, statusMap.requests) FROM test.sum_map;
|
||||
SELECT sumMapFiltered([1, 4, 8])(statusMap.status, statusMap.requests) FROM test.sum_map;
|
||||
|
||||
DROP TABLE test.sum_map;
|
||||
|
||||
select sumMap(val, cnt) from ( SELECT [ CAST(1, 'UInt64') ] as val, [1] as cnt );
|
||||
|
@ -155,3 +155,8 @@ Solution: Write in the GROUP BY query SearchPhrase HAVING uniqUpTo(4)(UserID) >=
|
||||
```
|
||||
|
||||
[Original article](https://clickhouse.yandex/docs/en/query_language/agg_functions/parametric_functions/) <!--hide-->
|
||||
|
||||
|
||||
## sumMapFiltered(keys_to_keep)(keys, values)
|
||||
|
||||
Same behavior as [sumMap](reference.md#sumMap) except that an array of keys is passed as a parameter. This can be especially useful when working with a high cardinality of keys.
|
||||
|
Loading…
Reference in New Issue
Block a user