diff --git a/dbms/src/AggregateFunctions/AggregateFunctionArgMinMax.h b/dbms/src/AggregateFunctions/AggregateFunctionArgMinMax.h index 9a232e2e77d..42649be78fd 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionArgMinMax.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionArgMinMax.h @@ -31,12 +31,13 @@ template class AggregateFunctionArgMinMax final : public IAggregateFunctionDataHelper> { private: - DataTypePtr type_res; - DataTypePtr type_val; + const DataTypePtr & type_res; + const DataTypePtr & type_val; public: AggregateFunctionArgMinMax(const DataTypePtr & type_res, const DataTypePtr & type_val) - : type_res(type_res), type_val(type_val) + : IAggregateFunctionDataHelper>({type_res, type_val}, {}), + type_res(argument_types[0]), type_val(argument_types[1]) { if (!type_val->isComparable()) throw Exception("Illegal type " + type_val->getName() + " of second argument of aggregate function " + getName() diff --git a/dbms/src/AggregateFunctions/AggregateFunctionArray.h b/dbms/src/AggregateFunctions/AggregateFunctionArray.h index 5dfebf13d52..08fa7c13bc3 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionArray.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionArray.h @@ -28,7 +28,8 @@ private: public: AggregateFunctionArray(AggregateFunctionPtr nested_, const DataTypes & arguments) - : nested_func(nested_), num_arguments(arguments.size()) + : IAggregateFunctionHelper(arguments, {}) + , nested_func(nested_), num_arguments(arguments.size()) { for (const auto & type : arguments) if (!isArray(type)) diff --git a/dbms/src/AggregateFunctions/AggregateFunctionAvg.h b/dbms/src/AggregateFunctions/AggregateFunctionAvg.h index 53b42c42c9a..98604f76742 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionAvg.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionAvg.h @@ -55,7 +55,8 @@ public: /// ctor for Decimals AggregateFunctionAvg(const IDataType & data_type) - : scale(getDecimalScale(data_type)) + : IAggregateFunctionDataHelper>({data_type}, {}) + , scale(getDecimalScale(data_type)) {} String getName() const override { return "avg"; } diff --git a/dbms/src/AggregateFunctions/AggregateFunctionBitwise.cpp b/dbms/src/AggregateFunctions/AggregateFunctionBitwise.cpp index 8c188bcbb8e..e92e1917bd5 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionBitwise.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionBitwise.cpp @@ -21,7 +21,7 @@ AggregateFunctionPtr createAggregateFunctionBitwise(const std::string & name, co + " is illegal, because it cannot be used in bitwise operations", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); - AggregateFunctionPtr res(createWithUnsignedIntegerType(*argument_types[0])); + AggregateFunctionPtr res(createWithUnsignedIntegerType(*argument_types[0], argument_types[0])); if (!res) throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionBitwise.h b/dbms/src/AggregateFunctions/AggregateFunctionBitwise.h index 6d33f010bd0..2788fdccd51 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionBitwise.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionBitwise.h @@ -43,6 +43,9 @@ template class AggregateFunctionBitwise final : public IAggregateFunctionDataHelper> { public: + AggregateFunctionBitwise(const DataTypePtr & type) + : IAggregateFunctionDataHelper>({type}, {}) {} + String getName() const override { return Data::name(); } DataTypePtr getReturnType() const override diff --git a/dbms/src/AggregateFunctions/AggregateFunctionBoundingRatio.h b/dbms/src/AggregateFunctions/AggregateFunctionBoundingRatio.h index 40b13acbbaa..5966993dc65 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionBoundingRatio.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionBoundingRatio.h @@ -111,6 +111,7 @@ public: } AggregateFunctionBoundingRatio(const DataTypes & arguments) + : IAggregateFunctionDataHelper(arguments, {}) { const auto x_arg = arguments.at(0).get(); const auto y_arg = arguments.at(0).get(); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionCount.cpp b/dbms/src/AggregateFunctions/AggregateFunctionCount.cpp index 1df424ecbf2..02dc796a4cf 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionCount.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionCount.cpp @@ -9,12 +9,12 @@ namespace DB namespace { -AggregateFunctionPtr createAggregateFunctionCount(const std::string & name, const DataTypes & /*argument_types*/, const Array & parameters) +AggregateFunctionPtr createAggregateFunctionCount(const std::string & name, const DataTypes & argument_types, const Array & parameters) { assertNoParameters(name, parameters); /// 'count' accept any number of arguments and (in this case of non-Nullable types) simply ignore them. - return std::make_shared(); + return std::make_shared(argument_types); } } diff --git a/dbms/src/AggregateFunctions/AggregateFunctionCount.h b/dbms/src/AggregateFunctions/AggregateFunctionCount.h index f9a1dcb45e2..82958a95fd2 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionCount.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionCount.h @@ -28,6 +28,8 @@ namespace ErrorCodes class AggregateFunctionCount final : public IAggregateFunctionDataHelper { public: + AggregateFunctionCount(const DataTypes & argument_types) : IAggregateFunctionDataHelper(argument_types, {}) {} + String getName() const override { return "count"; } DataTypePtr getReturnType() const override @@ -74,7 +76,8 @@ public: class AggregateFunctionCountNotNullUnary final : public IAggregateFunctionDataHelper { public: - AggregateFunctionCountNotNullUnary(const DataTypePtr & argument) + AggregateFunctionCountNotNullUnary(const DataTypePtr & argument, const Array & params) + : IAggregateFunctionDataHelper({argument}, params) { if (!argument->isNullable()) throw Exception("Logical error: not Nullable data type passed to AggregateFunctionCountNotNullUnary", ErrorCodes::LOGICAL_ERROR); @@ -120,7 +123,8 @@ public: class AggregateFunctionCountNotNullVariadic final : public IAggregateFunctionDataHelper { public: - AggregateFunctionCountNotNullVariadic(const DataTypes & arguments) + AggregateFunctionCountNotNullVariadic(const DataTypes & arguments, const Array & params) + : IAggregateFunctionDataHelper(arguments, params) { number_of_arguments = arguments.size(); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionEntropy.cpp b/dbms/src/AggregateFunctions/AggregateFunctionEntropy.cpp index 2f9910c97de..7ea15e11b72 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionEntropy.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionEntropy.cpp @@ -26,12 +26,12 @@ AggregateFunctionPtr createAggregateFunctionEntropy(const std::string & name, co if (num_args == 1) { /// Specialized implementation for single argument of numeric type. - if (auto res = createWithNumericBasedType(*argument_types[0], num_args)) + if (auto res = createWithNumericBasedType(*argument_types[0], argument_types)) return AggregateFunctionPtr(res); } /// Generic implementation for other types or for multiple arguments. - return std::make_shared>(num_args); + return std::make_shared>(argument_types); } } diff --git a/dbms/src/AggregateFunctions/AggregateFunctionEntropy.h b/dbms/src/AggregateFunctions/AggregateFunctionEntropy.h index 1adeefc6397..91ec6d4d5a6 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionEntropy.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionEntropy.h @@ -97,7 +97,9 @@ private: size_t num_args; public: - AggregateFunctionEntropy(size_t num_args) : num_args(num_args) + AggregateFunctionEntropy(const DataTypes & argument_types) + : IAggregateFunctionDataHelper, AggregateFunctionEntropy>(argument_types, {}) + , num_args(argument_types.size()) { } diff --git a/dbms/src/AggregateFunctions/AggregateFunctionFactory.cpp b/dbms/src/AggregateFunctions/AggregateFunctionFactory.cpp index 932d6615385..6aeaaef2bfa 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionFactory.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionFactory.cpp @@ -86,17 +86,12 @@ AggregateFunctionPtr AggregateFunctionFactory::get( [](const auto & type) { return type->onlyNull(); })) nested_function = getImpl(name, nested_types, parameters, recursion_level); - auto res = combinator->transformAggregateFunction(nested_function, type_without_low_cardinality, parameters); - res->setArguments(type_without_low_cardinality, parameters); - return res; + return combinator->transformAggregateFunction(nested_function, argument_types, parameters); } auto res = getImpl(name, type_without_low_cardinality, parameters, recursion_level); if (!res) throw Exception("Logical error: AggregateFunctionFactory returned nullptr", ErrorCodes::LOGICAL_ERROR); - - res->setArguments(type_without_low_cardinality, parameters); - return res; } diff --git a/dbms/src/AggregateFunctions/AggregateFunctionForEach.h b/dbms/src/AggregateFunctions/AggregateFunctionForEach.h index 519d1911a8a..39a52a7fa6e 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionForEach.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionForEach.h @@ -97,7 +97,8 @@ private: public: AggregateFunctionForEach(AggregateFunctionPtr nested_, const DataTypes & arguments) - : nested_func(nested_), num_arguments(arguments.size()) + : IAggregateFunctionDataHelper(arguments, {}) + , nested_func(nested_), num_arguments(arguments.size()) { nested_size_of_data = nested_func->sizeOfData(); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionGroupArray.h b/dbms/src/AggregateFunctions/AggregateFunctionGroupArray.h index 26708c87520..c496e90844d 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionGroupArray.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionGroupArray.h @@ -48,12 +48,13 @@ class GroupArrayNumericImpl final : public IAggregateFunctionDataHelper, GroupArrayNumericImpl> { static constexpr bool limit_num_elems = Tlimit_num_elems::value; - DataTypePtr data_type; + DataTypePtr & data_type; UInt64 max_elems; public: explicit GroupArrayNumericImpl(const DataTypePtr & data_type_, UInt64 max_elems_ = std::numeric_limits::max()) - : data_type(data_type_), max_elems(max_elems_) {} + : IAggregateFunctionDataHelper, GroupArrayNumericImpl>({data_type}, {}) + , data_type(argument_types[0]), max_elems(max_elems_) {} String getName() const override { return "groupArray"; } diff --git a/dbms/src/AggregateFunctions/AggregateFunctionGroupArrayInsertAt.cpp b/dbms/src/AggregateFunctions/AggregateFunctionGroupArrayInsertAt.cpp index bc8fac86d6d..ea42c129dea 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionGroupArrayInsertAt.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionGroupArrayInsertAt.cpp @@ -13,6 +13,10 @@ namespace AggregateFunctionPtr createAggregateFunctionGroupArrayInsertAt(const std::string & name, const DataTypes & argument_types, const Array & parameters) { assertBinary(name, argument_types); + + if (argument_types.size() != 2) + throw Exception("Aggregate function groupArrayInsertAt requires two arguments.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + return std::make_shared(argument_types, parameters); } diff --git a/dbms/src/AggregateFunctions/AggregateFunctionGroupArrayInsertAt.h b/dbms/src/AggregateFunctions/AggregateFunctionGroupArrayInsertAt.h index 90b19266e4c..c7dab21a4cb 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionGroupArrayInsertAt.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionGroupArrayInsertAt.h @@ -54,12 +54,14 @@ class AggregateFunctionGroupArrayInsertAtGeneric final : public IAggregateFunctionDataHelper { private: - DataTypePtr type; + DataTypePtr & type; Field default_value; UInt64 length_to_resize = 0; /// zero means - do not do resizing. public: AggregateFunctionGroupArrayInsertAtGeneric(const DataTypes & arguments, const Array & params) + : IAggregateFunctionDataHelper(arguments, params) + , type(argument_types[0]) { if (!params.empty()) { @@ -76,14 +78,9 @@ public: } } - if (arguments.size() != 2) - throw Exception("Aggregate function " + getName() + " requires two arguments.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); - if (!isUnsignedInteger(arguments[1])) throw Exception("Second argument of aggregate function " + getName() + " must be integer.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); - type = arguments.front(); - if (default_value.isNull()) default_value = type->getDefault(); else diff --git a/dbms/src/AggregateFunctions/AggregateFunctionGroupUniqArray.cpp b/dbms/src/AggregateFunctions/AggregateFunctionGroupUniqArray.cpp index a84ba2b28a2..f80a45afaa9 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionGroupUniqArray.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionGroupUniqArray.cpp @@ -15,11 +15,15 @@ namespace /// Substitute return type for Date and DateTime class AggregateFunctionGroupUniqArrayDate : public AggregateFunctionGroupUniqArray { +public: + AggregateFunctionGroupUniqArrayDate(const DataTypePtr & argument_type) : AggregateFunctionGroupUniqArray(argument_type) {} DataTypePtr getReturnType() const override { return std::make_shared(std::make_shared()); } }; class AggregateFunctionGroupUniqArrayDateTime : public AggregateFunctionGroupUniqArray { +public: + AggregateFunctionGroupUniqArrayDateTime(const DataTypePtr & argument_type) : AggregateFunctionGroupUniqArray(argument_type) {} DataTypePtr getReturnType() const override { return std::make_shared(std::make_shared()); } }; @@ -27,8 +31,8 @@ class AggregateFunctionGroupUniqArrayDateTime : public AggregateFunctionGroupUni static IAggregateFunction * createWithExtraTypes(const DataTypePtr & argument_type) { WhichDataType which(argument_type); - if (which.idx == TypeIndex::Date) return new AggregateFunctionGroupUniqArrayDate; - else if (which.idx == TypeIndex::DateTime) return new AggregateFunctionGroupUniqArrayDateTime; + if (which.idx == TypeIndex::Date) return new AggregateFunctionGroupUniqArrayDate(argument_type); + else if (which.idx == TypeIndex::DateTime) return new AggregateFunctionGroupUniqArrayDateTime(argument_type); else { /// Check that we can use plain version of AggreagteFunctionGroupUniqArrayGeneric diff --git a/dbms/src/AggregateFunctions/AggregateFunctionGroupUniqArray.h b/dbms/src/AggregateFunctions/AggregateFunctionGroupUniqArray.h index b638996f553..c0ef1fe0fa8 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionGroupUniqArray.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionGroupUniqArray.h @@ -44,6 +44,9 @@ private: using State = AggregateFunctionGroupUniqArrayData; public: + AggregateFunctionGroupUniqArray(const DataTypePtr & argument_type) + : IAggregateFunctionDataHelper, AggregateFunctionGroupUniqArray>({argument_type}, {}) {} + String getName() const override { return "groupUniqArray"; } DataTypePtr getReturnType() const override @@ -115,7 +118,7 @@ template class AggreagteFunctionGroupUniqArrayGeneric : public IAggregateFunctionDataHelper> { - DataTypePtr input_data_type; + DataTypePtr & input_data_type; using State = AggreagteFunctionGroupUniqArrayGenericData; @@ -125,7 +128,8 @@ class AggreagteFunctionGroupUniqArrayGeneric public: AggreagteFunctionGroupUniqArrayGeneric(const DataTypePtr & input_data_type) - : input_data_type(input_data_type) {} + : IAggregateFunctionDataHelper>({input_data_type}, {}) + , input_data_type(argument_types[0]) {} String getName() const override { return "groupUniqArray"; } diff --git a/dbms/src/AggregateFunctions/AggregateFunctionHistogram.cpp b/dbms/src/AggregateFunctions/AggregateFunctionHistogram.cpp index 05c4fe86320..384298b16a8 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionHistogram.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionHistogram.cpp @@ -39,7 +39,7 @@ AggregateFunctionPtr createAggregateFunctionHistogram(const std::string & name, throw Exception("Bin count should be positive", ErrorCodes::BAD_ARGUMENTS); assertUnary(name, arguments); - AggregateFunctionPtr res(createWithNumericType(*arguments[0], bins_count)); + AggregateFunctionPtr res(createWithNumericType(*arguments[0], bins_count, arguments, params)); if (!res) throw Exception("Illegal type " + arguments[0]->getName() + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionHistogram.h b/dbms/src/AggregateFunctions/AggregateFunctionHistogram.h index 3d03821cc65..60385f4788a 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionHistogram.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionHistogram.h @@ -304,8 +304,9 @@ private: const UInt32 max_bins; public: - AggregateFunctionHistogram(UInt32 max_bins) - : max_bins(max_bins) + AggregateFunctionHistogram(const DataTypes & arguments, const Array & params, UInt32 max_bins) + : IAggregateFunctionDataHelper>(arguments, params) + , max_bins(max_bins) { } diff --git a/dbms/src/AggregateFunctions/AggregateFunctionIf.h b/dbms/src/AggregateFunctions/AggregateFunctionIf.h index 594193eac87..8daf9505ae6 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionIf.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionIf.h @@ -28,7 +28,8 @@ private: public: AggregateFunctionIf(AggregateFunctionPtr nested, const DataTypes & types) - : nested_func(nested), num_arguments(types.size()) + : IAggregateFunctionHelper(types, nested->getParameters()) + , nested_func(nested), num_arguments(types.size()) { if (num_arguments == 0) throw Exception("Aggregate function " + getName() + " require at least one argument", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionMaxIntersections.h b/dbms/src/AggregateFunctions/AggregateFunctionMaxIntersections.h index 9b81ce01f30..dbb727b7d9a 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionMaxIntersections.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionMaxIntersections.h @@ -59,7 +59,7 @@ private: public: AggregateFunctionIntersectionsMax(AggregateFunctionIntersectionsKind kind_, const DataTypes & arguments) - : kind(kind_) + : IAggregateFunctionDataHelper, AggregateFunctionIntersectionsMax>(arguments, {}), kind(kind_) { if (!isNumber(arguments[0])) throw Exception{getName() + ": first argument must be represented by integer", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; diff --git a/dbms/src/AggregateFunctions/AggregateFunctionMerge.cpp b/dbms/src/AggregateFunctions/AggregateFunctionMerge.cpp index 256c7bc9a84..f9c2eb8c9dd 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionMerge.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionMerge.cpp @@ -47,7 +47,7 @@ public: + ", because it corresponds to different aggregate function: " + function->getFunctionName() + " instead of " + nested_function->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); - return std::make_shared(nested_function, *argument); + return std::make_shared(nested_function, argument); } }; diff --git a/dbms/src/AggregateFunctions/AggregateFunctionMerge.h b/dbms/src/AggregateFunctions/AggregateFunctionMerge.h index 2d92db98e17..c94d4d3cf3c 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionMerge.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionMerge.h @@ -22,13 +22,14 @@ private: AggregateFunctionPtr nested_func; public: - AggregateFunctionMerge(const AggregateFunctionPtr & nested_, const IDataType & argument) - : nested_func(nested_) + AggregateFunctionMerge(const AggregateFunctionPtr & nested_, const DataTypePtr & argument) + : IAggregateFunctionHelper({argument}, nested_->getParameters()) + , nested_func(nested_) { - const DataTypeAggregateFunction * data_type = typeid_cast(&argument); + const DataTypeAggregateFunction * data_type = typeid_cast(argument.get()); if (!data_type || data_type->getFunctionName() != nested_func->getName()) - throw Exception("Illegal type " + argument.getName() + " of argument for aggregate function " + getName(), + throw Exception("Illegal type " + argument->getName() + " of argument for aggregate function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); } diff --git a/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.h b/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.h index 51d1e8d1dd7..426ee8ee479 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.h @@ -676,10 +676,12 @@ template class AggregateFunctionsSingleValue final : public IAggregateFunctionDataHelper> { private: - DataTypePtr type; + DataTypePtr & type; public: - AggregateFunctionsSingleValue(const DataTypePtr & type) : type(type) + AggregateFunctionsSingleValue(const DataTypePtr & type) + : IAggregateFunctionDataHelper>({type}, {}) + , type(argument_types[0]) { if (StringRef(Data::name()) == StringRef("min") || StringRef(Data::name()) == StringRef("max")) diff --git a/dbms/src/AggregateFunctions/AggregateFunctionNothing.h b/dbms/src/AggregateFunctions/AggregateFunctionNothing.h index 3a98807bb4a..aa54d95f158 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionNothing.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionNothing.h @@ -15,6 +15,9 @@ namespace DB class AggregateFunctionNothing final : public IAggregateFunctionHelper { public: + AggregateFunctionNothing(const DataTypes & arguments, const Array & params) + : IAggregateFunctionHelper(arguments, params) {} + String getName() const override { return "nothing"; diff --git a/dbms/src/AggregateFunctions/AggregateFunctionNull.cpp b/dbms/src/AggregateFunctions/AggregateFunctionNull.cpp index 6ce7d94d970..7011ebbde09 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionNull.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionNull.cpp @@ -30,7 +30,7 @@ public: } AggregateFunctionPtr transformAggregateFunction( - const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array &) const override + const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array & params) const override { bool has_nullable_types = false; bool has_null_types = false; @@ -55,29 +55,29 @@ public: if (nested_function && nested_function->getName() == "count") { if (arguments.size() == 1) - return std::make_shared(arguments[0]); + return std::make_shared(arguments[0], params); else - return std::make_shared(arguments); + return std::make_shared(arguments, params); } if (has_null_types) - return std::make_shared(); + return std::make_shared(arguments, params); bool return_type_is_nullable = nested_function->getReturnType()->canBeInsideNullable(); if (arguments.size() == 1) { if (return_type_is_nullable) - return std::make_shared>(nested_function); + return std::make_shared>(nested_function, arguments, params); else - return std::make_shared>(nested_function); + return std::make_shared>(nested_function, arguments, params); } else { if (return_type_is_nullable) - return std::make_shared>(nested_function, arguments); + return std::make_shared>(nested_function, arguments, params); else - return std::make_shared>(nested_function, arguments); + return std::make_shared>(nested_function, arguments, params); } } }; diff --git a/dbms/src/AggregateFunctions/AggregateFunctionNull.h b/dbms/src/AggregateFunctions/AggregateFunctionNull.h index c8676230500..ab4b5b27844 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionNull.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionNull.h @@ -68,8 +68,8 @@ protected: } public: - AggregateFunctionNullBase(AggregateFunctionPtr nested_function_) - : nested_function{nested_function_} + AggregateFunctionNullBase(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params) + : IAggregateFunctionHelper(arguments, params), nested_function{nested_function_} { if (result_is_nullable) prefix_size = nested_function->alignOfData(); @@ -187,8 +187,8 @@ template class AggregateFunctionNullUnary final : public AggregateFunctionNullBase> { public: - AggregateFunctionNullUnary(AggregateFunctionPtr nested_function_) - : AggregateFunctionNullBase>(std::move(nested_function_)) + AggregateFunctionNullUnary(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params) + : AggregateFunctionNullBase>(std::move(nested_function_), arguments, params) { } @@ -209,8 +209,8 @@ template class AggregateFunctionNullVariadic final : public AggregateFunctionNullBase> { public: - AggregateFunctionNullVariadic(AggregateFunctionPtr nested_function_, const DataTypes & arguments) - : AggregateFunctionNullBase>(std::move(nested_function_)), + AggregateFunctionNullVariadic(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params) + : AggregateFunctionNullBase>(std::move(nested_function_), arguments, params), number_of_arguments(arguments.size()) { if (number_of_arguments == 1) diff --git a/dbms/src/AggregateFunctions/AggregateFunctionQuantile.h b/dbms/src/AggregateFunctions/AggregateFunctionQuantile.h index cee2b6fe0c0..a87f520d395 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionQuantile.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionQuantile.h @@ -73,11 +73,12 @@ private: /// Used when there are single level to get. Float64 level = 0.5; - DataTypePtr argument_type; + DataTypePtr & argument_type; public: AggregateFunctionQuantile(const DataTypePtr & argument_type, const Array & params) - : levels(params, returns_many), level(levels.levels[0]), argument_type(argument_type) + : IAggregateFunctionDataHelper>({argument_type}, params) + , levels(params, returns_many), level(levels.levels[0]), argument_type(argument_types[0]) { if (!returns_many && levels.size() > 1) throw Exception("Aggregate function " + getName() + " require one parameter or less", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionRetention.h b/dbms/src/AggregateFunctions/AggregateFunctionRetention.h index 688f7f1404c..525a4d848d2 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionRetention.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionRetention.h @@ -76,6 +76,7 @@ public: } AggregateFunctionRetention(const DataTypes & arguments) + : IAggregateFunctionDataHelper(arguments, {}) { for (const auto i : ext::range(0, arguments.size())) { diff --git a/dbms/src/AggregateFunctions/AggregateFunctionSequenceMatch.cpp b/dbms/src/AggregateFunctions/AggregateFunctionSequenceMatch.cpp index 0b7a4b6b357..be139d9e633 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionSequenceMatch.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionSequenceMatch.cpp @@ -19,7 +19,7 @@ AggregateFunctionPtr createAggregateFunctionSequenceCount(const std::string & na ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH}; String pattern = params.front().safeGet(); - return std::make_shared(argument_types, pattern); + return std::make_shared(argument_types, params, pattern); } AggregateFunctionPtr createAggregateFunctionSequenceMatch(const std::string & name, const DataTypes & argument_types, const Array & params) @@ -29,7 +29,7 @@ AggregateFunctionPtr createAggregateFunctionSequenceMatch(const std::string & na ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH}; String pattern = params.front().safeGet(); - return std::make_shared(argument_types, pattern); + return std::make_shared(argument_types, params, pattern); } } diff --git a/dbms/src/AggregateFunctions/AggregateFunctionSequenceMatch.h b/dbms/src/AggregateFunctions/AggregateFunctionSequenceMatch.h index 86627a453c2..5c443c72b63 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionSequenceMatch.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionSequenceMatch.h @@ -139,8 +139,9 @@ template class AggregateFunctionSequenceBase : public IAggregateFunctionDataHelper { public: - AggregateFunctionSequenceBase(const DataTypes & arguments, const String & pattern) - : pattern(pattern) + AggregateFunctionSequenceBase(const DataTypes & arguments, const Array & params, const String & pattern) + : IAggregateFunctionDataHelper(arguments, params) + , pattern(pattern) { arg_count = arguments.size(); @@ -578,6 +579,9 @@ private: class AggregateFunctionSequenceMatch final : public AggregateFunctionSequenceBase { public: + AggregateFunctionSequenceMatch(const DataTypes & arguments, const Array & params, const String & pattern) + : AggregateFunctionSequenceBase(arguments, params, pattern) {} + using AggregateFunctionSequenceBase::AggregateFunctionSequenceBase; String getName() const override { return "sequenceMatch"; } @@ -603,6 +607,9 @@ public: class AggregateFunctionSequenceCount final : public AggregateFunctionSequenceBase { public: + AggregateFunctionSequenceCount(const DataTypes & arguments, const Array & params, const String & pattern) + : AggregateFunctionSequenceBase(arguments, params, pattern) {} + using AggregateFunctionSequenceBase::AggregateFunctionSequenceBase; String getName() const override { return "sequenceCount"; } diff --git a/dbms/src/AggregateFunctions/AggregateFunctionState.h b/dbms/src/AggregateFunctions/AggregateFunctionState.h index 30755ce3896..2d8e5c6a537 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionState.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionState.h @@ -24,7 +24,8 @@ private: public: AggregateFunctionState(AggregateFunctionPtr nested, const DataTypes & arguments, const Array & params) - : nested_func(nested), arguments(arguments), params(params) {} + : IAggregateFunctionHelper(arguments, params) + , nested_func(nested), arguments(arguments), params(params) {} String getName() const override { diff --git a/dbms/src/AggregateFunctions/AggregateFunctionStatistics.cpp b/dbms/src/AggregateFunctions/AggregateFunctionStatistics.cpp index ae73013d29d..1530ad25cf3 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionStatistics.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionStatistics.cpp @@ -21,7 +21,7 @@ AggregateFunctionPtr createAggregateFunctionStatisticsUnary(const std::string & assertNoParameters(name, parameters); assertUnary(name, argument_types); - AggregateFunctionPtr res(createWithNumericType(*argument_types[0])); + AggregateFunctionPtr res(createWithNumericType(*argument_types[0], argument_types[0])); if (!res) throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); @@ -35,7 +35,7 @@ AggregateFunctionPtr createAggregateFunctionStatisticsBinary(const std::string & assertNoParameters(name, parameters); assertBinary(name, argument_types); - AggregateFunctionPtr res(createWithTwoNumericTypes(*argument_types[0], *argument_types[1])); + AggregateFunctionPtr res(createWithTwoNumericTypes(*argument_types[0], *argument_types[1], argument_types)); if (!res) throw Exception("Illegal types " + argument_types[0]->getName() + " and " + argument_types[1]->getName() + " of arguments for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionStatistics.h b/dbms/src/AggregateFunctions/AggregateFunctionStatistics.h index 82d34fc2954..d1112ec0831 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionStatistics.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionStatistics.h @@ -111,6 +111,9 @@ class AggregateFunctionVariance final : public IAggregateFunctionDataHelper, AggregateFunctionVariance> { public: + AggregateFunctionVariance(const DataTypePtr & arg) + : IAggregateFunctionDataHelper, AggregateFunctionVariance>({arg}, {}) {} + String getName() const override { return Op::name; } DataTypePtr getReturnType() const override @@ -361,6 +364,10 @@ class AggregateFunctionCovariance final AggregateFunctionCovariance> { public: + AggregateFunctionCovariance(const DataTypes & args) : IAggregateFunctionDataHelper< + CovarianceData, + AggregateFunctionCovariance>(args, {}) {} + String getName() const override { return Op::name; } DataTypePtr getReturnType() const override diff --git a/dbms/src/AggregateFunctions/AggregateFunctionStatisticsSimple.h b/dbms/src/AggregateFunctions/AggregateFunctionStatisticsSimple.h index 0580a5131a2..4ab6a4d51ed 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionStatisticsSimple.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionStatisticsSimple.h @@ -288,12 +288,14 @@ public: using ResultType = typename StatFunc::ResultType; using ColVecResult = ColumnVector; - AggregateFunctionVarianceSimple() - : src_scale(0) + AggregateFunctionVarianceSimple(const DataTypes & argument_types) + : IAggregateFunctionDataHelper>(argument_types, {}) + , src_scale(0) {} - AggregateFunctionVarianceSimple(const IDataType & data_type) - : src_scale(getDecimalScale(data_type)) + AggregateFunctionVarianceSimple(const IDataType & data_type, const DataTypes & argument_types) + : IAggregateFunctionDataHelper>(argument_types, {}) + , src_scale(getDecimalScale(data_type)) {} String getName() const override diff --git a/dbms/src/AggregateFunctions/AggregateFunctionSum.cpp b/dbms/src/AggregateFunctions/AggregateFunctionSum.cpp index f21c60eeae6..5e060d7b7df 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionSum.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionSum.cpp @@ -50,9 +50,9 @@ AggregateFunctionPtr createAggregateFunctionSum(const std::string & name, const AggregateFunctionPtr res; DataTypePtr data_type = argument_types[0]; if (isDecimal(data_type)) - res.reset(createWithDecimalType(*data_type, *data_type)); + res.reset(createWithDecimalType(*data_type, *data_type, argument_types)); else - res.reset(createWithNumericType(*data_type)); + res.reset(createWithNumericType(*data_type, argument_types)); if (!res) throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name, diff --git a/dbms/src/AggregateFunctions/AggregateFunctionSum.h b/dbms/src/AggregateFunctions/AggregateFunctionSum.h index 5bd2d10917a..1860088cd93 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionSum.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionSum.h @@ -102,12 +102,14 @@ public: String getName() const override { return "sum"; } - AggregateFunctionSum() - : scale(0) + AggregateFunctionSum(const DataTypes & argument_types) + : IAggregateFunctionDataHelper>(argument_types, {}) + , scale(0) {} - AggregateFunctionSum(const IDataType & data_type) - : scale(getDecimalScale(data_type)) + AggregateFunctionSum(const IDataType & data_type, const DataTypes & argument_types) + : IAggregateFunctionDataHelper>(argument_types, {}) + , scale(getDecimalScale(data_type)) {} DataTypePtr getReturnType() const override diff --git a/dbms/src/AggregateFunctions/AggregateFunctionSumMap.cpp b/dbms/src/AggregateFunctions/AggregateFunctionSumMap.cpp index 75cd62c00f1..5a10ae62324 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionSumMap.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionSumMap.cpp @@ -80,7 +80,7 @@ AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, con auto [keys_type, values_types] = parseArguments(name, arguments); - AggregateFunctionPtr res(createWithNumericBasedType(*keys_type, keys_type, values_types)); + AggregateFunctionPtr res(createWithNumericBasedType(*keys_type, keys_type, values_types, arguments)); if (!res) res.reset(createWithDecimalType(*keys_type, keys_type, values_types)); if (!res) @@ -103,7 +103,7 @@ AggregateFunctionPtr createAggregateFunctionSumMapFiltered(const std::string & n auto [keys_type, values_types] = parseArguments(name, arguments); - AggregateFunctionPtr res(createWithNumericBasedType(*keys_type, keys_type, values_types, keys_to_keep)); + AggregateFunctionPtr res(createWithNumericBasedType(*keys_type, keys_type, values_types, keys_to_keep, arguments, params)); if (!res) res.reset(createWithDecimalType(*keys_type, keys_type, values_types, keys_to_keep)); if (!res) diff --git a/dbms/src/AggregateFunctions/AggregateFunctionSumMap.h b/dbms/src/AggregateFunctions/AggregateFunctionSumMap.h index c239b74630e..ef6cae9babc 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionSumMap.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionSumMap.h @@ -61,8 +61,11 @@ private: DataTypes values_types; public: - AggregateFunctionSumMapBase(const DataTypePtr & keys_type, const DataTypes & values_types) - : keys_type(keys_type), values_types(values_types) {} + AggregateFunctionSumMapBase( + const DataTypePtr & keys_type, const DataTypes & values_types, + const DataTypes & argument_types, const Array & params) + : IAggregateFunctionDataHelper>, Derived>(argument_types, params) + , keys_type(keys_type), values_types(values_types) {} String getName() const override { return "sumMap"; } @@ -271,8 +274,8 @@ private: using Base = AggregateFunctionSumMapBase; public: - AggregateFunctionSumMap(const DataTypePtr & keys_type, DataTypes & values_types) - : Base{keys_type, values_types} + AggregateFunctionSumMap(const DataTypePtr & keys_type, DataTypes & values_types, const DataTypes & argument_types) + : Base{keys_type, values_types, argument_types, {}} {} String getName() const override { return "sumMap"; } @@ -291,8 +294,10 @@ private: std::unordered_set keys_to_keep; public: - AggregateFunctionSumMapFiltered(const DataTypePtr & keys_type, const DataTypes & values_types, const Array & keys_to_keep_) - : Base{keys_type, values_types} + AggregateFunctionSumMapFiltered( + const DataTypePtr & keys_type, const DataTypes & values_types, const Array & keys_to_keep_, + const DataTypes & argument_types, const Array & params) + : Base{keys_type, values_types, argument_types, params} { keys_to_keep.reserve(keys_to_keep_.size()); for (const Field & f : keys_to_keep_) diff --git a/dbms/src/AggregateFunctions/AggregateFunctionTopK.cpp b/dbms/src/AggregateFunctions/AggregateFunctionTopK.cpp index 168dba4ebd5..04e74c17434 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionTopK.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionTopK.cpp @@ -39,19 +39,19 @@ class AggregateFunctionTopKDateTime : public AggregateFunctionTopK -static IAggregateFunction * createWithExtraTypes(const DataTypePtr & argument_type, UInt64 threshold) +static IAggregateFunction * createWithExtraTypes(const DataTypePtr & argument_type, UInt64 threshold, const Array & params) { WhichDataType which(argument_type); if (which.idx == TypeIndex::Date) - return new AggregateFunctionTopKDate(threshold); + return new AggregateFunctionTopKDate(threshold, {argument_type}, params); if (which.idx == TypeIndex::DateTime) - return new AggregateFunctionTopKDateTime(threshold); + return new AggregateFunctionTopKDateTime(threshold, {argument_type}, params); /// Check that we can use plain version of AggregateFunctionTopKGeneric if (argument_type->isValueUnambiguouslyRepresentedInContiguousMemoryRegion()) - return new AggregateFunctionTopKGeneric(threshold, argument_type); + return new AggregateFunctionTopKGeneric(threshold, argument_type, params); else - return new AggregateFunctionTopKGeneric(threshold, argument_type); + return new AggregateFunctionTopKGeneric(threshold, argument_type, params); } @@ -90,10 +90,10 @@ AggregateFunctionPtr createAggregateFunctionTopK(const std::string & name, const threshold = k; } - AggregateFunctionPtr res(createWithNumericType(*argument_types[0], threshold)); + AggregateFunctionPtr res(createWithNumericType(*argument_types[0], threshold, argument_types, params)); if (!res) - res = AggregateFunctionPtr(createWithExtraTypes(argument_types[0], threshold)); + res = AggregateFunctionPtr(createWithExtraTypes(argument_types[0], threshold, params)); if (!res) throw Exception("Illegal type " + argument_types[0]->getName() + diff --git a/dbms/src/AggregateFunctions/AggregateFunctionTopK.h b/dbms/src/AggregateFunctions/AggregateFunctionTopK.h index 09897f5ccd2..846a3e2b2a1 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionTopK.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionTopK.h @@ -48,8 +48,9 @@ protected: UInt64 reserved; public: - AggregateFunctionTopK(UInt64 threshold) - : threshold(threshold), reserved(TOP_K_LOAD_FACTOR * threshold) {} + AggregateFunctionTopK(UInt64 threshold, const DataTypes & argument_types, const Array & params) + : IAggregateFunctionDataHelper, AggregateFunctionTopK>(argument_types, params) + , threshold(threshold), reserved(TOP_K_LOAD_FACTOR * threshold) {} String getName() const override { return is_weighted ? "topKWeighted" : "topK"; } @@ -136,13 +137,15 @@ private: UInt64 threshold; UInt64 reserved; - DataTypePtr input_data_type; + DataTypePtr & input_data_type; static void deserializeAndInsert(StringRef str, IColumn & data_to); public: - AggregateFunctionTopKGeneric(UInt64 threshold, const DataTypePtr & input_data_type) - : threshold(threshold), reserved(TOP_K_LOAD_FACTOR * threshold), input_data_type(input_data_type) {} + AggregateFunctionTopKGeneric( + UInt64 threshold, const DataTypePtr & input_data_type, const Array & params) + : IAggregateFunctionDataHelper>({input_data_type}, params) + , threshold(threshold), reserved(TOP_K_LOAD_FACTOR * threshold), input_data_type(argument_types[0]) {} String getName() const override { return is_weighted ? "topKWeighted" : "topK"; } diff --git a/dbms/src/AggregateFunctions/AggregateFunctionUniq.cpp b/dbms/src/AggregateFunctions/AggregateFunctionUniq.cpp index 6b63a719b8f..eaf021d8735 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionUniq.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionUniq.cpp @@ -43,19 +43,19 @@ AggregateFunctionPtr createAggregateFunctionUniq(const std::string & name, const { const IDataType & argument_type = *argument_types[0]; - AggregateFunctionPtr res(createWithNumericType(*argument_types[0])); + AggregateFunctionPtr res(createWithNumericType(*argument_types[0], argument_types)); WhichDataType which(argument_type); if (res) return res; else if (which.isDate()) - return std::make_shared>(); + return std::make_shared>(argument_types); else if (which.isDateTime()) - return std::make_shared>(); + return std::make_shared>(argument_types); else if (which.isStringOrFixedString()) - return std::make_shared>(); + return std::make_shared>(argument_types); else if (which.isUUID()) - return std::make_shared>(); + return std::make_shared>(argument_types); else if (which.isTuple()) { if (use_exact_hash_function) @@ -89,19 +89,19 @@ AggregateFunctionPtr createAggregateFunctionUniq(const std::string & name, const { const IDataType & argument_type = *argument_types[0]; - AggregateFunctionPtr res(createWithNumericType(*argument_types[0])); + AggregateFunctionPtr res(createWithNumericType(*argument_types[0], argument_types)); WhichDataType which(argument_type); if (res) return res; else if (which.isDate()) - return std::make_shared>>(); + return std::make_shared>>(argument_types); else if (which.isDateTime()) - return std::make_shared>>(); + return std::make_shared>>(argument_types); else if (which.isStringOrFixedString()) return std::make_shared>>(); else if (which.isUUID()) - return std::make_shared>>(); + return std::make_shared>>(argument_types); else if (which.isTuple()) { if (use_exact_hash_function) diff --git a/dbms/src/AggregateFunctions/AggregateFunctionUniq.h b/dbms/src/AggregateFunctions/AggregateFunctionUniq.h index fea79a920a9..56a855aabb9 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionUniq.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionUniq.h @@ -209,6 +209,9 @@ template class AggregateFunctionUniq final : public IAggregateFunctionDataHelper> { public: + AggregateFunctionUniq(const DataTypes & argument_types) + : IAggregateFunctionDataHelper>(argument_types, {}) {} + String getName() const override { return Data::getName(); } DataTypePtr getReturnType() const override @@ -257,6 +260,7 @@ private: public: AggregateFunctionUniqVariadic(const DataTypes & arguments) + : IAggregateFunctionDataHelper>(arguments) { if (argument_is_tuple) num_args = typeid_cast(*arguments[0]).getElements().size(); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionUniqCombined.cpp b/dbms/src/AggregateFunctions/AggregateFunctionUniqCombined.cpp index 90b84d3b927..38982b8130e 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionUniqCombined.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionUniqCombined.cpp @@ -28,7 +28,7 @@ namespace }; template - AggregateFunctionPtr createAggregateFunctionWithK(const DataTypes & argument_types) + AggregateFunctionPtr createAggregateFunctionWithK(const DataTypes & argument_types, const Array & params) { /// We use exact hash function if the arguments are not contiguous in memory, because only exact hash function has support for this case. bool use_exact_hash_function = !isAllArgumentsContiguousInMemory(argument_types); @@ -37,33 +37,33 @@ namespace { const IDataType & argument_type = *argument_types[0]; - AggregateFunctionPtr res(createWithNumericType::template AggregateFunction>(*argument_types[0])); + AggregateFunctionPtr res(createWithNumericType::template AggregateFunction>(*argument_types[0], argument_types, params)); WhichDataType which(argument_type); if (res) return res; else if (which.isDate()) - return std::make_shared::template AggregateFunction>(); + return std::make_shared::template AggregateFunction>(argument_types, params); else if (which.isDateTime()) - return std::make_shared::template AggregateFunction>(); + return std::make_shared::template AggregateFunction>(argument_types, params); else if (which.isStringOrFixedString()) - return std::make_shared::template AggregateFunction>(); + return std::make_shared::template AggregateFunction>(argument_types, params); else if (which.isUUID()) - return std::make_shared::template AggregateFunction>(); + return std::make_shared::template AggregateFunction>(argument_types, params); else if (which.isTuple()) { if (use_exact_hash_function) - return std::make_shared::template AggregateFunctionVariadic>(argument_types); + return std::make_shared::template AggregateFunctionVariadic>(argument_types, params); else - return std::make_shared::template AggregateFunctionVariadic>(argument_types); + return std::make_shared::template AggregateFunctionVariadic>(argument_types, params); } } /// "Variadic" method also works as a fallback generic case for a single argument. if (use_exact_hash_function) - return std::make_shared::template AggregateFunctionVariadic>(argument_types); + return std::make_shared::template AggregateFunctionVariadic>(argument_types, params); else - return std::make_shared::template AggregateFunctionVariadic>(argument_types); + return std::make_shared::template AggregateFunctionVariadic>(argument_types, params); } AggregateFunctionPtr createAggregateFunctionUniqCombined( @@ -95,23 +95,23 @@ namespace switch (precision) { case 12: - return createAggregateFunctionWithK<12>(argument_types); + return createAggregateFunctionWithK<12>(argument_types, params); case 13: - return createAggregateFunctionWithK<13>(argument_types); + return createAggregateFunctionWithK<13>(argument_types, params); case 14: - return createAggregateFunctionWithK<14>(argument_types); + return createAggregateFunctionWithK<14>(argument_types, params); case 15: - return createAggregateFunctionWithK<15>(argument_types); + return createAggregateFunctionWithK<15>(argument_types, params); case 16: - return createAggregateFunctionWithK<16>(argument_types); + return createAggregateFunctionWithK<16>(argument_types, params); case 17: - return createAggregateFunctionWithK<17>(argument_types); + return createAggregateFunctionWithK<17>(argument_types, params); case 18: - return createAggregateFunctionWithK<18>(argument_types); + return createAggregateFunctionWithK<18>(argument_types, params); case 19: - return createAggregateFunctionWithK<19>(argument_types); + return createAggregateFunctionWithK<19>(argument_types, params); case 20: - return createAggregateFunctionWithK<20>(argument_types); + return createAggregateFunctionWithK<20>(argument_types, params); } __builtin_unreachable(); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionUniqCombined.h b/dbms/src/AggregateFunctions/AggregateFunctionUniqCombined.h index 001f4e7f289..3b7aee95186 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionUniqCombined.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionUniqCombined.h @@ -114,6 +114,9 @@ class AggregateFunctionUniqCombined final : public IAggregateFunctionDataHelper, AggregateFunctionUniqCombined> { public: + AggregateFunctionUniqCombined(const DataTypes & argument_types, const Array & params) + : IAggregateFunctionDataHelper, AggregateFunctionUniqCombined>(argument_types, params) {} + String getName() const override { return "uniqCombined"; @@ -176,7 +179,9 @@ private: size_t num_args = 0; public: - explicit AggregateFunctionUniqCombinedVariadic(const DataTypes & arguments) + explicit AggregateFunctionUniqCombinedVariadic(const DataTypes & arguments, const Array & params) + : IAggregateFunctionDataHelper, + AggregateFunctionUniqCombinedVariadic>(arguments, params) { if (argument_is_tuple) num_args = typeid_cast(*arguments[0]).getElements().size(); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionUniqUpTo.cpp b/dbms/src/AggregateFunctions/AggregateFunctionUniqUpTo.cpp index b9cdcaa4eae..ba4f337839e 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionUniqUpTo.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionUniqUpTo.cpp @@ -52,33 +52,33 @@ AggregateFunctionPtr createAggregateFunctionUniqUpTo(const std::string & name, c { const IDataType & argument_type = *argument_types[0]; - AggregateFunctionPtr res(createWithNumericType(*argument_types[0], threshold)); + AggregateFunctionPtr res(createWithNumericType(*argument_types[0], threshold, argument_types, params)); WhichDataType which(argument_type); if (res) return res; else if (which.isDate()) - return std::make_shared>(threshold); + return std::make_shared>(threshold, argument_types, params); else if (which.isDateTime()) - return std::make_shared>(threshold); + return std::make_shared>(threshold, argument_types, params); else if (which.isStringOrFixedString()) - return std::make_shared>(threshold); + return std::make_shared>(threshold, argument_types, params); else if (which.isUUID()) - return std::make_shared>(threshold); + return std::make_shared>(threshold, argument_types, params); else if (which.isTuple()) { if (use_exact_hash_function) - return std::make_shared>(argument_types, threshold); + return std::make_shared>(argument_types, params, threshold); else - return std::make_shared>(argument_types, threshold); + return std::make_shared>(argument_types, params, threshold); } } /// "Variadic" method also works as a fallback generic case for single argument. if (use_exact_hash_function) - return std::make_shared>(argument_types, threshold); + return std::make_shared>(argument_types, params, threshold); else - return std::make_shared>(argument_types, threshold); + return std::make_shared>(argument_types, params, threshold); } } diff --git a/dbms/src/AggregateFunctions/AggregateFunctionUniqUpTo.h b/dbms/src/AggregateFunctions/AggregateFunctionUniqUpTo.h index 6b6a645024a..477a729894d 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionUniqUpTo.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionUniqUpTo.h @@ -136,8 +136,9 @@ private: UInt8 threshold; public: - AggregateFunctionUniqUpTo(UInt8 threshold) - : threshold(threshold) + AggregateFunctionUniqUpTo(UInt8 threshold, const DataTypes & argument_types, const Array & params) + : IAggregateFunctionDataHelper, AggregateFunctionUniqUpTo>(argument_types, params) + , threshold(threshold) { } @@ -195,8 +196,9 @@ private: UInt8 threshold; public: - AggregateFunctionUniqUpToVariadic(const DataTypes & arguments, UInt8 threshold) - : threshold(threshold) + AggregateFunctionUniqUpToVariadic(const DataTypes & arguments, const Array & params, UInt8 threshold) + : IAggregateFunctionDataHelper, AggregateFunctionUniqUpToVariadic>(arguments, params) + , threshold(threshold) { if (argument_is_tuple) num_args = typeid_cast(*arguments[0]).getElements().size(); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionWindowFunnel.h b/dbms/src/AggregateFunctions/AggregateFunctionWindowFunnel.h index 317637b1b69..556f9bb1ae1 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionWindowFunnel.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionWindowFunnel.h @@ -189,6 +189,7 @@ public: } AggregateFunctionWindowFunnel(const DataTypes & arguments, const Array & params) + : IAggregateFunctionDataHelper(arguments, params) { const auto time_arg = arguments.front().get(); if (!WhichDataType(time_arg).isDateTime() && !WhichDataType(time_arg).isUInt32()) diff --git a/dbms/src/AggregateFunctions/AggregateFunctionsStatisticsSimple.cpp b/dbms/src/AggregateFunctions/AggregateFunctionsStatisticsSimple.cpp index 4159403afc7..1fafa6e00c9 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionsStatisticsSimple.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionsStatisticsSimple.cpp @@ -24,9 +24,9 @@ AggregateFunctionPtr createAggregateFunctionStatisticsUnary(const std::string & AggregateFunctionPtr res; DataTypePtr data_type = argument_types[0]; if (isDecimal(data_type)) - res.reset(createWithDecimalType(*data_type, *data_type)); + res.reset(createWithDecimalType(*data_type, *data_type, argument_types)); else - res.reset(createWithNumericType(*data_type)); + res.reset(createWithNumericType(*data_type, argument_types)); if (!res) throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name, @@ -40,7 +40,7 @@ AggregateFunctionPtr createAggregateFunctionStatisticsBinary(const std::string & assertNoParameters(name, parameters); assertBinary(name, argument_types); - AggregateFunctionPtr res(createWithTwoNumericTypes(*argument_types[0], *argument_types[1])); + AggregateFunctionPtr res(createWithTwoNumericTypes(*argument_types[0], *argument_types[1], argument_types)); if (!res) throw Exception("Illegal types " + argument_types[0]->getName() + " and " + argument_types[1]->getName() + " of arguments for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); diff --git a/dbms/src/AggregateFunctions/IAggregateFunction.h b/dbms/src/AggregateFunctions/IAggregateFunction.h index f5def066058..17620f7493d 100644 --- a/dbms/src/AggregateFunctions/IAggregateFunction.h +++ b/dbms/src/AggregateFunctions/IAggregateFunction.h @@ -37,6 +37,9 @@ using ConstAggregateDataPtr = const char *; class IAggregateFunction { public: + IAggregateFunction(const DataTypes & argument_types_, const Array & parameters_) + : argument_types(argument_types_), parameters(parameters_) {} + /// Get main function name. virtual String getName() const = 0; @@ -112,17 +115,9 @@ public: const DataTypes & getArgumentTypes() const { return argument_types; } const Array & getParameters() const { return parameters; } -private: +protected: DataTypes argument_types; Array parameters; - - friend class AggregateFunctionFactory; - - void setArguments(DataTypes argument_types_, Array parameters_) - { - argument_types = std::move(argument_types_); - parameters = std::move(parameters_); - } }; @@ -137,6 +132,8 @@ private: } public: + IAggregateFunctionHelper(const DataTypes & argument_types_, const Array & parameters_) + : IAggregateFunction(argument_types_, parameters_) {} AddFunc getAddressOfAddFunction() const override { return &addFree; } }; @@ -152,6 +149,10 @@ protected: static const Data & data(ConstAggregateDataPtr place) { return *reinterpret_cast(place); } public: + + IAggregateFunctionDataHelper(const DataTypes & argument_types_, const Array & parameters_) + : IAggregateFunctionHelper(argument_types_, parameters_) {} + void create(AggregateDataPtr place) const override { new (place) Data;