mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-09-20 00:30:49 +00:00
Merge pull request #4129 from ercolanelli-leo/sumMapFiltered
implement sumMapFiltered
This commit is contained in:
commit
fbefc99fb1
@ -12,10 +12,10 @@ namespace DB
|
||||
namespace
|
||||
{
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, const DataTypes & arguments, const Array & params)
|
||||
{
|
||||
assertNoParameters(name, params);
|
||||
using SumMapArgs = std::pair<DataTypePtr, DataTypes>;
|
||||
|
||||
SumMapArgs parseArguments(const std::string & name, const DataTypes & arguments)
|
||||
{
|
||||
if (arguments.size() < 2)
|
||||
throw Exception("Aggregate function " + name + " requires at least two arguments of Array type.",
|
||||
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
@ -25,9 +25,11 @@ AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, con
|
||||
throw Exception("First argument for function " + name + " must be an array.",
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
const DataTypePtr & keys_type = array_type->getNestedType();
|
||||
|
||||
DataTypePtr keys_type = array_type->getNestedType();
|
||||
|
||||
DataTypes values_types;
|
||||
values_types.reserve(arguments.size() - 1);
|
||||
for (size_t i = 1; i < arguments.size(); ++i)
|
||||
{
|
||||
array_type = checkAndGetDataType<DataTypeArray>(arguments[i].get());
|
||||
@ -37,6 +39,15 @@ AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, con
|
||||
values_types.push_back(array_type->getNestedType());
|
||||
}
|
||||
|
||||
return {std::move(keys_type), std::move(values_types)};
|
||||
}
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, const DataTypes & arguments, const Array & params)
|
||||
{
|
||||
assertNoParameters(name, params);
|
||||
|
||||
auto [keys_type, values_types] = parseArguments(name, arguments);
|
||||
|
||||
AggregateFunctionPtr res(createWithNumericBasedType<AggregateFunctionSumMap>(*keys_type, keys_type, values_types));
|
||||
if (!res)
|
||||
res.reset(createWithDecimalType<AggregateFunctionSumMap>(*keys_type, keys_type, values_types));
|
||||
@ -46,11 +57,33 @@ AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, con
|
||||
return res;
|
||||
}
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionSumMapFiltered(const std::string & name, const DataTypes & arguments, const Array & params)
|
||||
{
|
||||
if (params.size() != 1)
|
||||
throw Exception("Aggregate function " + name + " requires exactly one parameter of Array type.",
|
||||
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
|
||||
Array keys_to_keep;
|
||||
if (!params.front().tryGet<Array>(keys_to_keep))
|
||||
throw Exception("Aggregate function " + name + " requires an Array as parameter.",
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
auto [keys_type, values_types] = parseArguments(name, arguments);
|
||||
|
||||
AggregateFunctionPtr res(createWithNumericBasedType<AggregateFunctionSumMapFiltered>(*keys_type, keys_type, values_types, keys_to_keep));
|
||||
if (!res)
|
||||
res.reset(createWithDecimalType<AggregateFunctionSumMapFiltered>(*keys_type, keys_type, values_types, keys_to_keep));
|
||||
if (!res)
|
||||
throw Exception("Illegal type of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
||||
void registerAggregateFunctionSumMap(AggregateFunctionFactory & factory)
|
||||
{
|
||||
factory.registerFunction("sumMap", createAggregateFunctionSumMap);
|
||||
factory.registerFunction("sumMapFiltered", createAggregateFunctionSumMapFiltered);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -50,9 +50,9 @@ struct AggregateFunctionSumMapData
|
||||
* ([1,2,3,4,5,6,7,8,9,10],[10,10,45,20,35,20,15,30,20,20])
|
||||
*/
|
||||
|
||||
template <typename T>
|
||||
class AggregateFunctionSumMap final : public IAggregateFunctionDataHelper<
|
||||
AggregateFunctionSumMapData<NearestFieldType<T>>, AggregateFunctionSumMap<T>>
|
||||
template <typename T, typename Derived>
|
||||
class AggregateFunctionSumMapBase : public IAggregateFunctionDataHelper<
|
||||
AggregateFunctionSumMapData<NearestFieldType<T>>, Derived>
|
||||
{
|
||||
private:
|
||||
using ColVecType = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>;
|
||||
@ -61,7 +61,7 @@ private:
|
||||
DataTypes values_types;
|
||||
|
||||
public:
|
||||
AggregateFunctionSumMap(const DataTypePtr & keys_type, const DataTypes & values_types)
|
||||
AggregateFunctionSumMapBase(const DataTypePtr & keys_type, const DataTypes & values_types)
|
||||
: keys_type(keys_type), values_types(values_types) {}
|
||||
|
||||
String getName() const override { return "sumMap"; }
|
||||
@ -109,6 +109,11 @@ public:
|
||||
array_column.getData().get(values_vec_offset + i, value);
|
||||
const auto & key = keys_vec.getData()[keys_vec_offset + i];
|
||||
|
||||
if (!keepKey(key))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
IteratorType it;
|
||||
if constexpr (IsDecimalNumber<T>)
|
||||
{
|
||||
@ -253,6 +258,43 @@ public:
|
||||
}
|
||||
|
||||
const char * getHeaderFilePath() const override { return __FILE__; }
|
||||
|
||||
bool keepKey(const T & key) const { return static_cast<const Derived &>(*this).keepKey(key); }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class AggregateFunctionSumMap final : public AggregateFunctionSumMapBase<T, AggregateFunctionSumMap<T>>
|
||||
{
|
||||
public:
|
||||
AggregateFunctionSumMap(const DataTypePtr & keys_type, DataTypes & values_types)
|
||||
: AggregateFunctionSumMapBase<T, AggregateFunctionSumMap<T>>{keys_type, values_types}
|
||||
{}
|
||||
|
||||
String getName() const override { return "sumMap"; }
|
||||
|
||||
bool keepKey(const T &) const { return true; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class AggregateFunctionSumMapFiltered final : public AggregateFunctionSumMapBase<T, AggregateFunctionSumMapFiltered<T>>
|
||||
{
|
||||
private:
|
||||
std::unordered_set<T> keys_to_keep;
|
||||
|
||||
public:
|
||||
AggregateFunctionSumMapFiltered(const DataTypePtr & keys_type, const DataTypes & values_types, const Array & keys_to_keep_)
|
||||
: AggregateFunctionSumMapBase<T, AggregateFunctionSumMapFiltered<T>>{keys_type, values_types}
|
||||
{
|
||||
keys_to_keep.reserve(keys_to_keep_.size());
|
||||
for (const Field & f : keys_to_keep_)
|
||||
{
|
||||
keys_to_keep.emplace(f.safeGet<NearestFieldType<T>>());
|
||||
}
|
||||
}
|
||||
|
||||
String getName() const override { return "sumMapFiltered"; }
|
||||
|
||||
bool keepKey(const T & key) const { return keys_to_keep.count(key); }
|
||||
};
|
||||
|
||||
}
|
||||
|
@ -166,3 +166,20 @@ template <> constexpr bool IsDecimalNumber<Decimal64> = true;
|
||||
template <> constexpr bool IsDecimalNumber<Decimal128> = true;
|
||||
|
||||
}
|
||||
|
||||
/// Specialization of `std::hash` for the Decimal<T> types.
|
||||
namespace std
|
||||
{
|
||||
template <typename T>
|
||||
struct hash<DB::Decimal<T>> { size_t operator()(const DB::Decimal<T> & x) const { return hash<T>()(x.value); } };
|
||||
|
||||
template <>
|
||||
struct hash<DB::Decimal128>
|
||||
{
|
||||
size_t operator()(const DB::Decimal128 & x) const
|
||||
{
|
||||
return std::hash<DB::Int64>()(x.value >> 64)
|
||||
^ std::hash<DB::Int64>()(x.value & std::numeric_limits<DB::UInt64>::max());
|
||||
}
|
||||
};
|
||||
}
|
||||
|
@ -8,6 +8,8 @@
|
||||
2000-01-01 00:01:00 ([4,5,6,7,8],[10,10,20,10,10])
|
||||
2000-01-01 00:00:00 [1,2,3,4,5] [10,10,20,10,10]
|
||||
2000-01-01 00:01:00 [4,5,6,7,8] [10,10,20,10,10]
|
||||
([1],[10])
|
||||
([1,4,8],[10,20,10])
|
||||
([1],[1])
|
||||
([1],[1])
|
||||
(['a'],[1])
|
||||
|
@ -12,6 +12,9 @@ SELECT sumMapMerge(s) FROM (SELECT sumMapState(statusMap.status, statusMap.reque
|
||||
SELECT timeslot, sumMap(statusMap.status, statusMap.requests) FROM test.sum_map GROUP BY timeslot ORDER BY timeslot;
|
||||
SELECT timeslot, sumMap(statusMap.status, statusMap.requests).1, sumMap(statusMap.status, statusMap.requests).2 FROM test.sum_map GROUP BY timeslot ORDER BY timeslot;
|
||||
|
||||
SELECT sumMapFiltered([1])(statusMap.status, statusMap.requests) FROM test.sum_map;
|
||||
SELECT sumMapFiltered([1, 4, 8])(statusMap.status, statusMap.requests) FROM test.sum_map;
|
||||
|
||||
DROP TABLE test.sum_map;
|
||||
|
||||
select sumMap(val, cnt) from ( SELECT [ CAST(1, 'UInt64') ] as val, [1] as cnt );
|
||||
|
@ -155,3 +155,8 @@ Solution: Write in the GROUP BY query SearchPhrase HAVING uniqUpTo(4)(UserID) >=
|
||||
```
|
||||
|
||||
[Original article](https://clickhouse.yandex/docs/en/query_language/agg_functions/parametric_functions/) <!--hide-->
|
||||
|
||||
|
||||
## sumMapFiltered(keys_to_keep)(keys, values)
|
||||
|
||||
Same behavior as [sumMap](reference.md#sumMap) except that an array of keys is passed as a parameter. This can be especially useful when working with a high cardinality of keys.
|
||||
|
Loading…
Reference in New Issue
Block a user