Statically dispatch on whether the argument is a Tuple

This commit is contained in:
Alexander Kuzmenkov 2020-04-20 18:37:28 +03:00
parent d71b57f627
commit 3ee89344af
2 changed files with 100 additions and 51 deletions

View File

@ -36,23 +36,29 @@ struct WithoutOverflowPolicy
}
};
template <typename T>
using SumMapWithOverflow = AggregateFunctionSumMap<T, WithOverflowPolicy>;
template <bool overflow, bool tuple_argument>
struct SumMap
{
template <typename T>
using F = AggregateFunctionSumMap<T,
std::conditional_t<overflow, WithOverflowPolicy, WithoutOverflowPolicy>,
tuple_argument>;
};
template <typename T>
using SumMapWithoutOverflow = AggregateFunctionSumMap<T, WithoutOverflowPolicy>;
template <bool overflow, bool tuple_argument>
struct SumMapFiltered
{
template <typename T>
using F = AggregateFunctionSumMapFiltered<T,
std::conditional_t<overflow, WithOverflowPolicy, WithoutOverflowPolicy>,
tuple_argument>;
};
template <typename T>
using SumMapFilteredWithOverflow = AggregateFunctionSumMapFiltered<T, WithOverflowPolicy>;
template <typename T>
using SumMapFilteredWithoutOverflow = AggregateFunctionSumMapFiltered<T, WithoutOverflowPolicy>;
using SumMapArgs = std::pair<DataTypePtr, DataTypes>;
SumMapArgs parseArguments(const std::string & name, const DataTypes & arguments)
auto parseArguments(const std::string & name, const DataTypes & arguments)
{
DataTypes args;
bool tuple_argument = false;
if (arguments.size() == 1)
{
@ -66,9 +72,13 @@ SumMapArgs parseArguments(const std::string & name, const DataTypes & arguments)
const auto elems = tuple_type->getElements();
args.insert(args.end(), elems.begin(), elems.end());
tuple_argument = true;
}
else
{
args.insert(args.end(), arguments.begin(), arguments.end());
tuple_argument = false;
}
if (args.size() < 2)
throw Exception("Aggregate function " + name + " requires at least two arguments of Array type or one argument of tuple of two arrays",
@ -92,28 +102,42 @@ SumMapArgs parseArguments(const std::string & name, const DataTypes & arguments)
values_types.push_back(array_type->getNestedType());
}
return {std::move(keys_type), std::move(values_types)};
return std::tuple{std::move(keys_type), std::move(values_types),
tuple_argument};
}
template <template <typename> class Function>
template <bool overflow>
AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, const DataTypes & arguments, const Array & params)
{
assertNoParameters(name, params);
auto [keys_type, values_types] = parseArguments(name, arguments);
auto [keys_type, values_types, tuple_argument] = parseArguments(name,
arguments);
AggregateFunctionPtr res(createWithNumericBasedType<Function>(*keys_type, keys_type, values_types, arguments));
if (!res)
res.reset(createWithDecimalType<Function>(*keys_type, keys_type, values_types, arguments));
if (!res)
res.reset(createWithStringType<Function>(*keys_type, keys_type, values_types, arguments));
AggregateFunctionPtr res;
if (tuple_argument)
{
res.reset(createWithNumericBasedType<SumMap<overflow, true>::template F>(*keys_type, keys_type, values_types, arguments));
if (!res)
res.reset(createWithDecimalType<SumMap<overflow, true>::template F>(*keys_type, keys_type, values_types, arguments));
if (!res)
res.reset(createWithStringType<SumMap<overflow, true>::template F>(*keys_type, keys_type, values_types, arguments));
}
else
{
res.reset(createWithNumericBasedType<SumMap<overflow, false>::template F>(*keys_type, keys_type, values_types, arguments));
if (!res)
res.reset(createWithDecimalType<SumMap<overflow, false>::template F>(*keys_type, keys_type, values_types, arguments));
if (!res)
res.reset(createWithStringType<SumMap<overflow, false>::template F>(*keys_type, keys_type, values_types, arguments));
}
if (!res)
throw Exception("Illegal type of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return res;
}
template <template <typename> class Function>
template <bool overflow>
AggregateFunctionPtr createAggregateFunctionSumMapFiltered(const std::string & name, const DataTypes & arguments, const Array & params)
{
if (params.size() != 1)
@ -125,26 +149,40 @@ AggregateFunctionPtr createAggregateFunctionSumMapFiltered(const std::string & n
throw Exception("Aggregate function " + name + " requires an Array as parameter.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
auto [keys_type, values_types] = parseArguments(name, arguments);
auto [keys_type, values_types, tuple_argument] = parseArguments(name,
arguments);
AggregateFunctionPtr res(createWithNumericBasedType<Function>(*keys_type, keys_type, values_types, keys_to_keep, arguments, params));
if (!res)
res.reset(createWithDecimalType<Function>(*keys_type, keys_type, values_types, keys_to_keep, arguments, params));
if (!res)
res.reset(createWithStringType<Function>(*keys_type, keys_type, values_types, keys_to_keep, arguments, params));
AggregateFunctionPtr res;
if (tuple_argument)
{
res.reset(createWithNumericBasedType<SumMapFiltered<overflow, true>::template F>(*keys_type, keys_type, values_types, keys_to_keep, arguments, params));
if (!res)
res.reset(createWithDecimalType<SumMapFiltered<overflow, true>::template F>(*keys_type, keys_type, values_types, keys_to_keep, arguments, params));
if (!res)
res.reset(createWithStringType<SumMapFiltered<overflow, true>::template F>(*keys_type, keys_type, values_types, keys_to_keep, arguments, params));
}
else
{
res.reset(createWithNumericBasedType<SumMapFiltered<overflow, false>::template F>(*keys_type, keys_type, values_types, keys_to_keep, arguments, params));
if (!res)
res.reset(createWithDecimalType<SumMapFiltered<overflow, false>::template F>(*keys_type, keys_type, values_types, keys_to_keep, arguments, params));
if (!res)
res.reset(createWithStringType<SumMapFiltered<overflow, false>::template F>(*keys_type, keys_type, values_types, keys_to_keep, arguments, params));
}
if (!res)
throw Exception("Illegal type of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return res;
}
}
void registerAggregateFunctionSumMap(AggregateFunctionFactory & factory)
{
factory.registerFunction("sumMap", createAggregateFunctionSumMap<SumMapWithoutOverflow>);
factory.registerFunction("sumMapWithOverflow", createAggregateFunctionSumMap<SumMapWithOverflow>);
factory.registerFunction("sumMapFiltered", createAggregateFunctionSumMapFiltered<SumMapFilteredWithoutOverflow>);
factory.registerFunction("sumMapFilteredWithOverflow", createAggregateFunctionSumMapFiltered<SumMapFilteredWithOverflow>);
factory.registerFunction("sumMap", createAggregateFunctionSumMap<false /*overflow*/>);
factory.registerFunction("sumMapWithOverflow", createAggregateFunctionSumMap<true /*overflow*/>);
factory.registerFunction("sumMapFiltered", createAggregateFunctionSumMapFiltered<false /*overflow*/>);
factory.registerFunction("sumMapFilteredWithOverflow", createAggregateFunctionSumMapFiltered<true /*overflow*/>);
}
}

View File

@ -50,7 +50,8 @@ struct AggregateFunctionSumMapData
* ([1,2,3,4,5,6,7,8,9,10],[10,10,45,20,35,20,15,30,20,20])
*/
template <typename T, typename Derived, typename OverflowPolicy>
template <typename T, typename Derived, typename OverflowPolicy,
bool tuple_argument = false>
class AggregateFunctionSumMapBase : public IAggregateFunctionDataHelper<
AggregateFunctionSumMapData<NearestFieldType<T>>, Derived>
{
@ -78,19 +79,23 @@ public:
return std::make_shared<DataTypeTuple>(types);
}
void add(AggregateDataPtr place, const IColumn ** columns, const size_t row_num, Arena *) const override
void add(AggregateDataPtr place, const IColumn** _columns, const size_t row_num, Arena *) const override
{
// Check if tuple
auto tuple_col = checkAndGetColumn<ColumnTuple>(columns[0]);
if (tuple_col)
addImpl(place, tuple_col->getColumns(), row_num);
else
addImpl(place, columns, row_num);
}
std::conditional_t<tuple_argument,
const std::vector<ColumnTuple::WrappedPtr>,
const IColumn**> * columns_ptr;
if constexpr (tuple_argument)
{
columns_ptr = &static_cast<const ColumnTuple *>(_columns[0])->getColumns();
}
else
{
columns_ptr = &_columns;
}
auto & columns = *columns_ptr;
template<typename TColumns>
void addImpl(AggregateDataPtr place, TColumns & columns, const size_t row_num) const
{
// Column 0 contains array of keys of known type
Field key_field;
const ColumnArray & array_column0 = assert_cast<const ColumnArray &>(*columns[0]);
@ -274,13 +279,14 @@ public:
bool keepKey(const T & key) const { return static_cast<const Derived &>(*this).keepKey(key); }
};
template <typename T, typename OverflowPolicy>
template <typename T, typename OverflowPolicy, bool tuple_argument = false>
class AggregateFunctionSumMap final :
public AggregateFunctionSumMapBase<T, AggregateFunctionSumMap<T, OverflowPolicy>, OverflowPolicy>
public AggregateFunctionSumMapBase<T, AggregateFunctionSumMap<T, OverflowPolicy, tuple_argument>, OverflowPolicy, tuple_argument>
{
private:
using Self = AggregateFunctionSumMap<T, OverflowPolicy>;
using Base = AggregateFunctionSumMapBase<T, Self, OverflowPolicy>;
using Self = AggregateFunctionSumMap<T, OverflowPolicy, tuple_argument>;
using Base = AggregateFunctionSumMapBase<T, Self, OverflowPolicy,
tuple_argument>;
public:
AggregateFunctionSumMap(const DataTypePtr & keys_type_, DataTypes & values_types_, const DataTypes & argument_types_)
@ -292,13 +298,18 @@ public:
bool keepKey(const T &) const { return true; }
};
template <typename T, typename OverflowPolicy>
template <typename T, typename OverflowPolicy, bool tuple_argument = false>
class AggregateFunctionSumMapFiltered final :
public AggregateFunctionSumMapBase<T, AggregateFunctionSumMapFiltered<T, OverflowPolicy>, OverflowPolicy>
public AggregateFunctionSumMapBase<T,
AggregateFunctionSumMapFiltered<T, OverflowPolicy, tuple_argument>,
OverflowPolicy,
tuple_argument>
{
private:
using Self = AggregateFunctionSumMapFiltered<T, OverflowPolicy>;
using Base = AggregateFunctionSumMapBase<T, Self, OverflowPolicy>;
using Self = AggregateFunctionSumMapFiltered<T, OverflowPolicy,
tuple_argument>;
using Base = AggregateFunctionSumMapBase<T, Self, OverflowPolicy,
tuple_argument>;
std::unordered_set<T> keys_to_keep;