Update IAggregateFunction interface.

This commit is contained in:
Nikolai Kochetov 2019-02-11 22:26:32 +03:00
parent 0a6f75a1b6
commit 2b8b342ccd
50 changed files with 220 additions and 151 deletions

View File

@ -31,12 +31,13 @@ template <typename Data>
class AggregateFunctionArgMinMax final : public IAggregateFunctionDataHelper<Data, AggregateFunctionArgMinMax<Data>>
{
private:
DataTypePtr type_res;
DataTypePtr type_val;
const DataTypePtr & type_res;
const DataTypePtr & type_val;
public:
AggregateFunctionArgMinMax(const DataTypePtr & type_res, const DataTypePtr & type_val)
: type_res(type_res), type_val(type_val)
: IAggregateFunctionDataHelper<Data, AggregateFunctionArgMinMax<Data>>({type_res, type_val}, {}),
type_res(argument_types[0]), type_val(argument_types[1])
{
if (!type_val->isComparable())
throw Exception("Illegal type " + type_val->getName() + " of second argument of aggregate function " + getName()

View File

@ -28,7 +28,8 @@ private:
public:
AggregateFunctionArray(AggregateFunctionPtr nested_, const DataTypes & arguments)
: nested_func(nested_), num_arguments(arguments.size())
: IAggregateFunctionHelper<AggregateFunctionArray>(arguments, {})
, nested_func(nested_), num_arguments(arguments.size())
{
for (const auto & type : arguments)
if (!isArray(type))

View File

@ -55,7 +55,8 @@ public:
/// ctor for Decimals
AggregateFunctionAvg(const IDataType & data_type)
: scale(getDecimalScale(data_type))
: IAggregateFunctionDataHelper<Data, AggregateFunctionAvg<T, Data>>({data_type}, {})
, scale(getDecimalScale(data_type))
{}
String getName() const override { return "avg"; }

View File

@ -21,7 +21,7 @@ AggregateFunctionPtr createAggregateFunctionBitwise(const std::string & name, co
+ " is illegal, because it cannot be used in bitwise operations",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
AggregateFunctionPtr res(createWithUnsignedIntegerType<AggregateFunctionBitwise, Data>(*argument_types[0]));
AggregateFunctionPtr res(createWithUnsignedIntegerType<AggregateFunctionBitwise, Data>(*argument_types[0], argument_types[0]));
if (!res)
throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

View File

@ -43,6 +43,9 @@ template <typename T, typename Data>
class AggregateFunctionBitwise final : public IAggregateFunctionDataHelper<Data, AggregateFunctionBitwise<T, Data>>
{
public:
AggregateFunctionBitwise(const DataTypePtr & type)
: IAggregateFunctionDataHelper<Data, AggregateFunctionBitwise<T, Data>>({type}, {}) {}
String getName() const override { return Data::name(); }
DataTypePtr getReturnType() const override

View File

@ -111,6 +111,7 @@ public:
}
AggregateFunctionBoundingRatio(const DataTypes & arguments)
: IAggregateFunctionDataHelper<AggregateFunctionBoundingRatioData, AggregateFunctionBoundingRatio>(arguments, {})
{
const auto x_arg = arguments.at(0).get();
const auto y_arg = arguments.at(0).get();

View File

@ -9,12 +9,12 @@ namespace DB
namespace
{
AggregateFunctionPtr createAggregateFunctionCount(const std::string & name, const DataTypes & /*argument_types*/, const Array & parameters)
AggregateFunctionPtr createAggregateFunctionCount(const std::string & name, const DataTypes & argument_types, const Array & parameters)
{
assertNoParameters(name, parameters);
/// 'count' accept any number of arguments and (in this case of non-Nullable types) simply ignore them.
return std::make_shared<AggregateFunctionCount>();
return std::make_shared<AggregateFunctionCount>(argument_types);
}
}

View File

@ -28,6 +28,8 @@ namespace ErrorCodes
class AggregateFunctionCount final : public IAggregateFunctionDataHelper<AggregateFunctionCountData, AggregateFunctionCount>
{
public:
AggregateFunctionCount(const DataTypes & argument_types) : IAggregateFunctionDataHelper(argument_types, {}) {}
String getName() const override { return "count"; }
DataTypePtr getReturnType() const override
@ -74,7 +76,8 @@ public:
class AggregateFunctionCountNotNullUnary final : public IAggregateFunctionDataHelper<AggregateFunctionCountData, AggregateFunctionCountNotNullUnary>
{
public:
AggregateFunctionCountNotNullUnary(const DataTypePtr & argument)
AggregateFunctionCountNotNullUnary(const DataTypePtr & argument, const Array & params)
: IAggregateFunctionDataHelper<AggregateFunctionCountData, AggregateFunctionCountNotNullUnary>({argument}, params)
{
if (!argument->isNullable())
throw Exception("Logical error: not Nullable data type passed to AggregateFunctionCountNotNullUnary", ErrorCodes::LOGICAL_ERROR);
@ -120,7 +123,8 @@ public:
class AggregateFunctionCountNotNullVariadic final : public IAggregateFunctionDataHelper<AggregateFunctionCountData, AggregateFunctionCountNotNullVariadic>
{
public:
AggregateFunctionCountNotNullVariadic(const DataTypes & arguments)
AggregateFunctionCountNotNullVariadic(const DataTypes & arguments, const Array & params)
: IAggregateFunctionDataHelper<AggregateFunctionCountData, AggregateFunctionCountNotNullVariadic>(arguments, params)
{
number_of_arguments = arguments.size();

View File

@ -26,12 +26,12 @@ AggregateFunctionPtr createAggregateFunctionEntropy(const std::string & name, co
if (num_args == 1)
{
/// Specialized implementation for single argument of numeric type.
if (auto res = createWithNumericBasedType<AggregateFunctionEntropy>(*argument_types[0], num_args))
if (auto res = createWithNumericBasedType<AggregateFunctionEntropy>(*argument_types[0], argument_types))
return AggregateFunctionPtr(res);
}
/// Generic implementation for other types or for multiple arguments.
return std::make_shared<AggregateFunctionEntropy<UInt128>>(num_args);
return std::make_shared<AggregateFunctionEntropy<UInt128>>(argument_types);
}
}

View File

@ -97,7 +97,9 @@ private:
size_t num_args;
public:
AggregateFunctionEntropy(size_t num_args) : num_args(num_args)
AggregateFunctionEntropy(const DataTypes & argument_types)
: IAggregateFunctionDataHelper<EntropyData<Value>, AggregateFunctionEntropy<Value>>(argument_types, {})
, num_args(argument_types.size())
{
}

View File

@ -86,17 +86,12 @@ AggregateFunctionPtr AggregateFunctionFactory::get(
[](const auto & type) { return type->onlyNull(); }))
nested_function = getImpl(name, nested_types, parameters, recursion_level);
auto res = combinator->transformAggregateFunction(nested_function, type_without_low_cardinality, parameters);
res->setArguments(type_without_low_cardinality, parameters);
return res;
return combinator->transformAggregateFunction(nested_function, argument_types, parameters);
}
auto res = getImpl(name, type_without_low_cardinality, parameters, recursion_level);
if (!res)
throw Exception("Logical error: AggregateFunctionFactory returned nullptr", ErrorCodes::LOGICAL_ERROR);
res->setArguments(type_without_low_cardinality, parameters);
return res;
}

View File

@ -97,7 +97,8 @@ private:
public:
AggregateFunctionForEach(AggregateFunctionPtr nested_, const DataTypes & arguments)
: nested_func(nested_), num_arguments(arguments.size())
: IAggregateFunctionDataHelper<AggregateFunctionForEachData, AggregateFunctionForEach>(arguments, {})
, nested_func(nested_), num_arguments(arguments.size())
{
nested_size_of_data = nested_func->sizeOfData();

View File

@ -48,12 +48,13 @@ class GroupArrayNumericImpl final
: public IAggregateFunctionDataHelper<GroupArrayNumericData<T>, GroupArrayNumericImpl<T, Tlimit_num_elems>>
{
static constexpr bool limit_num_elems = Tlimit_num_elems::value;
DataTypePtr data_type;
DataTypePtr & data_type;
UInt64 max_elems;
public:
explicit GroupArrayNumericImpl(const DataTypePtr & data_type_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max())
: data_type(data_type_), max_elems(max_elems_) {}
: IAggregateFunctionDataHelper<GroupArrayNumericData<T>, GroupArrayNumericImpl<T, Tlimit_num_elems>>({data_type}, {})
, data_type(argument_types[0]), max_elems(max_elems_) {}
String getName() const override { return "groupArray"; }

View File

@ -13,6 +13,10 @@ namespace
AggregateFunctionPtr createAggregateFunctionGroupArrayInsertAt(const std::string & name, const DataTypes & argument_types, const Array & parameters)
{
assertBinary(name, argument_types);
if (argument_types.size() != 2)
throw Exception("Aggregate function groupArrayInsertAt requires two arguments.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
return std::make_shared<AggregateFunctionGroupArrayInsertAtGeneric>(argument_types, parameters);
}

View File

@ -54,12 +54,14 @@ class AggregateFunctionGroupArrayInsertAtGeneric final
: public IAggregateFunctionDataHelper<AggregateFunctionGroupArrayInsertAtDataGeneric, AggregateFunctionGroupArrayInsertAtGeneric>
{
private:
DataTypePtr type;
DataTypePtr & type;
Field default_value;
UInt64 length_to_resize = 0; /// zero means - do not do resizing.
public:
AggregateFunctionGroupArrayInsertAtGeneric(const DataTypes & arguments, const Array & params)
: IAggregateFunctionDataHelper<AggregateFunctionGroupArrayInsertAtDataGeneric, AggregateFunctionGroupArrayInsertAtGeneric>(arguments, params)
, type(argument_types[0])
{
if (!params.empty())
{
@ -76,14 +78,9 @@ public:
}
}
if (arguments.size() != 2)
throw Exception("Aggregate function " + getName() + " requires two arguments.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
if (!isUnsignedInteger(arguments[1]))
throw Exception("Second argument of aggregate function " + getName() + " must be integer.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
type = arguments.front();
if (default_value.isNull())
default_value = type->getDefault();
else

View File

@ -15,11 +15,15 @@ namespace
/// Substitute return type for Date and DateTime
class AggregateFunctionGroupUniqArrayDate : public AggregateFunctionGroupUniqArray<DataTypeDate::FieldType>
{
public:
AggregateFunctionGroupUniqArrayDate(const DataTypePtr & argument_type) : AggregateFunctionGroupUniqArray<DataTypeDate::FieldType>(argument_type) {}
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeArray>(std::make_shared<DataTypeDate>()); }
};
class AggregateFunctionGroupUniqArrayDateTime : public AggregateFunctionGroupUniqArray<DataTypeDateTime::FieldType>
{
public:
AggregateFunctionGroupUniqArrayDateTime(const DataTypePtr & argument_type) : AggregateFunctionGroupUniqArray<DataTypeDateTime::FieldType>(argument_type) {}
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeArray>(std::make_shared<DataTypeDateTime>()); }
};
@ -27,8 +31,8 @@ class AggregateFunctionGroupUniqArrayDateTime : public AggregateFunctionGroupUni
static IAggregateFunction * createWithExtraTypes(const DataTypePtr & argument_type)
{
WhichDataType which(argument_type);
if (which.idx == TypeIndex::Date) return new AggregateFunctionGroupUniqArrayDate;
else if (which.idx == TypeIndex::DateTime) return new AggregateFunctionGroupUniqArrayDateTime;
if (which.idx == TypeIndex::Date) return new AggregateFunctionGroupUniqArrayDate(argument_type);
else if (which.idx == TypeIndex::DateTime) return new AggregateFunctionGroupUniqArrayDateTime(argument_type);
else
{
/// Check that we can use plain version of AggreagteFunctionGroupUniqArrayGeneric

View File

@ -44,6 +44,9 @@ private:
using State = AggregateFunctionGroupUniqArrayData<T>;
public:
AggregateFunctionGroupUniqArray(const DataTypePtr & argument_type)
: IAggregateFunctionDataHelper<AggregateFunctionGroupUniqArrayData<T>, AggregateFunctionGroupUniqArray<T>>({argument_type}, {}) {}
String getName() const override { return "groupUniqArray"; }
DataTypePtr getReturnType() const override
@ -115,7 +118,7 @@ template <bool is_plain_column = false>
class AggreagteFunctionGroupUniqArrayGeneric
: public IAggregateFunctionDataHelper<AggreagteFunctionGroupUniqArrayGenericData, AggreagteFunctionGroupUniqArrayGeneric<is_plain_column>>
{
DataTypePtr input_data_type;
DataTypePtr & input_data_type;
using State = AggreagteFunctionGroupUniqArrayGenericData;
@ -125,7 +128,8 @@ class AggreagteFunctionGroupUniqArrayGeneric
public:
AggreagteFunctionGroupUniqArrayGeneric(const DataTypePtr & input_data_type)
: input_data_type(input_data_type) {}
: IAggregateFunctionDataHelper<AggreagteFunctionGroupUniqArrayGenericData, AggreagteFunctionGroupUniqArrayGeneric<is_plain_column>>({input_data_type}, {})
, input_data_type(argument_types[0]) {}
String getName() const override { return "groupUniqArray"; }

View File

@ -39,7 +39,7 @@ AggregateFunctionPtr createAggregateFunctionHistogram(const std::string & name,
throw Exception("Bin count should be positive", ErrorCodes::BAD_ARGUMENTS);
assertUnary(name, arguments);
AggregateFunctionPtr res(createWithNumericType<AggregateFunctionHistogram>(*arguments[0], bins_count));
AggregateFunctionPtr res(createWithNumericType<AggregateFunctionHistogram>(*arguments[0], bins_count, arguments, params));
if (!res)
throw Exception("Illegal type " + arguments[0]->getName() + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

View File

@ -304,8 +304,9 @@ private:
const UInt32 max_bins;
public:
AggregateFunctionHistogram(UInt32 max_bins)
: max_bins(max_bins)
AggregateFunctionHistogram(const DataTypes & arguments, const Array & params, UInt32 max_bins)
: IAggregateFunctionDataHelper<AggregateFunctionHistogramData, AggregateFunctionHistogram<T>>(arguments, params)
, max_bins(max_bins)
{
}

View File

@ -28,7 +28,8 @@ private:
public:
AggregateFunctionIf(AggregateFunctionPtr nested, const DataTypes & types)
: nested_func(nested), num_arguments(types.size())
: IAggregateFunctionHelper<AggregateFunctionIf>(types, nested->getParameters())
, nested_func(nested), num_arguments(types.size())
{
if (num_arguments == 0)
throw Exception("Aggregate function " + getName() + " require at least one argument", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);

View File

@ -59,7 +59,7 @@ private:
public:
AggregateFunctionIntersectionsMax(AggregateFunctionIntersectionsKind kind_, const DataTypes & arguments)
: kind(kind_)
: IAggregateFunctionDataHelper<MaxIntersectionsData<PointType>, AggregateFunctionIntersectionsMax<PointType>>(arguments, {}), kind(kind_)
{
if (!isNumber(arguments[0]))
throw Exception{getName() + ": first argument must be represented by integer", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};

View File

@ -47,7 +47,7 @@ public:
+ ", because it corresponds to different aggregate function: " + function->getFunctionName() + " instead of " + nested_function->getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return std::make_shared<AggregateFunctionMerge>(nested_function, *argument);
return std::make_shared<AggregateFunctionMerge>(nested_function, argument);
}
};

View File

@ -22,13 +22,14 @@ private:
AggregateFunctionPtr nested_func;
public:
AggregateFunctionMerge(const AggregateFunctionPtr & nested_, const IDataType & argument)
: nested_func(nested_)
AggregateFunctionMerge(const AggregateFunctionPtr & nested_, const DataTypePtr & argument)
: IAggregateFunctionHelper<AggregateFunctionMerge>({argument}, nested_->getParameters())
, nested_func(nested_)
{
const DataTypeAggregateFunction * data_type = typeid_cast<const DataTypeAggregateFunction *>(&argument);
const DataTypeAggregateFunction * data_type = typeid_cast<const DataTypeAggregateFunction *>(argument.get());
if (!data_type || data_type->getFunctionName() != nested_func->getName())
throw Exception("Illegal type " + argument.getName() + " of argument for aggregate function " + getName(),
throw Exception("Illegal type " + argument->getName() + " of argument for aggregate function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}

View File

@ -676,10 +676,12 @@ template <typename Data>
class AggregateFunctionsSingleValue final : public IAggregateFunctionDataHelper<Data, AggregateFunctionsSingleValue<Data>>
{
private:
DataTypePtr type;
DataTypePtr & type;
public:
AggregateFunctionsSingleValue(const DataTypePtr & type) : type(type)
AggregateFunctionsSingleValue(const DataTypePtr & type)
: IAggregateFunctionDataHelper<Data, AggregateFunctionsSingleValue<Data>>({type}, {})
, type(argument_types[0])
{
if (StringRef(Data::name()) == StringRef("min")
|| StringRef(Data::name()) == StringRef("max"))

View File

@ -15,6 +15,9 @@ namespace DB
class AggregateFunctionNothing final : public IAggregateFunctionHelper<AggregateFunctionNothing>
{
public:
AggregateFunctionNothing(const DataTypes & arguments, const Array & params)
: IAggregateFunctionHelper<AggregateFunctionNothing>(arguments, params) {}
String getName() const override
{
return "nothing";

View File

@ -30,7 +30,7 @@ public:
}
AggregateFunctionPtr transformAggregateFunction(
const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array &) const override
const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array & params) const override
{
bool has_nullable_types = false;
bool has_null_types = false;
@ -55,29 +55,29 @@ public:
if (nested_function && nested_function->getName() == "count")
{
if (arguments.size() == 1)
return std::make_shared<AggregateFunctionCountNotNullUnary>(arguments[0]);
return std::make_shared<AggregateFunctionCountNotNullUnary>(arguments[0], params);
else
return std::make_shared<AggregateFunctionCountNotNullVariadic>(arguments);
return std::make_shared<AggregateFunctionCountNotNullVariadic>(arguments, params);
}
if (has_null_types)
return std::make_shared<AggregateFunctionNothing>();
return std::make_shared<AggregateFunctionNothing>(arguments, params);
bool return_type_is_nullable = nested_function->getReturnType()->canBeInsideNullable();
if (arguments.size() == 1)
{
if (return_type_is_nullable)
return std::make_shared<AggregateFunctionNullUnary<true>>(nested_function);
return std::make_shared<AggregateFunctionNullUnary<true>>(nested_function, arguments, params);
else
return std::make_shared<AggregateFunctionNullUnary<false>>(nested_function);
return std::make_shared<AggregateFunctionNullUnary<false>>(nested_function, arguments, params);
}
else
{
if (return_type_is_nullable)
return std::make_shared<AggregateFunctionNullVariadic<true>>(nested_function, arguments);
return std::make_shared<AggregateFunctionNullVariadic<true>>(nested_function, arguments, params);
else
return std::make_shared<AggregateFunctionNullVariadic<false>>(nested_function, arguments);
return std::make_shared<AggregateFunctionNullVariadic<false>>(nested_function, arguments, params);
}
}
};

View File

@ -68,8 +68,8 @@ protected:
}
public:
AggregateFunctionNullBase(AggregateFunctionPtr nested_function_)
: nested_function{nested_function_}
AggregateFunctionNullBase(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params)
: IAggregateFunctionHelper<Derived>(arguments, params), nested_function{nested_function_}
{
if (result_is_nullable)
prefix_size = nested_function->alignOfData();
@ -187,8 +187,8 @@ template <bool result_is_nullable>
class AggregateFunctionNullUnary final : public AggregateFunctionNullBase<result_is_nullable, AggregateFunctionNullUnary<result_is_nullable>>
{
public:
AggregateFunctionNullUnary(AggregateFunctionPtr nested_function_)
: AggregateFunctionNullBase<result_is_nullable, AggregateFunctionNullUnary<result_is_nullable>>(std::move(nested_function_))
AggregateFunctionNullUnary(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params)
: AggregateFunctionNullBase<result_is_nullable, AggregateFunctionNullUnary<result_is_nullable>>(std::move(nested_function_), arguments, params)
{
}
@ -209,8 +209,8 @@ template <bool result_is_nullable>
class AggregateFunctionNullVariadic final : public AggregateFunctionNullBase<result_is_nullable, AggregateFunctionNullVariadic<result_is_nullable>>
{
public:
AggregateFunctionNullVariadic(AggregateFunctionPtr nested_function_, const DataTypes & arguments)
: AggregateFunctionNullBase<result_is_nullable, AggregateFunctionNullVariadic<result_is_nullable>>(std::move(nested_function_)),
AggregateFunctionNullVariadic(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params)
: AggregateFunctionNullBase<result_is_nullable, AggregateFunctionNullVariadic<result_is_nullable>>(std::move(nested_function_), arguments, params),
number_of_arguments(arguments.size())
{
if (number_of_arguments == 1)

View File

@ -73,11 +73,12 @@ private:
/// Used when there are single level to get.
Float64 level = 0.5;
DataTypePtr argument_type;
DataTypePtr & argument_type;
public:
AggregateFunctionQuantile(const DataTypePtr & argument_type, const Array & params)
: levels(params, returns_many), level(levels.levels[0]), argument_type(argument_type)
: IAggregateFunctionDataHelper<Data, AggregateFunctionQuantile<Value, Data, Name, has_second_arg, FloatReturnType, returns_many>>({argument_type}, params)
, levels(params, returns_many), level(levels.levels[0]), argument_type(argument_types[0])
{
if (!returns_many && levels.size() > 1)
throw Exception("Aggregate function " + getName() + " require one parameter or less", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);

View File

@ -76,6 +76,7 @@ public:
}
AggregateFunctionRetention(const DataTypes & arguments)
: IAggregateFunctionDataHelper<AggregateFunctionRetentionData, AggregateFunctionRetention>(arguments, {})
{
for (const auto i : ext::range(0, arguments.size()))
{

View File

@ -19,7 +19,7 @@ AggregateFunctionPtr createAggregateFunctionSequenceCount(const std::string & na
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH};
String pattern = params.front().safeGet<std::string>();
return std::make_shared<AggregateFunctionSequenceCount>(argument_types, pattern);
return std::make_shared<AggregateFunctionSequenceCount>(argument_types, params, pattern);
}
AggregateFunctionPtr createAggregateFunctionSequenceMatch(const std::string & name, const DataTypes & argument_types, const Array & params)
@ -29,7 +29,7 @@ AggregateFunctionPtr createAggregateFunctionSequenceMatch(const std::string & na
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH};
String pattern = params.front().safeGet<std::string>();
return std::make_shared<AggregateFunctionSequenceMatch>(argument_types, pattern);
return std::make_shared<AggregateFunctionSequenceMatch>(argument_types, params, pattern);
}
}

View File

@ -139,8 +139,9 @@ template <typename Derived>
class AggregateFunctionSequenceBase : public IAggregateFunctionDataHelper<AggregateFunctionSequenceMatchData, Derived>
{
public:
AggregateFunctionSequenceBase(const DataTypes & arguments, const String & pattern)
: pattern(pattern)
AggregateFunctionSequenceBase(const DataTypes & arguments, const Array & params, const String & pattern)
: IAggregateFunctionDataHelper<AggregateFunctionSequenceMatchData, Derived>(arguments, params)
, pattern(pattern)
{
arg_count = arguments.size();
@ -578,6 +579,9 @@ private:
class AggregateFunctionSequenceMatch final : public AggregateFunctionSequenceBase<AggregateFunctionSequenceMatch>
{
public:
AggregateFunctionSequenceMatch(const DataTypes & arguments, const Array & params, const String & pattern)
: AggregateFunctionSequenceBase<AggregateFunctionSequenceMatch>(arguments, params, pattern) {}
using AggregateFunctionSequenceBase<AggregateFunctionSequenceMatch>::AggregateFunctionSequenceBase;
String getName() const override { return "sequenceMatch"; }
@ -603,6 +607,9 @@ public:
class AggregateFunctionSequenceCount final : public AggregateFunctionSequenceBase<AggregateFunctionSequenceCount>
{
public:
AggregateFunctionSequenceCount(const DataTypes & arguments, const Array & params, const String & pattern)
: AggregateFunctionSequenceBase<AggregateFunctionSequenceCount>(arguments, params, pattern) {}
using AggregateFunctionSequenceBase<AggregateFunctionSequenceCount>::AggregateFunctionSequenceBase;
String getName() const override { return "sequenceCount"; }

View File

@ -24,7 +24,8 @@ private:
public:
AggregateFunctionState(AggregateFunctionPtr nested, const DataTypes & arguments, const Array & params)
: nested_func(nested), arguments(arguments), params(params) {}
: IAggregateFunctionHelper<AggregateFunctionState>(arguments, params)
, nested_func(nested), arguments(arguments), params(params) {}
String getName() const override
{

View File

@ -21,7 +21,7 @@ AggregateFunctionPtr createAggregateFunctionStatisticsUnary(const std::string &
assertNoParameters(name, parameters);
assertUnary(name, argument_types);
AggregateFunctionPtr res(createWithNumericType<FunctionTemplate>(*argument_types[0]));
AggregateFunctionPtr res(createWithNumericType<FunctionTemplate>(*argument_types[0], argument_types[0]));
if (!res)
throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
@ -35,7 +35,7 @@ AggregateFunctionPtr createAggregateFunctionStatisticsBinary(const std::string &
assertNoParameters(name, parameters);
assertBinary(name, argument_types);
AggregateFunctionPtr res(createWithTwoNumericTypes<FunctionTemplate>(*argument_types[0], *argument_types[1]));
AggregateFunctionPtr res(createWithTwoNumericTypes<FunctionTemplate>(*argument_types[0], *argument_types[1], argument_types));
if (!res)
throw Exception("Illegal types " + argument_types[0]->getName() + " and " + argument_types[1]->getName()
+ " of arguments for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

View File

@ -111,6 +111,9 @@ class AggregateFunctionVariance final
: public IAggregateFunctionDataHelper<AggregateFunctionVarianceData<T, Op>, AggregateFunctionVariance<T, Op>>
{
public:
AggregateFunctionVariance(const DataTypePtr & arg)
: IAggregateFunctionDataHelper<AggregateFunctionVarianceData<T, Op>, AggregateFunctionVariance<T, Op>>({arg}, {}) {}
String getName() const override { return Op::name; }
DataTypePtr getReturnType() const override
@ -361,6 +364,10 @@ class AggregateFunctionCovariance final
AggregateFunctionCovariance<T, U, Op, compute_marginal_moments>>
{
public:
AggregateFunctionCovariance(const DataTypes & args) : IAggregateFunctionDataHelper<
CovarianceData<T, U, Op, compute_marginal_moments>,
AggregateFunctionCovariance<T, U, Op, compute_marginal_moments>>(args, {}) {}
String getName() const override { return Op::name; }
DataTypePtr getReturnType() const override

View File

@ -288,12 +288,14 @@ public:
using ResultType = typename StatFunc::ResultType;
using ColVecResult = ColumnVector<ResultType>;
AggregateFunctionVarianceSimple()
: src_scale(0)
AggregateFunctionVarianceSimple(const DataTypes & argument_types)
: IAggregateFunctionDataHelper<typename StatFunc::Data, AggregateFunctionVarianceSimple<StatFunc>>(argument_types, {})
, src_scale(0)
{}
AggregateFunctionVarianceSimple(const IDataType & data_type)
: src_scale(getDecimalScale(data_type))
AggregateFunctionVarianceSimple(const IDataType & data_type, const DataTypes & argument_types)
: IAggregateFunctionDataHelper<typename StatFunc::Data, AggregateFunctionVarianceSimple<StatFunc>>(argument_types, {})
, src_scale(getDecimalScale(data_type))
{}
String getName() const override

View File

@ -50,9 +50,9 @@ AggregateFunctionPtr createAggregateFunctionSum(const std::string & name, const
AggregateFunctionPtr res;
DataTypePtr data_type = argument_types[0];
if (isDecimal(data_type))
res.reset(createWithDecimalType<Function>(*data_type, *data_type));
res.reset(createWithDecimalType<Function>(*data_type, *data_type, argument_types));
else
res.reset(createWithNumericType<Function>(*data_type));
res.reset(createWithNumericType<Function>(*data_type, argument_types));
if (!res)
throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name,

View File

@ -102,12 +102,14 @@ public:
String getName() const override { return "sum"; }
AggregateFunctionSum()
: scale(0)
AggregateFunctionSum(const DataTypes & argument_types)
: IAggregateFunctionDataHelper<Data, AggregateFunctionSum<T, TResult, Data>>(argument_types, {})
, scale(0)
{}
AggregateFunctionSum(const IDataType & data_type)
: scale(getDecimalScale(data_type))
AggregateFunctionSum(const IDataType & data_type, const DataTypes & argument_types)
: IAggregateFunctionDataHelper<Data, AggregateFunctionSum<T, TResult, Data>>(argument_types, {})
, scale(getDecimalScale(data_type))
{}
DataTypePtr getReturnType() const override

View File

@ -80,7 +80,7 @@ AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, con
auto [keys_type, values_types] = parseArguments(name, arguments);
AggregateFunctionPtr res(createWithNumericBasedType<Function>(*keys_type, keys_type, values_types));
AggregateFunctionPtr res(createWithNumericBasedType<Function>(*keys_type, keys_type, values_types, arguments));
if (!res)
res.reset(createWithDecimalType<Function>(*keys_type, keys_type, values_types));
if (!res)
@ -103,7 +103,7 @@ AggregateFunctionPtr createAggregateFunctionSumMapFiltered(const std::string & n
auto [keys_type, values_types] = parseArguments(name, arguments);
AggregateFunctionPtr res(createWithNumericBasedType<Function>(*keys_type, keys_type, values_types, keys_to_keep));
AggregateFunctionPtr res(createWithNumericBasedType<Function>(*keys_type, keys_type, values_types, keys_to_keep, arguments, params));
if (!res)
res.reset(createWithDecimalType<Function>(*keys_type, keys_type, values_types, keys_to_keep));
if (!res)

View File

@ -61,8 +61,11 @@ private:
DataTypes values_types;
public:
AggregateFunctionSumMapBase(const DataTypePtr & keys_type, const DataTypes & values_types)
: keys_type(keys_type), values_types(values_types) {}
AggregateFunctionSumMapBase(
const DataTypePtr & keys_type, const DataTypes & values_types,
const DataTypes & argument_types, const Array & params)
: IAggregateFunctionDataHelper<AggregateFunctionSumMapData<NearestFieldType<T>>, Derived>(argument_types, params)
, keys_type(keys_type), values_types(values_types) {}
String getName() const override { return "sumMap"; }
@ -271,8 +274,8 @@ private:
using Base = AggregateFunctionSumMapBase<T, Self, OverflowPolicy>;
public:
AggregateFunctionSumMap(const DataTypePtr & keys_type, DataTypes & values_types)
: Base{keys_type, values_types}
AggregateFunctionSumMap(const DataTypePtr & keys_type, DataTypes & values_types, const DataTypes & argument_types)
: Base{keys_type, values_types, argument_types, {}}
{}
String getName() const override { return "sumMap"; }
@ -291,8 +294,10 @@ private:
std::unordered_set<T> keys_to_keep;
public:
AggregateFunctionSumMapFiltered(const DataTypePtr & keys_type, const DataTypes & values_types, const Array & keys_to_keep_)
: Base{keys_type, values_types}
AggregateFunctionSumMapFiltered(
const DataTypePtr & keys_type, const DataTypes & values_types, const Array & keys_to_keep_,
const DataTypes & argument_types, const Array & params)
: Base{keys_type, values_types, argument_types, params}
{
keys_to_keep.reserve(keys_to_keep_.size());
for (const Field & f : keys_to_keep_)

View File

@ -39,19 +39,19 @@ class AggregateFunctionTopKDateTime : public AggregateFunctionTopK<DataTypeDateT
template <bool is_weighted>
static IAggregateFunction * createWithExtraTypes(const DataTypePtr & argument_type, UInt64 threshold)
static IAggregateFunction * createWithExtraTypes(const DataTypePtr & argument_type, UInt64 threshold, const Array & params)
{
WhichDataType which(argument_type);
if (which.idx == TypeIndex::Date)
return new AggregateFunctionTopKDate<is_weighted>(threshold);
return new AggregateFunctionTopKDate<is_weighted>(threshold, {argument_type}, params);
if (which.idx == TypeIndex::DateTime)
return new AggregateFunctionTopKDateTime<is_weighted>(threshold);
return new AggregateFunctionTopKDateTime<is_weighted>(threshold, {argument_type}, params);
/// Check that we can use plain version of AggregateFunctionTopKGeneric
if (argument_type->isValueUnambiguouslyRepresentedInContiguousMemoryRegion())
return new AggregateFunctionTopKGeneric<true, is_weighted>(threshold, argument_type);
return new AggregateFunctionTopKGeneric<true, is_weighted>(threshold, argument_type, params);
else
return new AggregateFunctionTopKGeneric<false, is_weighted>(threshold, argument_type);
return new AggregateFunctionTopKGeneric<false, is_weighted>(threshold, argument_type, params);
}
@ -90,10 +90,10 @@ AggregateFunctionPtr createAggregateFunctionTopK(const std::string & name, const
threshold = k;
}
AggregateFunctionPtr res(createWithNumericType<AggregateFunctionTopK, is_weighted>(*argument_types[0], threshold));
AggregateFunctionPtr res(createWithNumericType<AggregateFunctionTopK, is_weighted>(*argument_types[0], threshold, argument_types, params));
if (!res)
res = AggregateFunctionPtr(createWithExtraTypes<is_weighted>(argument_types[0], threshold));
res = AggregateFunctionPtr(createWithExtraTypes<is_weighted>(argument_types[0], threshold, params));
if (!res)
throw Exception("Illegal type " + argument_types[0]->getName() +

View File

@ -48,8 +48,9 @@ protected:
UInt64 reserved;
public:
AggregateFunctionTopK(UInt64 threshold)
: threshold(threshold), reserved(TOP_K_LOAD_FACTOR * threshold) {}
AggregateFunctionTopK(UInt64 threshold, const DataTypes & argument_types, const Array & params)
: IAggregateFunctionDataHelper<AggregateFunctionTopKData<T>, AggregateFunctionTopK<T, is_weighted>>(argument_types, params)
, threshold(threshold), reserved(TOP_K_LOAD_FACTOR * threshold) {}
String getName() const override { return is_weighted ? "topKWeighted" : "topK"; }
@ -136,13 +137,15 @@ private:
UInt64 threshold;
UInt64 reserved;
DataTypePtr input_data_type;
DataTypePtr & input_data_type;
static void deserializeAndInsert(StringRef str, IColumn & data_to);
public:
AggregateFunctionTopKGeneric(UInt64 threshold, const DataTypePtr & input_data_type)
: threshold(threshold), reserved(TOP_K_LOAD_FACTOR * threshold), input_data_type(input_data_type) {}
AggregateFunctionTopKGeneric(
UInt64 threshold, const DataTypePtr & input_data_type, const Array & params)
: IAggregateFunctionDataHelper<AggregateFunctionTopKGenericData, AggregateFunctionTopKGeneric<is_plain_column, is_weighted>>({input_data_type}, params)
, threshold(threshold), reserved(TOP_K_LOAD_FACTOR * threshold), input_data_type(argument_types[0]) {}
String getName() const override { return is_weighted ? "topKWeighted" : "topK"; }

View File

@ -43,19 +43,19 @@ AggregateFunctionPtr createAggregateFunctionUniq(const std::string & name, const
{
const IDataType & argument_type = *argument_types[0];
AggregateFunctionPtr res(createWithNumericType<AggregateFunctionUniq, Data>(*argument_types[0]));
AggregateFunctionPtr res(createWithNumericType<AggregateFunctionUniq, Data>(*argument_types[0], argument_types));
WhichDataType which(argument_type);
if (res)
return res;
else if (which.isDate())
return std::make_shared<AggregateFunctionUniq<DataTypeDate::FieldType, Data>>();
return std::make_shared<AggregateFunctionUniq<DataTypeDate::FieldType, Data>>(argument_types);
else if (which.isDateTime())
return std::make_shared<AggregateFunctionUniq<DataTypeDateTime::FieldType, Data>>();
return std::make_shared<AggregateFunctionUniq<DataTypeDateTime::FieldType, Data>>(argument_types);
else if (which.isStringOrFixedString())
return std::make_shared<AggregateFunctionUniq<String, Data>>();
return std::make_shared<AggregateFunctionUniq<String, Data>>(argument_types);
else if (which.isUUID())
return std::make_shared<AggregateFunctionUniq<DataTypeUUID::FieldType, Data>>();
return std::make_shared<AggregateFunctionUniq<DataTypeUUID::FieldType, Data>>(argument_types);
else if (which.isTuple())
{
if (use_exact_hash_function)
@ -89,19 +89,19 @@ AggregateFunctionPtr createAggregateFunctionUniq(const std::string & name, const
{
const IDataType & argument_type = *argument_types[0];
AggregateFunctionPtr res(createWithNumericType<AggregateFunctionUniq, Data>(*argument_types[0]));
AggregateFunctionPtr res(createWithNumericType<AggregateFunctionUniq, Data>(*argument_types[0], argument_types));
WhichDataType which(argument_type);
if (res)
return res;
else if (which.isDate())
return std::make_shared<AggregateFunctionUniq<DataTypeDate::FieldType, Data<DataTypeDate::FieldType>>>();
return std::make_shared<AggregateFunctionUniq<DataTypeDate::FieldType, Data<DataTypeDate::FieldType>>>(argument_types);
else if (which.isDateTime())
return std::make_shared<AggregateFunctionUniq<DataTypeDateTime::FieldType, Data<DataTypeDateTime::FieldType>>>();
return std::make_shared<AggregateFunctionUniq<DataTypeDateTime::FieldType, Data<DataTypeDateTime::FieldType>>>(argument_types);
else if (which.isStringOrFixedString())
return std::make_shared<AggregateFunctionUniq<String, Data<String>>>();
else if (which.isUUID())
return std::make_shared<AggregateFunctionUniq<DataTypeUUID::FieldType, Data<DataTypeUUID::FieldType>>>();
return std::make_shared<AggregateFunctionUniq<DataTypeUUID::FieldType, Data<DataTypeUUID::FieldType>>>(argument_types);
else if (which.isTuple())
{
if (use_exact_hash_function)

View File

@ -209,6 +209,9 @@ template <typename T, typename Data>
class AggregateFunctionUniq final : public IAggregateFunctionDataHelper<Data, AggregateFunctionUniq<T, Data>>
{
public:
AggregateFunctionUniq(const DataTypes & argument_types)
: IAggregateFunctionDataHelper<Data, AggregateFunctionUniq<T, Data>>(argument_types, {}) {}
String getName() const override { return Data::getName(); }
DataTypePtr getReturnType() const override
@ -257,6 +260,7 @@ private:
public:
AggregateFunctionUniqVariadic(const DataTypes & arguments)
: IAggregateFunctionDataHelper<Data, AggregateFunctionUniqVariadic<Data, is_exact, argument_is_tuple>>(arguments)
{
if (argument_is_tuple)
num_args = typeid_cast<const DataTypeTuple &>(*arguments[0]).getElements().size();

View File

@ -28,7 +28,7 @@ namespace
};
template <UInt8 K>
AggregateFunctionPtr createAggregateFunctionWithK(const DataTypes & argument_types)
AggregateFunctionPtr createAggregateFunctionWithK(const DataTypes & argument_types, const Array & params)
{
/// We use exact hash function if the arguments are not contiguous in memory, because only exact hash function has support for this case.
bool use_exact_hash_function = !isAllArgumentsContiguousInMemory(argument_types);
@ -37,33 +37,33 @@ namespace
{
const IDataType & argument_type = *argument_types[0];
AggregateFunctionPtr res(createWithNumericType<WithK<K>::template AggregateFunction>(*argument_types[0]));
AggregateFunctionPtr res(createWithNumericType<WithK<K>::template AggregateFunction>(*argument_types[0], argument_types, params));
WhichDataType which(argument_type);
if (res)
return res;
else if (which.isDate())
return std::make_shared<typename WithK<K>::template AggregateFunction<DataTypeDate::FieldType>>();
return std::make_shared<typename WithK<K>::template AggregateFunction<DataTypeDate::FieldType>>(argument_types, params);
else if (which.isDateTime())
return std::make_shared<typename WithK<K>::template AggregateFunction<DataTypeDateTime::FieldType>>();
return std::make_shared<typename WithK<K>::template AggregateFunction<DataTypeDateTime::FieldType>>(argument_types, params);
else if (which.isStringOrFixedString())
return std::make_shared<typename WithK<K>::template AggregateFunction<String>>();
return std::make_shared<typename WithK<K>::template AggregateFunction<String>>(argument_types, params);
else if (which.isUUID())
return std::make_shared<typename WithK<K>::template AggregateFunction<DataTypeUUID::FieldType>>();
return std::make_shared<typename WithK<K>::template AggregateFunction<DataTypeUUID::FieldType>>(argument_types, params);
else if (which.isTuple())
{
if (use_exact_hash_function)
return std::make_shared<typename WithK<K>::template AggregateFunctionVariadic<true, true>>(argument_types);
return std::make_shared<typename WithK<K>::template AggregateFunctionVariadic<true, true>>(argument_types, params);
else
return std::make_shared<typename WithK<K>::template AggregateFunctionVariadic<false, true>>(argument_types);
return std::make_shared<typename WithK<K>::template AggregateFunctionVariadic<false, true>>(argument_types, params);
}
}
/// "Variadic" method also works as a fallback generic case for a single argument.
if (use_exact_hash_function)
return std::make_shared<typename WithK<K>::template AggregateFunctionVariadic<true, false>>(argument_types);
return std::make_shared<typename WithK<K>::template AggregateFunctionVariadic<true, false>>(argument_types, params);
else
return std::make_shared<typename WithK<K>::template AggregateFunctionVariadic<false, false>>(argument_types);
return std::make_shared<typename WithK<K>::template AggregateFunctionVariadic<false, false>>(argument_types, params);
}
AggregateFunctionPtr createAggregateFunctionUniqCombined(
@ -95,23 +95,23 @@ namespace
switch (precision)
{
case 12:
return createAggregateFunctionWithK<12>(argument_types);
return createAggregateFunctionWithK<12>(argument_types, params);
case 13:
return createAggregateFunctionWithK<13>(argument_types);
return createAggregateFunctionWithK<13>(argument_types, params);
case 14:
return createAggregateFunctionWithK<14>(argument_types);
return createAggregateFunctionWithK<14>(argument_types, params);
case 15:
return createAggregateFunctionWithK<15>(argument_types);
return createAggregateFunctionWithK<15>(argument_types, params);
case 16:
return createAggregateFunctionWithK<16>(argument_types);
return createAggregateFunctionWithK<16>(argument_types, params);
case 17:
return createAggregateFunctionWithK<17>(argument_types);
return createAggregateFunctionWithK<17>(argument_types, params);
case 18:
return createAggregateFunctionWithK<18>(argument_types);
return createAggregateFunctionWithK<18>(argument_types, params);
case 19:
return createAggregateFunctionWithK<19>(argument_types);
return createAggregateFunctionWithK<19>(argument_types, params);
case 20:
return createAggregateFunctionWithK<20>(argument_types);
return createAggregateFunctionWithK<20>(argument_types, params);
}
__builtin_unreachable();

View File

@ -114,6 +114,9 @@ class AggregateFunctionUniqCombined final
: public IAggregateFunctionDataHelper<AggregateFunctionUniqCombinedData<T, K>, AggregateFunctionUniqCombined<T, K>>
{
public:
AggregateFunctionUniqCombined(const DataTypes & argument_types, const Array & params)
: IAggregateFunctionDataHelper<AggregateFunctionUniqCombinedData<T, K>, AggregateFunctionUniqCombined<T, K>>(argument_types, params) {}
String getName() const override
{
return "uniqCombined";
@ -176,7 +179,9 @@ private:
size_t num_args = 0;
public:
explicit AggregateFunctionUniqCombinedVariadic(const DataTypes & arguments)
explicit AggregateFunctionUniqCombinedVariadic(const DataTypes & arguments, const Array & params)
: IAggregateFunctionDataHelper<AggregateFunctionUniqCombinedData<UInt64, K>,
AggregateFunctionUniqCombinedVariadic<is_exact, argument_is_tuple, K>>(arguments, params)
{
if (argument_is_tuple)
num_args = typeid_cast<const DataTypeTuple &>(*arguments[0]).getElements().size();

View File

@ -52,33 +52,33 @@ AggregateFunctionPtr createAggregateFunctionUniqUpTo(const std::string & name, c
{
const IDataType & argument_type = *argument_types[0];
AggregateFunctionPtr res(createWithNumericType<AggregateFunctionUniqUpTo>(*argument_types[0], threshold));
AggregateFunctionPtr res(createWithNumericType<AggregateFunctionUniqUpTo>(*argument_types[0], threshold, argument_types, params));
WhichDataType which(argument_type);
if (res)
return res;
else if (which.isDate())
return std::make_shared<AggregateFunctionUniqUpTo<DataTypeDate::FieldType>>(threshold);
return std::make_shared<AggregateFunctionUniqUpTo<DataTypeDate::FieldType>>(threshold, argument_types, params);
else if (which.isDateTime())
return std::make_shared<AggregateFunctionUniqUpTo<DataTypeDateTime::FieldType>>(threshold);
return std::make_shared<AggregateFunctionUniqUpTo<DataTypeDateTime::FieldType>>(threshold, argument_types, params);
else if (which.isStringOrFixedString())
return std::make_shared<AggregateFunctionUniqUpTo<String>>(threshold);
return std::make_shared<AggregateFunctionUniqUpTo<String>>(threshold, argument_types, params);
else if (which.isUUID())
return std::make_shared<AggregateFunctionUniqUpTo<DataTypeUUID::FieldType>>(threshold);
return std::make_shared<AggregateFunctionUniqUpTo<DataTypeUUID::FieldType>>(threshold, argument_types, params);
else if (which.isTuple())
{
if (use_exact_hash_function)
return std::make_shared<AggregateFunctionUniqUpToVariadic<true, true>>(argument_types, threshold);
return std::make_shared<AggregateFunctionUniqUpToVariadic<true, true>>(argument_types, params, threshold);
else
return std::make_shared<AggregateFunctionUniqUpToVariadic<false, true>>(argument_types, threshold);
return std::make_shared<AggregateFunctionUniqUpToVariadic<false, true>>(argument_types, params, threshold);
}
}
/// "Variadic" method also works as a fallback generic case for single argument.
if (use_exact_hash_function)
return std::make_shared<AggregateFunctionUniqUpToVariadic<true, false>>(argument_types, threshold);
return std::make_shared<AggregateFunctionUniqUpToVariadic<true, false>>(argument_types, params, threshold);
else
return std::make_shared<AggregateFunctionUniqUpToVariadic<false, false>>(argument_types, threshold);
return std::make_shared<AggregateFunctionUniqUpToVariadic<false, false>>(argument_types, params, threshold);
}
}

View File

@ -136,8 +136,9 @@ private:
UInt8 threshold;
public:
AggregateFunctionUniqUpTo(UInt8 threshold)
: threshold(threshold)
AggregateFunctionUniqUpTo(UInt8 threshold, const DataTypes & argument_types, const Array & params)
: IAggregateFunctionDataHelper<AggregateFunctionUniqUpToData<T>, AggregateFunctionUniqUpTo<T>>(argument_types, params)
, threshold(threshold)
{
}
@ -195,8 +196,9 @@ private:
UInt8 threshold;
public:
AggregateFunctionUniqUpToVariadic(const DataTypes & arguments, UInt8 threshold)
: threshold(threshold)
AggregateFunctionUniqUpToVariadic(const DataTypes & arguments, const Array & params, UInt8 threshold)
: IAggregateFunctionDataHelper<AggregateFunctionUniqUpToData<UInt64>, AggregateFunctionUniqUpToVariadic<is_exact, argument_is_tuple>>(arguments, params)
, threshold(threshold)
{
if (argument_is_tuple)
num_args = typeid_cast<const DataTypeTuple &>(*arguments[0]).getElements().size();

View File

@ -189,6 +189,7 @@ public:
}
AggregateFunctionWindowFunnel(const DataTypes & arguments, const Array & params)
: IAggregateFunctionDataHelper<AggregateFunctionWindowFunnelData, AggregateFunctionWindowFunnel>(arguments, params)
{
const auto time_arg = arguments.front().get();
if (!WhichDataType(time_arg).isDateTime() && !WhichDataType(time_arg).isUInt32())

View File

@ -24,9 +24,9 @@ AggregateFunctionPtr createAggregateFunctionStatisticsUnary(const std::string &
AggregateFunctionPtr res;
DataTypePtr data_type = argument_types[0];
if (isDecimal(data_type))
res.reset(createWithDecimalType<FunctionTemplate>(*data_type, *data_type));
res.reset(createWithDecimalType<FunctionTemplate>(*data_type, *data_type, argument_types));
else
res.reset(createWithNumericType<FunctionTemplate>(*data_type));
res.reset(createWithNumericType<FunctionTemplate>(*data_type, argument_types));
if (!res)
throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name,
@ -40,7 +40,7 @@ AggregateFunctionPtr createAggregateFunctionStatisticsBinary(const std::string &
assertNoParameters(name, parameters);
assertBinary(name, argument_types);
AggregateFunctionPtr res(createWithTwoNumericTypes<FunctionTemplate>(*argument_types[0], *argument_types[1]));
AggregateFunctionPtr res(createWithTwoNumericTypes<FunctionTemplate>(*argument_types[0], *argument_types[1], argument_types));
if (!res)
throw Exception("Illegal types " + argument_types[0]->getName() + " and " + argument_types[1]->getName()
+ " of arguments for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

View File

@ -37,6 +37,9 @@ using ConstAggregateDataPtr = const char *;
class IAggregateFunction
{
public:
IAggregateFunction(const DataTypes & argument_types_, const Array & parameters_)
: argument_types(argument_types_), parameters(parameters_) {}
/// Get main function name.
virtual String getName() const = 0;
@ -112,17 +115,9 @@ public:
const DataTypes & getArgumentTypes() const { return argument_types; }
const Array & getParameters() const { return parameters; }
private:
protected:
DataTypes argument_types;
Array parameters;
friend class AggregateFunctionFactory;
void setArguments(DataTypes argument_types_, Array parameters_)
{
argument_types = std::move(argument_types_);
parameters = std::move(parameters_);
}
};
@ -137,6 +132,8 @@ private:
}
public:
IAggregateFunctionHelper(const DataTypes & argument_types_, const Array & parameters_)
: IAggregateFunction(argument_types_, parameters_) {}
AddFunc getAddressOfAddFunction() const override { return &addFree; }
};
@ -152,6 +149,10 @@ protected:
static const Data & data(ConstAggregateDataPtr place) { return *reinterpret_cast<const Data*>(place); }
public:
IAggregateFunctionDataHelper(const DataTypes & argument_types_, const Array & parameters_)
: IAggregateFunctionHelper<Derived>(argument_types_, parameters_) {}
void create(AggregateDataPtr place) const override
{
new (place) Data;