diff --git a/src/AggregateFunctions/AggregateFunctionMerge.h b/src/AggregateFunctions/AggregateFunctionMerge.h index 72f3d119883..78e5c92c917 100644 --- a/src/AggregateFunctions/AggregateFunctionMerge.h +++ b/src/AggregateFunctions/AggregateFunctionMerge.h @@ -35,13 +35,9 @@ public: { const DataTypeAggregateFunction * data_type = typeid_cast(argument.get()); - if (!data_type || data_type->getFunctionName() != nested_func->getName()) - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument for aggregate function {}", - argument->getName(), getName()); - - if (data_type->getParameters() != getParameters()) - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument for aggregate function {}: " - "parameters mismatch", argument->getName(), getName()); + if (!data_type || !nested_func->haveSameStateRepresentation(*data_type->getFunction())) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument for aggregate function {}, " + "expected {} or equivalent type", argument->getName(), getName(), getStateType()->getName()); } String getName() const override diff --git a/src/AggregateFunctions/AggregateFunctionQuantile.h b/src/AggregateFunctions/AggregateFunctionQuantile.h index 90745c7d749..a7a3d4042c2 100644 --- a/src/AggregateFunctions/AggregateFunctionQuantile.h +++ b/src/AggregateFunctions/AggregateFunctionQuantile.h @@ -105,6 +105,11 @@ public: return res; } + bool haveSameStateRepresentation(const IAggregateFunction & rhs) const override + { + return getName() == rhs.getName() && this->haveEqualArgumentTypes(rhs); + } + bool allocatesMemoryInArena() const override { return false; } void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override diff --git a/src/AggregateFunctions/AggregateFunctionSequenceMatch.h b/src/AggregateFunctions/AggregateFunctionSequenceMatch.h index 88a6809d229..d05a4ca314d 100644 --- a/src/AggregateFunctions/AggregateFunctionSequenceMatch.h +++ b/src/AggregateFunctions/AggregateFunctionSequenceMatch.h @@ -179,6 +179,11 @@ public: this->data(place).deserialize(buf); } + bool haveSameStateRepresentation(const IAggregateFunction & rhs) const override + { + return this->getName() == rhs.getName() && this->haveEqualArgumentTypes(rhs); + } + private: enum class PatternActionType { diff --git a/src/AggregateFunctions/AggregateFunctionSequenceNextNode.cpp b/src/AggregateFunctions/AggregateFunctionSequenceNextNode.cpp index 965413afb1d..d86499b90f3 100644 --- a/src/AggregateFunctions/AggregateFunctionSequenceNextNode.cpp +++ b/src/AggregateFunctions/AggregateFunctionSequenceNextNode.cpp @@ -31,10 +31,10 @@ namespace template inline AggregateFunctionPtr createAggregateFunctionSequenceNodeImpl( - const DataTypePtr data_type, const DataTypes & argument_types, SequenceDirection direction, SequenceBase base) + const DataTypePtr data_type, const DataTypes & argument_types, const Array & parameters, SequenceDirection direction, SequenceBase base) { return std::make_shared>>( - data_type, argument_types, base, direction, min_required_args); + data_type, argument_types, parameters, base, direction, min_required_args); } AggregateFunctionPtr @@ -116,17 +116,17 @@ createAggregateFunctionSequenceNode(const std::string & name, const DataTypes & WhichDataType timestamp_type(argument_types[0].get()); if (timestamp_type.idx == TypeIndex::UInt8) - return createAggregateFunctionSequenceNodeImpl(data_type, argument_types, direction, base); + return createAggregateFunctionSequenceNodeImpl(data_type, argument_types, parameters, direction, base); if (timestamp_type.idx == TypeIndex::UInt16) - return createAggregateFunctionSequenceNodeImpl(data_type, argument_types, direction, base); + return createAggregateFunctionSequenceNodeImpl(data_type, argument_types, parameters, direction, base); if (timestamp_type.idx == TypeIndex::UInt32) - return createAggregateFunctionSequenceNodeImpl(data_type, argument_types, direction, base); + return createAggregateFunctionSequenceNodeImpl(data_type, argument_types, parameters, direction, base); if (timestamp_type.idx == TypeIndex::UInt64) - return createAggregateFunctionSequenceNodeImpl(data_type, argument_types, direction, base); + return createAggregateFunctionSequenceNodeImpl(data_type, argument_types, parameters, direction, base); if (timestamp_type.isDate()) - return createAggregateFunctionSequenceNodeImpl(data_type, argument_types, direction, base); + return createAggregateFunctionSequenceNodeImpl(data_type, argument_types, parameters, direction, base); if (timestamp_type.isDateTime()) - return createAggregateFunctionSequenceNodeImpl(data_type, argument_types, direction, base); + return createAggregateFunctionSequenceNodeImpl(data_type, argument_types, parameters, direction, base); throw Exception{"Illegal type " + argument_types.front().get()->getName() + " of first argument of aggregate function " + name + ", must be Unsigned Number, Date, DateTime", diff --git a/src/AggregateFunctions/AggregateFunctionSequenceNextNode.h b/src/AggregateFunctions/AggregateFunctionSequenceNextNode.h index 116e53e95e8..e5b007232e2 100644 --- a/src/AggregateFunctions/AggregateFunctionSequenceNextNode.h +++ b/src/AggregateFunctions/AggregateFunctionSequenceNextNode.h @@ -175,11 +175,12 @@ public: SequenceNextNodeImpl( const DataTypePtr & data_type_, const DataTypes & arguments, + const Array & parameters_, SequenceBase seq_base_kind_, SequenceDirection seq_direction_, size_t min_required_args_, UInt64 max_elems_ = std::numeric_limits::max()) - : IAggregateFunctionDataHelper, Self>({data_type_}, {}) + : IAggregateFunctionDataHelper, Self>({data_type_}, parameters_) , seq_base_kind(seq_base_kind_) , seq_direction(seq_direction_) , min_required_args(min_required_args_) @@ -193,6 +194,11 @@ public: DataTypePtr getReturnType() const override { return data_type; } + bool haveSameStateRepresentation(const IAggregateFunction & rhs) const override + { + return this->getName() == rhs.getName() && this->haveEqualArgumentTypes(rhs); + } + AggregateFunctionPtr getOwnNullAdapter( const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array & params, const AggregateFunctionProperties &) const override diff --git a/src/AggregateFunctions/IAggregateFunction.cpp b/src/AggregateFunctions/IAggregateFunction.cpp index 55998d963bf..ea4f8338fb8 100644 --- a/src/AggregateFunctions/IAggregateFunction.cpp +++ b/src/AggregateFunctions/IAggregateFunction.cpp @@ -50,4 +50,21 @@ String IAggregateFunction::getDescription() const return description; } + +bool IAggregateFunction::haveEqualArgumentTypes(const IAggregateFunction & rhs) const +{ + return std::equal(argument_types.begin(), argument_types.end(), + rhs.argument_types.begin(), rhs.argument_types.end(), + [](const auto & t1, const auto & t2) { return t1->equals(*t2); }); +} + +bool IAggregateFunction::haveSameStateRepresentation(const IAggregateFunction & rhs) const +{ + bool res = getName() == rhs.getName() + && parameters == rhs.parameters + && haveEqualArgumentTypes(rhs); + assert(res == (getStateType()->getName() == rhs.getStateType()->getName())); + return res; +} + } diff --git a/src/AggregateFunctions/IAggregateFunction.h b/src/AggregateFunctions/IAggregateFunction.h index 7acfa82a139..a06f1d12c0d 100644 --- a/src/AggregateFunctions/IAggregateFunction.h +++ b/src/AggregateFunctions/IAggregateFunction.h @@ -74,6 +74,16 @@ public: /// Get the data type of internal state. By default it is AggregateFunction(name(params), argument_types...). virtual DataTypePtr getStateType() const; + /// Returns true if two aggregate functions have the same state representation in memory and the same serialization, + /// so state of one aggregate function can be safely used with another. + /// Examples: + /// - quantile(x), quantile(a)(x), quantile(b)(x) - parameter doesn't affect state and used for finalization only + /// - foo(x) and fooIf(x) - If combinator doesn't affect state + /// By default returns true only if functions have exactly the same names, combinators and parameters. + virtual bool haveSameStateRepresentation(const IAggregateFunction & rhs) const; + + bool haveEqualArgumentTypes(const IAggregateFunction & rhs) const; + /// Get type which will be used for prediction result in case if function is an ML method. virtual DataTypePtr getReturnTypeToPredict() const { diff --git a/tests/queries/0_stateless/00905_field_with_aggregate_function_state.sql b/tests/queries/0_stateless/00905_field_with_aggregate_function_state.sql index a1903268d9a..b0470ac9992 100644 --- a/tests/queries/0_stateless/00905_field_with_aggregate_function_state.sql +++ b/tests/queries/0_stateless/00905_field_with_aggregate_function_state.sql @@ -1,4 +1,4 @@ with (select sumState(1)) as s select sumMerge(s); with (select sumState(number) from (select * from system.numbers limit 10)) as s select sumMerge(s); -with (select quantileState(0.5)(number) from (select * from system.numbers limit 10)) as s select quantileMerge(0.5)(s); +with (select quantileState(0.5)(number) from (select * from system.numbers limit 10)) as s select quantileMerge(s);