mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-24 00:22:29 +00:00
allow mismatching parameters for some functions
This commit is contained in:
parent
ebcf0844f5
commit
764701c3f3
@ -35,13 +35,9 @@ public:
|
|||||||
{
|
{
|
||||||
const DataTypeAggregateFunction * data_type = typeid_cast<const DataTypeAggregateFunction *>(argument.get());
|
const DataTypeAggregateFunction * data_type = typeid_cast<const DataTypeAggregateFunction *>(argument.get());
|
||||||
|
|
||||||
if (!data_type || data_type->getFunctionName() != nested_func->getName())
|
if (!data_type || !nested_func->haveSameStateRepresentation(*data_type->getFunction()))
|
||||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument for aggregate function {}",
|
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument for aggregate function {}, "
|
||||||
argument->getName(), getName());
|
"expected {} or equivalent type", argument->getName(), getName(), getStateType()->getName());
|
||||||
|
|
||||||
if (data_type->getParameters() != getParameters())
|
|
||||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument for aggregate function {}: "
|
|
||||||
"parameters mismatch", argument->getName(), getName());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
String getName() const override
|
String getName() const override
|
||||||
|
@ -105,6 +105,11 @@ public:
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool haveSameStateRepresentation(const IAggregateFunction & rhs) const override
|
||||||
|
{
|
||||||
|
return getName() == rhs.getName() && this->haveEqualArgumentTypes(rhs);
|
||||||
|
}
|
||||||
|
|
||||||
bool allocatesMemoryInArena() const override { return false; }
|
bool allocatesMemoryInArena() const override { return false; }
|
||||||
|
|
||||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||||
|
@ -179,6 +179,11 @@ public:
|
|||||||
this->data(place).deserialize(buf);
|
this->data(place).deserialize(buf);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool haveSameStateRepresentation(const IAggregateFunction & rhs) const override
|
||||||
|
{
|
||||||
|
return this->getName() == rhs.getName() && this->haveEqualArgumentTypes(rhs);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
enum class PatternActionType
|
enum class PatternActionType
|
||||||
{
|
{
|
||||||
|
@ -31,10 +31,10 @@ namespace
|
|||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline AggregateFunctionPtr createAggregateFunctionSequenceNodeImpl(
|
inline AggregateFunctionPtr createAggregateFunctionSequenceNodeImpl(
|
||||||
const DataTypePtr data_type, const DataTypes & argument_types, SequenceDirection direction, SequenceBase base)
|
const DataTypePtr data_type, const DataTypes & argument_types, const Array & parameters, SequenceDirection direction, SequenceBase base)
|
||||||
{
|
{
|
||||||
return std::make_shared<SequenceNextNodeImpl<T, NodeString<max_events_size>>>(
|
return std::make_shared<SequenceNextNodeImpl<T, NodeString<max_events_size>>>(
|
||||||
data_type, argument_types, base, direction, min_required_args);
|
data_type, argument_types, parameters, base, direction, min_required_args);
|
||||||
}
|
}
|
||||||
|
|
||||||
AggregateFunctionPtr
|
AggregateFunctionPtr
|
||||||
@ -116,17 +116,17 @@ createAggregateFunctionSequenceNode(const std::string & name, const DataTypes &
|
|||||||
|
|
||||||
WhichDataType timestamp_type(argument_types[0].get());
|
WhichDataType timestamp_type(argument_types[0].get());
|
||||||
if (timestamp_type.idx == TypeIndex::UInt8)
|
if (timestamp_type.idx == TypeIndex::UInt8)
|
||||||
return createAggregateFunctionSequenceNodeImpl<UInt8>(data_type, argument_types, direction, base);
|
return createAggregateFunctionSequenceNodeImpl<UInt8>(data_type, argument_types, parameters, direction, base);
|
||||||
if (timestamp_type.idx == TypeIndex::UInt16)
|
if (timestamp_type.idx == TypeIndex::UInt16)
|
||||||
return createAggregateFunctionSequenceNodeImpl<UInt16>(data_type, argument_types, direction, base);
|
return createAggregateFunctionSequenceNodeImpl<UInt16>(data_type, argument_types, parameters, direction, base);
|
||||||
if (timestamp_type.idx == TypeIndex::UInt32)
|
if (timestamp_type.idx == TypeIndex::UInt32)
|
||||||
return createAggregateFunctionSequenceNodeImpl<UInt32>(data_type, argument_types, direction, base);
|
return createAggregateFunctionSequenceNodeImpl<UInt32>(data_type, argument_types, parameters, direction, base);
|
||||||
if (timestamp_type.idx == TypeIndex::UInt64)
|
if (timestamp_type.idx == TypeIndex::UInt64)
|
||||||
return createAggregateFunctionSequenceNodeImpl<UInt64>(data_type, argument_types, direction, base);
|
return createAggregateFunctionSequenceNodeImpl<UInt64>(data_type, argument_types, parameters, direction, base);
|
||||||
if (timestamp_type.isDate())
|
if (timestamp_type.isDate())
|
||||||
return createAggregateFunctionSequenceNodeImpl<DataTypeDate::FieldType>(data_type, argument_types, direction, base);
|
return createAggregateFunctionSequenceNodeImpl<DataTypeDate::FieldType>(data_type, argument_types, parameters, direction, base);
|
||||||
if (timestamp_type.isDateTime())
|
if (timestamp_type.isDateTime())
|
||||||
return createAggregateFunctionSequenceNodeImpl<DataTypeDateTime::FieldType>(data_type, argument_types, direction, base);
|
return createAggregateFunctionSequenceNodeImpl<DataTypeDateTime::FieldType>(data_type, argument_types, parameters, direction, base);
|
||||||
|
|
||||||
throw Exception{"Illegal type " + argument_types.front().get()->getName()
|
throw Exception{"Illegal type " + argument_types.front().get()->getName()
|
||||||
+ " of first argument of aggregate function " + name + ", must be Unsigned Number, Date, DateTime",
|
+ " of first argument of aggregate function " + name + ", must be Unsigned Number, Date, DateTime",
|
||||||
|
@ -175,11 +175,12 @@ public:
|
|||||||
SequenceNextNodeImpl(
|
SequenceNextNodeImpl(
|
||||||
const DataTypePtr & data_type_,
|
const DataTypePtr & data_type_,
|
||||||
const DataTypes & arguments,
|
const DataTypes & arguments,
|
||||||
|
const Array & parameters_,
|
||||||
SequenceBase seq_base_kind_,
|
SequenceBase seq_base_kind_,
|
||||||
SequenceDirection seq_direction_,
|
SequenceDirection seq_direction_,
|
||||||
size_t min_required_args_,
|
size_t min_required_args_,
|
||||||
UInt64 max_elems_ = std::numeric_limits<UInt64>::max())
|
UInt64 max_elems_ = std::numeric_limits<UInt64>::max())
|
||||||
: IAggregateFunctionDataHelper<SequenceNextNodeGeneralData<Node>, Self>({data_type_}, {})
|
: IAggregateFunctionDataHelper<SequenceNextNodeGeneralData<Node>, Self>({data_type_}, parameters_)
|
||||||
, seq_base_kind(seq_base_kind_)
|
, seq_base_kind(seq_base_kind_)
|
||||||
, seq_direction(seq_direction_)
|
, seq_direction(seq_direction_)
|
||||||
, min_required_args(min_required_args_)
|
, min_required_args(min_required_args_)
|
||||||
@ -193,6 +194,11 @@ public:
|
|||||||
|
|
||||||
DataTypePtr getReturnType() const override { return data_type; }
|
DataTypePtr getReturnType() const override { return data_type; }
|
||||||
|
|
||||||
|
bool haveSameStateRepresentation(const IAggregateFunction & rhs) const override
|
||||||
|
{
|
||||||
|
return this->getName() == rhs.getName() && this->haveEqualArgumentTypes(rhs);
|
||||||
|
}
|
||||||
|
|
||||||
AggregateFunctionPtr getOwnNullAdapter(
|
AggregateFunctionPtr getOwnNullAdapter(
|
||||||
const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array & params,
|
const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array & params,
|
||||||
const AggregateFunctionProperties &) const override
|
const AggregateFunctionProperties &) const override
|
||||||
|
@ -50,4 +50,21 @@ String IAggregateFunction::getDescription() const
|
|||||||
|
|
||||||
return description;
|
return description;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool IAggregateFunction::haveEqualArgumentTypes(const IAggregateFunction & rhs) const
|
||||||
|
{
|
||||||
|
return std::equal(argument_types.begin(), argument_types.end(),
|
||||||
|
rhs.argument_types.begin(), rhs.argument_types.end(),
|
||||||
|
[](const auto & t1, const auto & t2) { return t1->equals(*t2); });
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IAggregateFunction::haveSameStateRepresentation(const IAggregateFunction & rhs) const
|
||||||
|
{
|
||||||
|
bool res = getName() == rhs.getName()
|
||||||
|
&& parameters == rhs.parameters
|
||||||
|
&& haveEqualArgumentTypes(rhs);
|
||||||
|
assert(res == (getStateType()->getName() == rhs.getStateType()->getName()));
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -74,6 +74,16 @@ public:
|
|||||||
/// Get the data type of internal state. By default it is AggregateFunction(name(params), argument_types...).
|
/// Get the data type of internal state. By default it is AggregateFunction(name(params), argument_types...).
|
||||||
virtual DataTypePtr getStateType() const;
|
virtual DataTypePtr getStateType() const;
|
||||||
|
|
||||||
|
/// Returns true if two aggregate functions have the same state representation in memory and the same serialization,
|
||||||
|
/// so state of one aggregate function can be safely used with another.
|
||||||
|
/// Examples:
|
||||||
|
/// - quantile(x), quantile(a)(x), quantile(b)(x) - parameter doesn't affect state and used for finalization only
|
||||||
|
/// - foo(x) and fooIf(x) - If combinator doesn't affect state
|
||||||
|
/// By default returns true only if functions have exactly the same names, combinators and parameters.
|
||||||
|
virtual bool haveSameStateRepresentation(const IAggregateFunction & rhs) const;
|
||||||
|
|
||||||
|
bool haveEqualArgumentTypes(const IAggregateFunction & rhs) const;
|
||||||
|
|
||||||
/// Get type which will be used for prediction result in case if function is an ML method.
|
/// Get type which will be used for prediction result in case if function is an ML method.
|
||||||
virtual DataTypePtr getReturnTypeToPredict() const
|
virtual DataTypePtr getReturnTypeToPredict() const
|
||||||
{
|
{
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
with (select sumState(1)) as s select sumMerge(s);
|
with (select sumState(1)) as s select sumMerge(s);
|
||||||
with (select sumState(number) from (select * from system.numbers limit 10)) as s select sumMerge(s);
|
with (select sumState(number) from (select * from system.numbers limit 10)) as s select sumMerge(s);
|
||||||
with (select quantileState(0.5)(number) from (select * from system.numbers limit 10)) as s select quantileMerge(0.5)(s);
|
with (select quantileState(0.5)(number) from (select * from system.numbers limit 10)) as s select quantileMerge(s);
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user