allow mismatching parameters for some functions

This commit is contained in:
Alexander Tokmakov 2021-07-28 20:55:13 +03:00
parent ebcf0844f5
commit 764701c3f3
8 changed files with 56 additions and 17 deletions

View File

@ -35,13 +35,9 @@ public:
{ {
const DataTypeAggregateFunction * data_type = typeid_cast<const DataTypeAggregateFunction *>(argument.get()); const DataTypeAggregateFunction * data_type = typeid_cast<const DataTypeAggregateFunction *>(argument.get());
if (!data_type || data_type->getFunctionName() != nested_func->getName()) if (!data_type || !nested_func->haveSameStateRepresentation(*data_type->getFunction()))
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument for aggregate function {}", throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument for aggregate function {}, "
argument->getName(), getName()); "expected {} or equivalent type", argument->getName(), getName(), getStateType()->getName());
if (data_type->getParameters() != getParameters())
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument for aggregate function {}: "
"parameters mismatch", argument->getName(), getName());
} }
String getName() const override String getName() const override

View File

@ -105,6 +105,11 @@ public:
return res; return res;
} }
bool haveSameStateRepresentation(const IAggregateFunction & rhs) const override
{
return getName() == rhs.getName() && this->haveEqualArgumentTypes(rhs);
}
bool allocatesMemoryInArena() const override { return false; } bool allocatesMemoryInArena() const override { return false; }
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override

View File

@ -179,6 +179,11 @@ public:
this->data(place).deserialize(buf); this->data(place).deserialize(buf);
} }
bool haveSameStateRepresentation(const IAggregateFunction & rhs) const override
{
return this->getName() == rhs.getName() && this->haveEqualArgumentTypes(rhs);
}
private: private:
enum class PatternActionType enum class PatternActionType
{ {

View File

@ -31,10 +31,10 @@ namespace
template <typename T> template <typename T>
inline AggregateFunctionPtr createAggregateFunctionSequenceNodeImpl( inline AggregateFunctionPtr createAggregateFunctionSequenceNodeImpl(
const DataTypePtr data_type, const DataTypes & argument_types, SequenceDirection direction, SequenceBase base) const DataTypePtr data_type, const DataTypes & argument_types, const Array & parameters, SequenceDirection direction, SequenceBase base)
{ {
return std::make_shared<SequenceNextNodeImpl<T, NodeString<max_events_size>>>( return std::make_shared<SequenceNextNodeImpl<T, NodeString<max_events_size>>>(
data_type, argument_types, base, direction, min_required_args); data_type, argument_types, parameters, base, direction, min_required_args);
} }
AggregateFunctionPtr AggregateFunctionPtr
@ -116,17 +116,17 @@ createAggregateFunctionSequenceNode(const std::string & name, const DataTypes &
WhichDataType timestamp_type(argument_types[0].get()); WhichDataType timestamp_type(argument_types[0].get());
if (timestamp_type.idx == TypeIndex::UInt8) if (timestamp_type.idx == TypeIndex::UInt8)
return createAggregateFunctionSequenceNodeImpl<UInt8>(data_type, argument_types, direction, base); return createAggregateFunctionSequenceNodeImpl<UInt8>(data_type, argument_types, parameters, direction, base);
if (timestamp_type.idx == TypeIndex::UInt16) if (timestamp_type.idx == TypeIndex::UInt16)
return createAggregateFunctionSequenceNodeImpl<UInt16>(data_type, argument_types, direction, base); return createAggregateFunctionSequenceNodeImpl<UInt16>(data_type, argument_types, parameters, direction, base);
if (timestamp_type.idx == TypeIndex::UInt32) if (timestamp_type.idx == TypeIndex::UInt32)
return createAggregateFunctionSequenceNodeImpl<UInt32>(data_type, argument_types, direction, base); return createAggregateFunctionSequenceNodeImpl<UInt32>(data_type, argument_types, parameters, direction, base);
if (timestamp_type.idx == TypeIndex::UInt64) if (timestamp_type.idx == TypeIndex::UInt64)
return createAggregateFunctionSequenceNodeImpl<UInt64>(data_type, argument_types, direction, base); return createAggregateFunctionSequenceNodeImpl<UInt64>(data_type, argument_types, parameters, direction, base);
if (timestamp_type.isDate()) if (timestamp_type.isDate())
return createAggregateFunctionSequenceNodeImpl<DataTypeDate::FieldType>(data_type, argument_types, direction, base); return createAggregateFunctionSequenceNodeImpl<DataTypeDate::FieldType>(data_type, argument_types, parameters, direction, base);
if (timestamp_type.isDateTime()) if (timestamp_type.isDateTime())
return createAggregateFunctionSequenceNodeImpl<DataTypeDateTime::FieldType>(data_type, argument_types, direction, base); return createAggregateFunctionSequenceNodeImpl<DataTypeDateTime::FieldType>(data_type, argument_types, parameters, direction, base);
throw Exception{"Illegal type " + argument_types.front().get()->getName() throw Exception{"Illegal type " + argument_types.front().get()->getName()
+ " of first argument of aggregate function " + name + ", must be Unsigned Number, Date, DateTime", + " of first argument of aggregate function " + name + ", must be Unsigned Number, Date, DateTime",

View File

@ -175,11 +175,12 @@ public:
SequenceNextNodeImpl( SequenceNextNodeImpl(
const DataTypePtr & data_type_, const DataTypePtr & data_type_,
const DataTypes & arguments, const DataTypes & arguments,
const Array & parameters_,
SequenceBase seq_base_kind_, SequenceBase seq_base_kind_,
SequenceDirection seq_direction_, SequenceDirection seq_direction_,
size_t min_required_args_, size_t min_required_args_,
UInt64 max_elems_ = std::numeric_limits<UInt64>::max()) UInt64 max_elems_ = std::numeric_limits<UInt64>::max())
: IAggregateFunctionDataHelper<SequenceNextNodeGeneralData<Node>, Self>({data_type_}, {}) : IAggregateFunctionDataHelper<SequenceNextNodeGeneralData<Node>, Self>({data_type_}, parameters_)
, seq_base_kind(seq_base_kind_) , seq_base_kind(seq_base_kind_)
, seq_direction(seq_direction_) , seq_direction(seq_direction_)
, min_required_args(min_required_args_) , min_required_args(min_required_args_)
@ -193,6 +194,11 @@ public:
DataTypePtr getReturnType() const override { return data_type; } DataTypePtr getReturnType() const override { return data_type; }
bool haveSameStateRepresentation(const IAggregateFunction & rhs) const override
{
return this->getName() == rhs.getName() && this->haveEqualArgumentTypes(rhs);
}
AggregateFunctionPtr getOwnNullAdapter( AggregateFunctionPtr getOwnNullAdapter(
const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array & params, const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array & params,
const AggregateFunctionProperties &) const override const AggregateFunctionProperties &) const override

View File

@ -50,4 +50,21 @@ String IAggregateFunction::getDescription() const
return description; return description;
} }
bool IAggregateFunction::haveEqualArgumentTypes(const IAggregateFunction & rhs) const
{
return std::equal(argument_types.begin(), argument_types.end(),
rhs.argument_types.begin(), rhs.argument_types.end(),
[](const auto & t1, const auto & t2) { return t1->equals(*t2); });
}
bool IAggregateFunction::haveSameStateRepresentation(const IAggregateFunction & rhs) const
{
bool res = getName() == rhs.getName()
&& parameters == rhs.parameters
&& haveEqualArgumentTypes(rhs);
assert(res == (getStateType()->getName() == rhs.getStateType()->getName()));
return res;
}
} }

View File

@ -74,6 +74,16 @@ public:
/// Get the data type of internal state. By default it is AggregateFunction(name(params), argument_types...). /// Get the data type of internal state. By default it is AggregateFunction(name(params), argument_types...).
virtual DataTypePtr getStateType() const; virtual DataTypePtr getStateType() const;
/// Returns true if two aggregate functions have the same state representation in memory and the same serialization,
/// so state of one aggregate function can be safely used with another.
/// Examples:
/// - quantile(x), quantile(a)(x), quantile(b)(x) - parameter doesn't affect state and used for finalization only
/// - foo(x) and fooIf(x) - If combinator doesn't affect state
/// By default returns true only if functions have exactly the same names, combinators and parameters.
virtual bool haveSameStateRepresentation(const IAggregateFunction & rhs) const;
bool haveEqualArgumentTypes(const IAggregateFunction & rhs) const;
/// Get type which will be used for prediction result in case if function is an ML method. /// Get type which will be used for prediction result in case if function is an ML method.
virtual DataTypePtr getReturnTypeToPredict() const virtual DataTypePtr getReturnTypeToPredict() const
{ {

View File

@ -1,4 +1,4 @@
with (select sumState(1)) as s select sumMerge(s); with (select sumState(1)) as s select sumMerge(s);
with (select sumState(number) from (select * from system.numbers limit 10)) as s select sumMerge(s); with (select sumState(number) from (select * from system.numbers limit 10)) as s select sumMerge(s);
with (select quantileState(0.5)(number) from (select * from system.numbers limit 10)) as s select quantileMerge(0.5)(s); with (select quantileState(0.5)(number) from (select * from system.numbers limit 10)) as s select quantileMerge(s);