Merge pull request #4129 from ercolanelli-leo/sumMapFiltered

implement sumMapFiltered
This commit is contained in:
alexey-milovidov 2019-01-24 21:27:03 +03:00 committed by GitHub
commit fbefc99fb1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 110 additions and 8 deletions

View File

@ -12,10 +12,10 @@ namespace DB
namespace
{
AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, const DataTypes & arguments, const Array & params)
{
assertNoParameters(name, params);
using SumMapArgs = std::pair<DataTypePtr, DataTypes>;
SumMapArgs parseArguments(const std::string & name, const DataTypes & arguments)
{
if (arguments.size() < 2)
throw Exception("Aggregate function " + name + " requires at least two arguments of Array type.",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
@ -25,9 +25,11 @@ AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, con
throw Exception("First argument for function " + name + " must be an array.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
const DataTypePtr & keys_type = array_type->getNestedType();
DataTypePtr keys_type = array_type->getNestedType();
DataTypes values_types;
values_types.reserve(arguments.size() - 1);
for (size_t i = 1; i < arguments.size(); ++i)
{
array_type = checkAndGetDataType<DataTypeArray>(arguments[i].get());
@ -37,6 +39,15 @@ AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, con
values_types.push_back(array_type->getNestedType());
}
return {std::move(keys_type), std::move(values_types)};
}
AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, const DataTypes & arguments, const Array & params)
{
assertNoParameters(name, params);
auto [keys_type, values_types] = parseArguments(name, arguments);
AggregateFunctionPtr res(createWithNumericBasedType<AggregateFunctionSumMap>(*keys_type, keys_type, values_types));
if (!res)
res.reset(createWithDecimalType<AggregateFunctionSumMap>(*keys_type, keys_type, values_types));
@ -46,11 +57,33 @@ AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, con
return res;
}
AggregateFunctionPtr createAggregateFunctionSumMapFiltered(const std::string & name, const DataTypes & arguments, const Array & params)
{
if (params.size() != 1)
throw Exception("Aggregate function " + name + " requires exactly one parameter of Array type.",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
Array keys_to_keep;
if (!params.front().tryGet<Array>(keys_to_keep))
throw Exception("Aggregate function " + name + " requires an Array as parameter.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
auto [keys_type, values_types] = parseArguments(name, arguments);
AggregateFunctionPtr res(createWithNumericBasedType<AggregateFunctionSumMapFiltered>(*keys_type, keys_type, values_types, keys_to_keep));
if (!res)
res.reset(createWithDecimalType<AggregateFunctionSumMapFiltered>(*keys_type, keys_type, values_types, keys_to_keep));
if (!res)
throw Exception("Illegal type of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return res;
}
}
void registerAggregateFunctionSumMap(AggregateFunctionFactory & factory)
{
factory.registerFunction("sumMap", createAggregateFunctionSumMap);
factory.registerFunction("sumMapFiltered", createAggregateFunctionSumMapFiltered);
}
}

View File

@ -50,9 +50,9 @@ struct AggregateFunctionSumMapData
* ([1,2,3,4,5,6,7,8,9,10],[10,10,45,20,35,20,15,30,20,20])
*/
template <typename T>
class AggregateFunctionSumMap final : public IAggregateFunctionDataHelper<
AggregateFunctionSumMapData<NearestFieldType<T>>, AggregateFunctionSumMap<T>>
template <typename T, typename Derived>
class AggregateFunctionSumMapBase : public IAggregateFunctionDataHelper<
AggregateFunctionSumMapData<NearestFieldType<T>>, Derived>
{
private:
using ColVecType = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>;
@ -61,7 +61,7 @@ private:
DataTypes values_types;
public:
AggregateFunctionSumMap(const DataTypePtr & keys_type, const DataTypes & values_types)
AggregateFunctionSumMapBase(const DataTypePtr & keys_type, const DataTypes & values_types)
: keys_type(keys_type), values_types(values_types) {}
String getName() const override { return "sumMap"; }
@ -109,6 +109,11 @@ public:
array_column.getData().get(values_vec_offset + i, value);
const auto & key = keys_vec.getData()[keys_vec_offset + i];
if (!keepKey(key))
{
continue;
}
IteratorType it;
if constexpr (IsDecimalNumber<T>)
{
@ -253,6 +258,43 @@ public:
}
const char * getHeaderFilePath() const override { return __FILE__; }
bool keepKey(const T & key) const { return static_cast<const Derived &>(*this).keepKey(key); }
};
template <typename T>
class AggregateFunctionSumMap final : public AggregateFunctionSumMapBase<T, AggregateFunctionSumMap<T>>
{
public:
AggregateFunctionSumMap(const DataTypePtr & keys_type, DataTypes & values_types)
: AggregateFunctionSumMapBase<T, AggregateFunctionSumMap<T>>{keys_type, values_types}
{}
String getName() const override { return "sumMap"; }
bool keepKey(const T &) const { return true; }
};
template <typename T>
class AggregateFunctionSumMapFiltered final : public AggregateFunctionSumMapBase<T, AggregateFunctionSumMapFiltered<T>>
{
private:
std::unordered_set<T> keys_to_keep;
public:
AggregateFunctionSumMapFiltered(const DataTypePtr & keys_type, const DataTypes & values_types, const Array & keys_to_keep_)
: AggregateFunctionSumMapBase<T, AggregateFunctionSumMapFiltered<T>>{keys_type, values_types}
{
keys_to_keep.reserve(keys_to_keep_.size());
for (const Field & f : keys_to_keep_)
{
keys_to_keep.emplace(f.safeGet<NearestFieldType<T>>());
}
}
String getName() const override { return "sumMapFiltered"; }
bool keepKey(const T & key) const { return keys_to_keep.count(key); }
};
}

View File

@ -166,3 +166,20 @@ template <> constexpr bool IsDecimalNumber<Decimal64> = true;
template <> constexpr bool IsDecimalNumber<Decimal128> = true;
}
/// Specialization of `std::hash` for the Decimal<T> types.
namespace std
{
template <typename T>
struct hash<DB::Decimal<T>> { size_t operator()(const DB::Decimal<T> & x) const { return hash<T>()(x.value); } };
template <>
struct hash<DB::Decimal128>
{
size_t operator()(const DB::Decimal128 & x) const
{
return std::hash<DB::Int64>()(x.value >> 64)
^ std::hash<DB::Int64>()(x.value & std::numeric_limits<DB::UInt64>::max());
}
};
}

View File

@ -8,6 +8,8 @@
2000-01-01 00:01:00 ([4,5,6,7,8],[10,10,20,10,10])
2000-01-01 00:00:00 [1,2,3,4,5] [10,10,20,10,10]
2000-01-01 00:01:00 [4,5,6,7,8] [10,10,20,10,10]
([1],[10])
([1,4,8],[10,20,10])
([1],[1])
([1],[1])
(['a'],[1])

View File

@ -12,6 +12,9 @@ SELECT sumMapMerge(s) FROM (SELECT sumMapState(statusMap.status, statusMap.reque
SELECT timeslot, sumMap(statusMap.status, statusMap.requests) FROM test.sum_map GROUP BY timeslot ORDER BY timeslot;
SELECT timeslot, sumMap(statusMap.status, statusMap.requests).1, sumMap(statusMap.status, statusMap.requests).2 FROM test.sum_map GROUP BY timeslot ORDER BY timeslot;
SELECT sumMapFiltered([1])(statusMap.status, statusMap.requests) FROM test.sum_map;
SELECT sumMapFiltered([1, 4, 8])(statusMap.status, statusMap.requests) FROM test.sum_map;
DROP TABLE test.sum_map;
select sumMap(val, cnt) from ( SELECT [ CAST(1, 'UInt64') ] as val, [1] as cnt );

View File

@ -155,3 +155,8 @@ Solution: Write in the GROUP BY query SearchPhrase HAVING uniqUpTo(4)(UserID) >=
```
[Original article](https://clickhouse.yandex/docs/en/query_language/agg_functions/parametric_functions/) <!--hide-->
## sumMapFiltered(keys_to_keep)(keys, values)
Same behavior as [sumMap](reference.md#sumMap) except that an array of keys is passed as a parameter. This can be especially useful when working with a high cardinality of keys.