mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-17 13:13:36 +00:00
Statically dispatch on whether the argument is a Tuple
This commit is contained in:
parent
d71b57f627
commit
3ee89344af
@ -36,23 +36,29 @@ struct WithoutOverflowPolicy
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using SumMapWithOverflow = AggregateFunctionSumMap<T, WithOverflowPolicy>;
|
||||
template <bool overflow, bool tuple_argument>
|
||||
struct SumMap
|
||||
{
|
||||
template <typename T>
|
||||
using F = AggregateFunctionSumMap<T,
|
||||
std::conditional_t<overflow, WithOverflowPolicy, WithoutOverflowPolicy>,
|
||||
tuple_argument>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using SumMapWithoutOverflow = AggregateFunctionSumMap<T, WithoutOverflowPolicy>;
|
||||
template <bool overflow, bool tuple_argument>
|
||||
struct SumMapFiltered
|
||||
{
|
||||
template <typename T>
|
||||
using F = AggregateFunctionSumMapFiltered<T,
|
||||
std::conditional_t<overflow, WithOverflowPolicy, WithoutOverflowPolicy>,
|
||||
tuple_argument>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using SumMapFilteredWithOverflow = AggregateFunctionSumMapFiltered<T, WithOverflowPolicy>;
|
||||
|
||||
template <typename T>
|
||||
using SumMapFilteredWithoutOverflow = AggregateFunctionSumMapFiltered<T, WithoutOverflowPolicy>;
|
||||
|
||||
using SumMapArgs = std::pair<DataTypePtr, DataTypes>;
|
||||
|
||||
SumMapArgs parseArguments(const std::string & name, const DataTypes & arguments)
|
||||
auto parseArguments(const std::string & name, const DataTypes & arguments)
|
||||
{
|
||||
DataTypes args;
|
||||
bool tuple_argument = false;
|
||||
|
||||
if (arguments.size() == 1)
|
||||
{
|
||||
@ -66,9 +72,13 @@ SumMapArgs parseArguments(const std::string & name, const DataTypes & arguments)
|
||||
|
||||
const auto elems = tuple_type->getElements();
|
||||
args.insert(args.end(), elems.begin(), elems.end());
|
||||
tuple_argument = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
args.insert(args.end(), arguments.begin(), arguments.end());
|
||||
tuple_argument = false;
|
||||
}
|
||||
|
||||
if (args.size() < 2)
|
||||
throw Exception("Aggregate function " + name + " requires at least two arguments of Array type or one argument of tuple of two arrays",
|
||||
@ -92,28 +102,42 @@ SumMapArgs parseArguments(const std::string & name, const DataTypes & arguments)
|
||||
values_types.push_back(array_type->getNestedType());
|
||||
}
|
||||
|
||||
return {std::move(keys_type), std::move(values_types)};
|
||||
return std::tuple{std::move(keys_type), std::move(values_types),
|
||||
tuple_argument};
|
||||
}
|
||||
|
||||
template <template <typename> class Function>
|
||||
template <bool overflow>
|
||||
AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, const DataTypes & arguments, const Array & params)
|
||||
{
|
||||
assertNoParameters(name, params);
|
||||
|
||||
auto [keys_type, values_types] = parseArguments(name, arguments);
|
||||
auto [keys_type, values_types, tuple_argument] = parseArguments(name,
|
||||
arguments);
|
||||
|
||||
AggregateFunctionPtr res(createWithNumericBasedType<Function>(*keys_type, keys_type, values_types, arguments));
|
||||
if (!res)
|
||||
res.reset(createWithDecimalType<Function>(*keys_type, keys_type, values_types, arguments));
|
||||
if (!res)
|
||||
res.reset(createWithStringType<Function>(*keys_type, keys_type, values_types, arguments));
|
||||
AggregateFunctionPtr res;
|
||||
if (tuple_argument)
|
||||
{
|
||||
res.reset(createWithNumericBasedType<SumMap<overflow, true>::template F>(*keys_type, keys_type, values_types, arguments));
|
||||
if (!res)
|
||||
res.reset(createWithDecimalType<SumMap<overflow, true>::template F>(*keys_type, keys_type, values_types, arguments));
|
||||
if (!res)
|
||||
res.reset(createWithStringType<SumMap<overflow, true>::template F>(*keys_type, keys_type, values_types, arguments));
|
||||
}
|
||||
else
|
||||
{
|
||||
res.reset(createWithNumericBasedType<SumMap<overflow, false>::template F>(*keys_type, keys_type, values_types, arguments));
|
||||
if (!res)
|
||||
res.reset(createWithDecimalType<SumMap<overflow, false>::template F>(*keys_type, keys_type, values_types, arguments));
|
||||
if (!res)
|
||||
res.reset(createWithStringType<SumMap<overflow, false>::template F>(*keys_type, keys_type, values_types, arguments));
|
||||
}
|
||||
if (!res)
|
||||
throw Exception("Illegal type of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
template <template <typename> class Function>
|
||||
template <bool overflow>
|
||||
AggregateFunctionPtr createAggregateFunctionSumMapFiltered(const std::string & name, const DataTypes & arguments, const Array & params)
|
||||
{
|
||||
if (params.size() != 1)
|
||||
@ -125,26 +149,40 @@ AggregateFunctionPtr createAggregateFunctionSumMapFiltered(const std::string & n
|
||||
throw Exception("Aggregate function " + name + " requires an Array as parameter.",
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
auto [keys_type, values_types] = parseArguments(name, arguments);
|
||||
auto [keys_type, values_types, tuple_argument] = parseArguments(name,
|
||||
arguments);
|
||||
|
||||
AggregateFunctionPtr res(createWithNumericBasedType<Function>(*keys_type, keys_type, values_types, keys_to_keep, arguments, params));
|
||||
if (!res)
|
||||
res.reset(createWithDecimalType<Function>(*keys_type, keys_type, values_types, keys_to_keep, arguments, params));
|
||||
if (!res)
|
||||
res.reset(createWithStringType<Function>(*keys_type, keys_type, values_types, keys_to_keep, arguments, params));
|
||||
AggregateFunctionPtr res;
|
||||
if (tuple_argument)
|
||||
{
|
||||
res.reset(createWithNumericBasedType<SumMapFiltered<overflow, true>::template F>(*keys_type, keys_type, values_types, keys_to_keep, arguments, params));
|
||||
if (!res)
|
||||
res.reset(createWithDecimalType<SumMapFiltered<overflow, true>::template F>(*keys_type, keys_type, values_types, keys_to_keep, arguments, params));
|
||||
if (!res)
|
||||
res.reset(createWithStringType<SumMapFiltered<overflow, true>::template F>(*keys_type, keys_type, values_types, keys_to_keep, arguments, params));
|
||||
}
|
||||
else
|
||||
{
|
||||
res.reset(createWithNumericBasedType<SumMapFiltered<overflow, false>::template F>(*keys_type, keys_type, values_types, keys_to_keep, arguments, params));
|
||||
if (!res)
|
||||
res.reset(createWithDecimalType<SumMapFiltered<overflow, false>::template F>(*keys_type, keys_type, values_types, keys_to_keep, arguments, params));
|
||||
if (!res)
|
||||
res.reset(createWithStringType<SumMapFiltered<overflow, false>::template F>(*keys_type, keys_type, values_types, keys_to_keep, arguments, params));
|
||||
}
|
||||
if (!res)
|
||||
throw Exception("Illegal type of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void registerAggregateFunctionSumMap(AggregateFunctionFactory & factory)
|
||||
{
|
||||
factory.registerFunction("sumMap", createAggregateFunctionSumMap<SumMapWithoutOverflow>);
|
||||
factory.registerFunction("sumMapWithOverflow", createAggregateFunctionSumMap<SumMapWithOverflow>);
|
||||
factory.registerFunction("sumMapFiltered", createAggregateFunctionSumMapFiltered<SumMapFilteredWithoutOverflow>);
|
||||
factory.registerFunction("sumMapFilteredWithOverflow", createAggregateFunctionSumMapFiltered<SumMapFilteredWithOverflow>);
|
||||
factory.registerFunction("sumMap", createAggregateFunctionSumMap<false /*overflow*/>);
|
||||
factory.registerFunction("sumMapWithOverflow", createAggregateFunctionSumMap<true /*overflow*/>);
|
||||
factory.registerFunction("sumMapFiltered", createAggregateFunctionSumMapFiltered<false /*overflow*/>);
|
||||
factory.registerFunction("sumMapFilteredWithOverflow", createAggregateFunctionSumMapFiltered<true /*overflow*/>);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -50,7 +50,8 @@ struct AggregateFunctionSumMapData
|
||||
* ([1,2,3,4,5,6,7,8,9,10],[10,10,45,20,35,20,15,30,20,20])
|
||||
*/
|
||||
|
||||
template <typename T, typename Derived, typename OverflowPolicy>
|
||||
template <typename T, typename Derived, typename OverflowPolicy,
|
||||
bool tuple_argument = false>
|
||||
class AggregateFunctionSumMapBase : public IAggregateFunctionDataHelper<
|
||||
AggregateFunctionSumMapData<NearestFieldType<T>>, Derived>
|
||||
{
|
||||
@ -78,19 +79,23 @@ public:
|
||||
return std::make_shared<DataTypeTuple>(types);
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr place, const IColumn ** columns, const size_t row_num, Arena *) const override
|
||||
void add(AggregateDataPtr place, const IColumn** _columns, const size_t row_num, Arena *) const override
|
||||
{
|
||||
// Check if tuple
|
||||
auto tuple_col = checkAndGetColumn<ColumnTuple>(columns[0]);
|
||||
if (tuple_col)
|
||||
addImpl(place, tuple_col->getColumns(), row_num);
|
||||
else
|
||||
addImpl(place, columns, row_num);
|
||||
}
|
||||
std::conditional_t<tuple_argument,
|
||||
const std::vector<ColumnTuple::WrappedPtr>,
|
||||
const IColumn**> * columns_ptr;
|
||||
|
||||
if constexpr (tuple_argument)
|
||||
{
|
||||
columns_ptr = &static_cast<const ColumnTuple *>(_columns[0])->getColumns();
|
||||
}
|
||||
else
|
||||
{
|
||||
columns_ptr = &_columns;
|
||||
}
|
||||
|
||||
auto & columns = *columns_ptr;
|
||||
|
||||
template<typename TColumns>
|
||||
void addImpl(AggregateDataPtr place, TColumns & columns, const size_t row_num) const
|
||||
{
|
||||
// Column 0 contains array of keys of known type
|
||||
Field key_field;
|
||||
const ColumnArray & array_column0 = assert_cast<const ColumnArray &>(*columns[0]);
|
||||
@ -274,13 +279,14 @@ public:
|
||||
bool keepKey(const T & key) const { return static_cast<const Derived &>(*this).keepKey(key); }
|
||||
};
|
||||
|
||||
template <typename T, typename OverflowPolicy>
|
||||
template <typename T, typename OverflowPolicy, bool tuple_argument = false>
|
||||
class AggregateFunctionSumMap final :
|
||||
public AggregateFunctionSumMapBase<T, AggregateFunctionSumMap<T, OverflowPolicy>, OverflowPolicy>
|
||||
public AggregateFunctionSumMapBase<T, AggregateFunctionSumMap<T, OverflowPolicy, tuple_argument>, OverflowPolicy, tuple_argument>
|
||||
{
|
||||
private:
|
||||
using Self = AggregateFunctionSumMap<T, OverflowPolicy>;
|
||||
using Base = AggregateFunctionSumMapBase<T, Self, OverflowPolicy>;
|
||||
using Self = AggregateFunctionSumMap<T, OverflowPolicy, tuple_argument>;
|
||||
using Base = AggregateFunctionSumMapBase<T, Self, OverflowPolicy,
|
||||
tuple_argument>;
|
||||
|
||||
public:
|
||||
AggregateFunctionSumMap(const DataTypePtr & keys_type_, DataTypes & values_types_, const DataTypes & argument_types_)
|
||||
@ -292,13 +298,18 @@ public:
|
||||
bool keepKey(const T &) const { return true; }
|
||||
};
|
||||
|
||||
template <typename T, typename OverflowPolicy>
|
||||
template <typename T, typename OverflowPolicy, bool tuple_argument = false>
|
||||
class AggregateFunctionSumMapFiltered final :
|
||||
public AggregateFunctionSumMapBase<T, AggregateFunctionSumMapFiltered<T, OverflowPolicy>, OverflowPolicy>
|
||||
public AggregateFunctionSumMapBase<T,
|
||||
AggregateFunctionSumMapFiltered<T, OverflowPolicy, tuple_argument>,
|
||||
OverflowPolicy,
|
||||
tuple_argument>
|
||||
{
|
||||
private:
|
||||
using Self = AggregateFunctionSumMapFiltered<T, OverflowPolicy>;
|
||||
using Base = AggregateFunctionSumMapBase<T, Self, OverflowPolicy>;
|
||||
using Self = AggregateFunctionSumMapFiltered<T, OverflowPolicy,
|
||||
tuple_argument>;
|
||||
using Base = AggregateFunctionSumMapBase<T, Self, OverflowPolicy,
|
||||
tuple_argument>;
|
||||
|
||||
std::unordered_set<T> keys_to_keep;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user