mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-21 15:12:02 +00:00
allow mismatching parameters for some functions
This commit is contained in:
parent
ebcf0844f5
commit
764701c3f3
@ -35,13 +35,9 @@ public:
|
||||
{
|
||||
const DataTypeAggregateFunction * data_type = typeid_cast<const DataTypeAggregateFunction *>(argument.get());
|
||||
|
||||
if (!data_type || data_type->getFunctionName() != nested_func->getName())
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument for aggregate function {}",
|
||||
argument->getName(), getName());
|
||||
|
||||
if (data_type->getParameters() != getParameters())
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument for aggregate function {}: "
|
||||
"parameters mismatch", argument->getName(), getName());
|
||||
if (!data_type || !nested_func->haveSameStateRepresentation(*data_type->getFunction()))
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument for aggregate function {}, "
|
||||
"expected {} or equivalent type", argument->getName(), getName(), getStateType()->getName());
|
||||
}
|
||||
|
||||
String getName() const override
|
||||
|
@ -105,6 +105,11 @@ public:
|
||||
return res;
|
||||
}
|
||||
|
||||
bool haveSameStateRepresentation(const IAggregateFunction & rhs) const override
|
||||
{
|
||||
return getName() == rhs.getName() && this->haveEqualArgumentTypes(rhs);
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
|
@ -179,6 +179,11 @@ public:
|
||||
this->data(place).deserialize(buf);
|
||||
}
|
||||
|
||||
bool haveSameStateRepresentation(const IAggregateFunction & rhs) const override
|
||||
{
|
||||
return this->getName() == rhs.getName() && this->haveEqualArgumentTypes(rhs);
|
||||
}
|
||||
|
||||
private:
|
||||
enum class PatternActionType
|
||||
{
|
||||
|
@ -31,10 +31,10 @@ namespace
|
||||
|
||||
template <typename T>
|
||||
inline AggregateFunctionPtr createAggregateFunctionSequenceNodeImpl(
|
||||
const DataTypePtr data_type, const DataTypes & argument_types, SequenceDirection direction, SequenceBase base)
|
||||
const DataTypePtr data_type, const DataTypes & argument_types, const Array & parameters, SequenceDirection direction, SequenceBase base)
|
||||
{
|
||||
return std::make_shared<SequenceNextNodeImpl<T, NodeString<max_events_size>>>(
|
||||
data_type, argument_types, base, direction, min_required_args);
|
||||
data_type, argument_types, parameters, base, direction, min_required_args);
|
||||
}
|
||||
|
||||
AggregateFunctionPtr
|
||||
@ -116,17 +116,17 @@ createAggregateFunctionSequenceNode(const std::string & name, const DataTypes &
|
||||
|
||||
WhichDataType timestamp_type(argument_types[0].get());
|
||||
if (timestamp_type.idx == TypeIndex::UInt8)
|
||||
return createAggregateFunctionSequenceNodeImpl<UInt8>(data_type, argument_types, direction, base);
|
||||
return createAggregateFunctionSequenceNodeImpl<UInt8>(data_type, argument_types, parameters, direction, base);
|
||||
if (timestamp_type.idx == TypeIndex::UInt16)
|
||||
return createAggregateFunctionSequenceNodeImpl<UInt16>(data_type, argument_types, direction, base);
|
||||
return createAggregateFunctionSequenceNodeImpl<UInt16>(data_type, argument_types, parameters, direction, base);
|
||||
if (timestamp_type.idx == TypeIndex::UInt32)
|
||||
return createAggregateFunctionSequenceNodeImpl<UInt32>(data_type, argument_types, direction, base);
|
||||
return createAggregateFunctionSequenceNodeImpl<UInt32>(data_type, argument_types, parameters, direction, base);
|
||||
if (timestamp_type.idx == TypeIndex::UInt64)
|
||||
return createAggregateFunctionSequenceNodeImpl<UInt64>(data_type, argument_types, direction, base);
|
||||
return createAggregateFunctionSequenceNodeImpl<UInt64>(data_type, argument_types, parameters, direction, base);
|
||||
if (timestamp_type.isDate())
|
||||
return createAggregateFunctionSequenceNodeImpl<DataTypeDate::FieldType>(data_type, argument_types, direction, base);
|
||||
return createAggregateFunctionSequenceNodeImpl<DataTypeDate::FieldType>(data_type, argument_types, parameters, direction, base);
|
||||
if (timestamp_type.isDateTime())
|
||||
return createAggregateFunctionSequenceNodeImpl<DataTypeDateTime::FieldType>(data_type, argument_types, direction, base);
|
||||
return createAggregateFunctionSequenceNodeImpl<DataTypeDateTime::FieldType>(data_type, argument_types, parameters, direction, base);
|
||||
|
||||
throw Exception{"Illegal type " + argument_types.front().get()->getName()
|
||||
+ " of first argument of aggregate function " + name + ", must be Unsigned Number, Date, DateTime",
|
||||
|
@ -175,11 +175,12 @@ public:
|
||||
SequenceNextNodeImpl(
|
||||
const DataTypePtr & data_type_,
|
||||
const DataTypes & arguments,
|
||||
const Array & parameters_,
|
||||
SequenceBase seq_base_kind_,
|
||||
SequenceDirection seq_direction_,
|
||||
size_t min_required_args_,
|
||||
UInt64 max_elems_ = std::numeric_limits<UInt64>::max())
|
||||
: IAggregateFunctionDataHelper<SequenceNextNodeGeneralData<Node>, Self>({data_type_}, {})
|
||||
: IAggregateFunctionDataHelper<SequenceNextNodeGeneralData<Node>, Self>({data_type_}, parameters_)
|
||||
, seq_base_kind(seq_base_kind_)
|
||||
, seq_direction(seq_direction_)
|
||||
, min_required_args(min_required_args_)
|
||||
@ -193,6 +194,11 @@ public:
|
||||
|
||||
DataTypePtr getReturnType() const override { return data_type; }
|
||||
|
||||
bool haveSameStateRepresentation(const IAggregateFunction & rhs) const override
|
||||
{
|
||||
return this->getName() == rhs.getName() && this->haveEqualArgumentTypes(rhs);
|
||||
}
|
||||
|
||||
AggregateFunctionPtr getOwnNullAdapter(
|
||||
const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array & params,
|
||||
const AggregateFunctionProperties &) const override
|
||||
|
@ -50,4 +50,21 @@ String IAggregateFunction::getDescription() const
|
||||
|
||||
return description;
|
||||
}
|
||||
|
||||
bool IAggregateFunction::haveEqualArgumentTypes(const IAggregateFunction & rhs) const
|
||||
{
|
||||
return std::equal(argument_types.begin(), argument_types.end(),
|
||||
rhs.argument_types.begin(), rhs.argument_types.end(),
|
||||
[](const auto & t1, const auto & t2) { return t1->equals(*t2); });
|
||||
}
|
||||
|
||||
bool IAggregateFunction::haveSameStateRepresentation(const IAggregateFunction & rhs) const
|
||||
{
|
||||
bool res = getName() == rhs.getName()
|
||||
&& parameters == rhs.parameters
|
||||
&& haveEqualArgumentTypes(rhs);
|
||||
assert(res == (getStateType()->getName() == rhs.getStateType()->getName()));
|
||||
return res;
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -74,6 +74,16 @@ public:
|
||||
/// Get the data type of internal state. By default it is AggregateFunction(name(params), argument_types...).
|
||||
virtual DataTypePtr getStateType() const;
|
||||
|
||||
/// Returns true if two aggregate functions have the same state representation in memory and the same serialization,
|
||||
/// so state of one aggregate function can be safely used with another.
|
||||
/// Examples:
|
||||
/// - quantile(x), quantile(a)(x), quantile(b)(x) - parameter doesn't affect state and used for finalization only
|
||||
/// - foo(x) and fooIf(x) - If combinator doesn't affect state
|
||||
/// By default returns true only if functions have exactly the same names, combinators and parameters.
|
||||
virtual bool haveSameStateRepresentation(const IAggregateFunction & rhs) const;
|
||||
|
||||
bool haveEqualArgumentTypes(const IAggregateFunction & rhs) const;
|
||||
|
||||
/// Get type which will be used for prediction result in case if function is an ML method.
|
||||
virtual DataTypePtr getReturnTypeToPredict() const
|
||||
{
|
||||
|
@ -1,4 +1,4 @@
|
||||
with (select sumState(1)) as s select sumMerge(s);
|
||||
with (select sumState(number) from (select * from system.numbers limit 10)) as s select sumMerge(s);
|
||||
with (select quantileState(0.5)(number) from (select * from system.numbers limit 10)) as s select quantileMerge(0.5)(s);
|
||||
with (select quantileState(0.5)(number) from (select * from system.numbers limit 10)) as s select quantileMerge(s);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user