Refactor FunctionNode

This commit is contained in:
Dmitry Novik 2022-11-28 15:02:59 +00:00
parent b6eddbac0d
commit 2c70dbc76a
106 changed files with 636 additions and 558 deletions

View File

@ -49,14 +49,16 @@ private:
public: public:
AggregateFunctionThrow(const DataTypes & argument_types_, const Array & parameters_, Float64 throw_probability_) AggregateFunctionThrow(const DataTypes & argument_types_, const Array & parameters_, Float64 throw_probability_)
: IAggregateFunctionDataHelper(argument_types_, parameters_), throw_probability(throw_probability_) {} : IAggregateFunctionDataHelper(argument_types_, parameters_, createResultType())
, throw_probability(throw_probability_)
{}
String getName() const override String getName() const override
{ {
return "aggThrow"; return "aggThrow";
} }
DataTypePtr getReturnType() const override static DataTypePtr createResultType()
{ {
return std::make_shared<DataTypeUInt8>(); return std::make_shared<DataTypeUInt8>();
} }

View File

@ -37,10 +37,10 @@ class AggregateFunctionAnalysisOfVariance final : public IAggregateFunctionDataH
{ {
public: public:
explicit AggregateFunctionAnalysisOfVariance(const DataTypes & arguments, const Array & params) explicit AggregateFunctionAnalysisOfVariance(const DataTypes & arguments, const Array & params)
: IAggregateFunctionDataHelper(arguments, params) : IAggregateFunctionDataHelper(arguments, params, createResultType())
{} {}
DataTypePtr getReturnType() const override DataTypePtr createResultType() const
{ {
DataTypes types {std::make_shared<DataTypeNumber<Float64>>(), std::make_shared<DataTypeNumber<Float64>>() }; DataTypes types {std::make_shared<DataTypeNumber<Float64>>(), std::make_shared<DataTypeNumber<Float64>>() };
Strings names {"f_statistic", "p_value"}; Strings names {"f_statistic", "p_value"};

View File

@ -38,7 +38,6 @@ template <typename Data>
class AggregateFunctionArgMinMax final : public IAggregateFunctionDataHelper<Data, AggregateFunctionArgMinMax<Data>> class AggregateFunctionArgMinMax final : public IAggregateFunctionDataHelper<Data, AggregateFunctionArgMinMax<Data>>
{ {
private: private:
const DataTypePtr & type_res;
const DataTypePtr & type_val; const DataTypePtr & type_val;
const SerializationPtr serialization_res; const SerializationPtr serialization_res;
const SerializationPtr serialization_val; const SerializationPtr serialization_val;
@ -47,10 +46,9 @@ private:
public: public:
AggregateFunctionArgMinMax(const DataTypePtr & type_res_, const DataTypePtr & type_val_) AggregateFunctionArgMinMax(const DataTypePtr & type_res_, const DataTypePtr & type_val_)
: Base({type_res_, type_val_}, {}) : Base({type_res_, type_val_}, {}, type_res_)
, type_res(this->argument_types[0])
, type_val(this->argument_types[1]) , type_val(this->argument_types[1])
, serialization_res(type_res->getDefaultSerialization()) , serialization_res(type_res_->getDefaultSerialization())
, serialization_val(type_val->getDefaultSerialization()) , serialization_val(type_val->getDefaultSerialization())
{ {
if (!type_val->isComparable()) if (!type_val->isComparable())
@ -63,11 +61,6 @@ public:
return StringRef(Data::ValueData_t::name()) == StringRef("min") ? "argMin" : "argMax"; return StringRef(Data::ValueData_t::name()) == StringRef("min") ? "argMin" : "argMax";
} }
DataTypePtr getReturnType() const override
{
return type_res;
}
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{ {
if (this->data(place).value.changeIfBetter(*columns[1], row_num, arena)) if (this->data(place).value.changeIfBetter(*columns[1], row_num, arena))

View File

@ -30,7 +30,7 @@ private:
public: public:
AggregateFunctionArray(AggregateFunctionPtr nested_, const DataTypes & arguments, const Array & params_) AggregateFunctionArray(AggregateFunctionPtr nested_, const DataTypes & arguments, const Array & params_)
: IAggregateFunctionHelper<AggregateFunctionArray>(arguments, params_) : IAggregateFunctionHelper<AggregateFunctionArray>(arguments, params_, createResultType())
, nested_func(nested_), num_arguments(arguments.size()) , nested_func(nested_), num_arguments(arguments.size())
{ {
assert(parameters == nested_func->getParameters()); assert(parameters == nested_func->getParameters());
@ -44,9 +44,9 @@ public:
return nested_func->getName() + "Array"; return nested_func->getName() + "Array";
} }
DataTypePtr getReturnType() const override DataTypePtr createResultType() const
{ {
return nested_func->getReturnType(); return nested_func->getResultType();
} }
const IAggregateFunction & getBaseAggregateFunctionWithSameStateRepresentation() const override const IAggregateFunction & getBaseAggregateFunctionWithSameStateRepresentation() const override

View File

@ -11,6 +11,9 @@
#include <AggregateFunctions/AggregateFunctionSum.h> #include <AggregateFunctions/AggregateFunctionSum.h>
#include <Core/DecimalFunctions.h> #include <Core/DecimalFunctions.h>
#include "Core/IResolvedFunction.h"
#include "DataTypes/IDataType.h"
#include "DataTypes/Serializations/ISerialization.h"
#include "config.h" #include "config.h"
#if USE_EMBEDDED_COMPILER #if USE_EMBEDDED_COMPILER
@ -83,10 +86,20 @@ public:
using Fraction = AvgFraction<Numerator, Denominator>; using Fraction = AvgFraction<Numerator, Denominator>;
explicit AggregateFunctionAvgBase(const DataTypes & argument_types_, explicit AggregateFunctionAvgBase(const DataTypes & argument_types_,
UInt32 num_scale_ = 0, UInt32 denom_scale_ = 0) UInt32 num_scale_ = 0, UInt32 denom_scale_ = 0)
: Base(argument_types_, {}), num_scale(num_scale_), denom_scale(denom_scale_) {} : Base(argument_types_, {}, createResultType())
, num_scale(num_scale_)
, denom_scale(denom_scale_)
{}
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeNumber<Float64>>(); } AggregateFunctionAvgBase(const DataTypes & argument_types_, const DataTypePtr & result_type_,
UInt32 num_scale_ = 0, UInt32 denom_scale_ = 0)
: Base(argument_types_, {}, result_type_)
, num_scale(num_scale_)
, denom_scale(denom_scale_)
{}
DataTypePtr createResultType() const { return std::make_shared<DataTypeNumber<Float64>>(); }
bool allocatesMemoryInArena() const override { return false; } bool allocatesMemoryInArena() const override { return false; }
@ -135,7 +148,7 @@ public:
for (const auto & argument : this->argument_types) for (const auto & argument : this->argument_types)
can_be_compiled &= canBeNativeType(*argument); can_be_compiled &= canBeNativeType(*argument);
auto return_type = getReturnType(); auto return_type = this->getResultType();
can_be_compiled &= canBeNativeType(*return_type); can_be_compiled &= canBeNativeType(*return_type);
return can_be_compiled; return can_be_compiled;

View File

@ -97,11 +97,12 @@ class AggregateFunctionBitwise final : public IAggregateFunctionDataHelper<Data,
{ {
public: public:
explicit AggregateFunctionBitwise(const DataTypePtr & type) explicit AggregateFunctionBitwise(const DataTypePtr & type)
: IAggregateFunctionDataHelper<Data, AggregateFunctionBitwise<T, Data>>({type}, {}) {} : IAggregateFunctionDataHelper<Data, AggregateFunctionBitwise<T, Data>>({type}, {}, createResultType())
{}
String getName() const override { return Data::name(); } String getName() const override { return Data::name(); }
DataTypePtr getReturnType() const override static DataTypePtr createResultType()
{ {
return std::make_shared<DataTypeNumber<T>>(); return std::make_shared<DataTypeNumber<T>>();
} }
@ -137,7 +138,7 @@ public:
bool isCompilable() const override bool isCompilable() const override
{ {
auto return_type = getReturnType(); auto return_type = this->getResultType();
return canBeNativeType(*return_type); return canBeNativeType(*return_type);
} }
@ -151,7 +152,7 @@ public:
{ {
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder); llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * return_type = toNativeType(b, getReturnType()); auto * return_type = toNativeType(b, this->getResultType());
auto * value_ptr = aggregate_data_ptr; auto * value_ptr = aggregate_data_ptr;
auto * value = b.CreateLoad(return_type, value_ptr); auto * value = b.CreateLoad(return_type, value_ptr);
@ -166,7 +167,7 @@ public:
{ {
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder); llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * return_type = toNativeType(b, getReturnType()); auto * return_type = toNativeType(b, this->getResultType());
auto * value_dst_ptr = aggregate_data_dst_ptr; auto * value_dst_ptr = aggregate_data_dst_ptr;
auto * value_dst = b.CreateLoad(return_type, value_dst_ptr); auto * value_dst = b.CreateLoad(return_type, value_dst_ptr);
@ -183,7 +184,7 @@ public:
{ {
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder); llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * return_type = toNativeType(b, getReturnType()); auto * return_type = toNativeType(b, this->getResultType());
auto * value_ptr = aggregate_data_ptr; auto * value_ptr = aggregate_data_ptr;
return b.CreateLoad(return_type, value_ptr); return b.CreateLoad(return_type, value_ptr);

View File

@ -112,7 +112,7 @@ public:
} }
explicit AggregateFunctionBoundingRatio(const DataTypes & arguments) explicit AggregateFunctionBoundingRatio(const DataTypes & arguments)
: IAggregateFunctionDataHelper<AggregateFunctionBoundingRatioData, AggregateFunctionBoundingRatio>(arguments, {}) : IAggregateFunctionDataHelper<AggregateFunctionBoundingRatioData, AggregateFunctionBoundingRatio>(arguments, {}, std::make_shared<DataTypeFloat64>())
{ {
const auto * x_arg = arguments.at(0).get(); const auto * x_arg = arguments.at(0).get();
const auto * y_arg = arguments.at(1).get(); const auto * y_arg = arguments.at(1).get();
@ -122,11 +122,6 @@ public:
ErrorCodes::BAD_ARGUMENTS); ErrorCodes::BAD_ARGUMENTS);
} }
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeFloat64>();
}
bool allocatesMemoryInArena() const override { return false; } bool allocatesMemoryInArena() const override { return false; }
void add(AggregateDataPtr __restrict place, const IColumn ** columns, const size_t row_num, Arena *) const override void add(AggregateDataPtr __restrict place, const IColumn ** columns, const size_t row_num, Arena *) const override

View File

@ -46,9 +46,9 @@ private:
} }
public: public:
AggregateFunctionCategoricalIV(const DataTypes & arguments_, const Array & params_) : AggregateFunctionCategoricalIV(const DataTypes & arguments_, const Array & params_)
IAggregateFunctionHelper<AggregateFunctionCategoricalIV>{arguments_, params_}, : IAggregateFunctionHelper<AggregateFunctionCategoricalIV>{arguments_, params_, createResultType()}
category_count{arguments_.size() - 1} , category_count{arguments_.size() - 1}
{ {
// notice: argument types has been checked before // notice: argument types has been checked before
} }
@ -121,7 +121,7 @@ public:
buf.readStrict(place, sizeOfData()); buf.readStrict(place, sizeOfData());
} }
DataTypePtr getReturnType() const override static DataTypePtr createResultType()
{ {
return std::make_shared<DataTypeArray>( return std::make_shared<DataTypeArray>(
std::make_shared<DataTypeNumber<Float64>>()); std::make_shared<DataTypeNumber<Float64>>());

View File

@ -39,11 +39,13 @@ namespace ErrorCodes
class AggregateFunctionCount final : public IAggregateFunctionDataHelper<AggregateFunctionCountData, AggregateFunctionCount> class AggregateFunctionCount final : public IAggregateFunctionDataHelper<AggregateFunctionCountData, AggregateFunctionCount>
{ {
public: public:
explicit AggregateFunctionCount(const DataTypes & argument_types_) : IAggregateFunctionDataHelper(argument_types_, {}) {} explicit AggregateFunctionCount(const DataTypes & argument_types_)
: IAggregateFunctionDataHelper(argument_types_, {}, createResultType())
{}
String getName() const override { return "count"; } String getName() const override { return "count"; }
DataTypePtr getReturnType() const override static DataTypePtr createResultType()
{ {
return std::make_shared<DataTypeUInt64>(); return std::make_shared<DataTypeUInt64>();
} }
@ -167,7 +169,7 @@ public:
{ {
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder); llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * return_type = toNativeType(b, getReturnType()); auto * return_type = toNativeType(b, this->getResultType());
auto * count_value_ptr = aggregate_data_ptr; auto * count_value_ptr = aggregate_data_ptr;
auto * count_value = b.CreateLoad(return_type, count_value_ptr); auto * count_value = b.CreateLoad(return_type, count_value_ptr);
@ -180,7 +182,7 @@ public:
{ {
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder); llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * return_type = toNativeType(b, getReturnType()); auto * return_type = toNativeType(b, this->getResultType());
auto * count_value_dst_ptr = aggregate_data_dst_ptr; auto * count_value_dst_ptr = aggregate_data_dst_ptr;
auto * count_value_dst = b.CreateLoad(return_type, count_value_dst_ptr); auto * count_value_dst = b.CreateLoad(return_type, count_value_dst_ptr);
@ -197,7 +199,7 @@ public:
{ {
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder); llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * return_type = toNativeType(b, getReturnType()); auto * return_type = toNativeType(b, this->getResultType());
auto * count_value_ptr = aggregate_data_ptr; auto * count_value_ptr = aggregate_data_ptr;
return b.CreateLoad(return_type, count_value_ptr); return b.CreateLoad(return_type, count_value_ptr);
@ -214,7 +216,7 @@ class AggregateFunctionCountNotNullUnary final
{ {
public: public:
AggregateFunctionCountNotNullUnary(const DataTypePtr & argument, const Array & params) AggregateFunctionCountNotNullUnary(const DataTypePtr & argument, const Array & params)
: IAggregateFunctionDataHelper<AggregateFunctionCountData, AggregateFunctionCountNotNullUnary>({argument}, params) : IAggregateFunctionDataHelper<AggregateFunctionCountData, AggregateFunctionCountNotNullUnary>({argument}, params, createResultType())
{ {
if (!argument->isNullable()) if (!argument->isNullable())
throw Exception("Logical error: not Nullable data type passed to AggregateFunctionCountNotNullUnary", ErrorCodes::LOGICAL_ERROR); throw Exception("Logical error: not Nullable data type passed to AggregateFunctionCountNotNullUnary", ErrorCodes::LOGICAL_ERROR);
@ -222,7 +224,7 @@ public:
String getName() const override { return "count"; } String getName() const override { return "count"; }
DataTypePtr getReturnType() const override static DataTypePtr createResultType()
{ {
return std::make_shared<DataTypeUInt64>(); return std::make_shared<DataTypeUInt64>();
} }
@ -311,7 +313,7 @@ public:
{ {
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder); llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * return_type = toNativeType(b, getReturnType()); auto * return_type = toNativeType(b, this->getResultType());
auto * is_null_value = b.CreateExtractValue(values[0], {1}); auto * is_null_value = b.CreateExtractValue(values[0], {1});
auto * increment_value = b.CreateSelect(is_null_value, llvm::ConstantInt::get(return_type, 0), llvm::ConstantInt::get(return_type, 1)); auto * increment_value = b.CreateSelect(is_null_value, llvm::ConstantInt::get(return_type, 0), llvm::ConstantInt::get(return_type, 1));
@ -327,7 +329,7 @@ public:
{ {
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder); llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * return_type = toNativeType(b, getReturnType()); auto * return_type = toNativeType(b, this->getResultType());
auto * count_value_dst_ptr = aggregate_data_dst_ptr; auto * count_value_dst_ptr = aggregate_data_dst_ptr;
auto * count_value_dst = b.CreateLoad(return_type, count_value_dst_ptr); auto * count_value_dst = b.CreateLoad(return_type, count_value_dst_ptr);
@ -344,7 +346,7 @@ public:
{ {
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder); llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * return_type = toNativeType(b, getReturnType()); auto * return_type = toNativeType(b, this->getResultType());
auto * count_value_ptr = aggregate_data_ptr; auto * count_value_ptr = aggregate_data_ptr;
return b.CreateLoad(return_type, count_value_ptr); return b.CreateLoad(return_type, count_value_ptr);

View File

@ -31,7 +31,7 @@ class AggregationFunctionDeltaSum final
{ {
public: public:
AggregationFunctionDeltaSum(const DataTypes & arguments, const Array & params) AggregationFunctionDeltaSum(const DataTypes & arguments, const Array & params)
: IAggregateFunctionDataHelper<AggregationFunctionDeltaSumData<T>, AggregationFunctionDeltaSum<T>>{arguments, params} : IAggregateFunctionDataHelper<AggregationFunctionDeltaSumData<T>, AggregationFunctionDeltaSum<T>>{arguments, params, createResultType()}
{} {}
AggregationFunctionDeltaSum() AggregationFunctionDeltaSum()
@ -40,7 +40,7 @@ public:
String getName() const override { return "deltaSum"; } String getName() const override { return "deltaSum"; }
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeNumber<T>>(); } static DataTypePtr createResultType() { return std::make_shared<DataTypeNumber<T>>(); }
bool allocatesMemoryInArena() const override { return false; } bool allocatesMemoryInArena() const override { return false; }

View File

@ -38,7 +38,7 @@ public:
: IAggregateFunctionDataHelper< : IAggregateFunctionDataHelper<
AggregationFunctionDeltaSumTimestampData<ValueType, TimestampType>, AggregationFunctionDeltaSumTimestampData<ValueType, TimestampType>,
AggregationFunctionDeltaSumTimestamp<ValueType, TimestampType> AggregationFunctionDeltaSumTimestamp<ValueType, TimestampType>
>{arguments, params} >{arguments, params, createResultType()}
{} {}
AggregationFunctionDeltaSumTimestamp() AggregationFunctionDeltaSumTimestamp()
@ -52,7 +52,7 @@ public:
String getName() const override { return "deltaSumTimestamp"; } String getName() const override { return "deltaSumTimestamp"; }
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeNumber<ValueType>>(); } static DataTypePtr createResultType() { return std::make_shared<DataTypeNumber<ValueType>>(); }
void NO_SANITIZE_UNDEFINED ALWAYS_INLINE add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override void NO_SANITIZE_UNDEFINED ALWAYS_INLINE add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
{ {

View File

@ -168,7 +168,7 @@ private:
public: public:
AggregateFunctionDistinct(AggregateFunctionPtr nested_func_, const DataTypes & arguments, const Array & params_) AggregateFunctionDistinct(AggregateFunctionPtr nested_func_, const DataTypes & arguments, const Array & params_)
: IAggregateFunctionDataHelper<Data, AggregateFunctionDistinct>(arguments, params_) : IAggregateFunctionDataHelper<Data, AggregateFunctionDistinct>(arguments, params_, nested_func_->getResultType())
, nested_func(nested_func_) , nested_func(nested_func_)
, arguments_num(arguments.size()) , arguments_num(arguments.size())
{ {
@ -255,11 +255,6 @@ public:
return nested_func->getName() + "Distinct"; return nested_func->getName() + "Distinct";
} }
DataTypePtr getReturnType() const override
{
return nested_func->getReturnType();
}
bool allocatesMemoryInArena() const override bool allocatesMemoryInArena() const override
{ {
return true; return true;

View File

@ -92,14 +92,14 @@ private:
public: public:
explicit AggregateFunctionEntropy(const DataTypes & argument_types_) explicit AggregateFunctionEntropy(const DataTypes & argument_types_)
: IAggregateFunctionDataHelper<EntropyData<Value>, AggregateFunctionEntropy<Value>>(argument_types_, {}) : IAggregateFunctionDataHelper<EntropyData<Value>, AggregateFunctionEntropy<Value>>(argument_types_, {}, createResultType())
, num_args(argument_types_.size()) , num_args(argument_types_.size())
{ {
} }
String getName() const override { return "entropy"; } String getName() const override { return "entropy"; }
DataTypePtr getReturnType() const override static DataTypePtr createResultType()
{ {
return std::make_shared<DataTypeNumber<Float64>>(); return std::make_shared<DataTypeNumber<Float64>>();
} }

View File

@ -29,7 +29,7 @@ private:
public: public:
AggregateFunctionExponentialMovingAverage(const DataTypes & argument_types_, const Array & params) AggregateFunctionExponentialMovingAverage(const DataTypes & argument_types_, const Array & params)
: IAggregateFunctionDataHelper<ExponentiallySmoothedAverage, AggregateFunctionExponentialMovingAverage>(argument_types_, params) : IAggregateFunctionDataHelper<ExponentiallySmoothedAverage, AggregateFunctionExponentialMovingAverage>(argument_types_, params, createResultType())
{ {
if (params.size() != 1) if (params.size() != 1)
throw Exception{"Aggregate function " + getName() + " requires exactly one parameter: half decay time.", throw Exception{"Aggregate function " + getName() + " requires exactly one parameter: half decay time.",
@ -43,7 +43,7 @@ public:
return "exponentialMovingAverage"; return "exponentialMovingAverage";
} }
DataTypePtr getReturnType() const override static DataTypePtr createResultType()
{ {
return std::make_shared<DataTypeNumber<Float64>>(); return std::make_shared<DataTypeNumber<Float64>>();
} }

View File

@ -107,7 +107,7 @@ private:
public: public:
AggregateFunctionForEach(AggregateFunctionPtr nested_, const DataTypes & arguments, const Array & params_) AggregateFunctionForEach(AggregateFunctionPtr nested_, const DataTypes & arguments, const Array & params_)
: IAggregateFunctionDataHelper<AggregateFunctionForEachData, AggregateFunctionForEach>(arguments, params_) : IAggregateFunctionDataHelper<AggregateFunctionForEachData, AggregateFunctionForEach>(arguments, params_, nested_->getResultType())
, nested_func(nested_), num_arguments(arguments.size()) , nested_func(nested_), num_arguments(arguments.size())
{ {
nested_size_of_data = nested_func->sizeOfData(); nested_size_of_data = nested_func->sizeOfData();
@ -125,11 +125,6 @@ public:
return nested_func->getName() + "ForEach"; return nested_func->getName() + "ForEach";
} }
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeArray>(nested_func->getReturnType());
}
bool isVersioned() const override bool isVersioned() const override
{ {
return nested_func->isVersioned(); return nested_func->isVersioned();

View File

@ -121,7 +121,7 @@ public:
explicit GroupArrayNumericImpl( explicit GroupArrayNumericImpl(
const DataTypePtr & data_type_, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max(), UInt64 seed_ = 123456) const DataTypePtr & data_type_, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max(), UInt64 seed_ = 123456)
: IAggregateFunctionDataHelper<GroupArrayNumericData<T, Trait::sampler != Sampler::NONE>, GroupArrayNumericImpl<T, Trait>>( : IAggregateFunctionDataHelper<GroupArrayNumericData<T, Trait::sampler != Sampler::NONE>, GroupArrayNumericImpl<T, Trait>>(
{data_type_}, parameters_) {data_type_}, parameters_, std::make_shared<DataTypeArray>(data_type_))
, max_elems(max_elems_) , max_elems(max_elems_)
, seed(seed_) , seed(seed_)
{ {
@ -129,8 +129,6 @@ public:
String getName() const override { return getNameByTrait<Trait>(); } String getName() const override { return getNameByTrait<Trait>(); }
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeArray>(this->argument_types[0]); }
void insert(Data & a, const T & v, Arena * arena) const void insert(Data & a, const T & v, Arena * arena) const
{ {
++a.total_values; ++a.total_values;
@ -423,7 +421,7 @@ class GroupArrayGeneralImpl final
public: public:
GroupArrayGeneralImpl(const DataTypePtr & data_type_, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max(), UInt64 seed_ = 123456) GroupArrayGeneralImpl(const DataTypePtr & data_type_, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max(), UInt64 seed_ = 123456)
: IAggregateFunctionDataHelper<GroupArrayGeneralData<Node, Trait::sampler != Sampler::NONE>, GroupArrayGeneralImpl<Node, Trait>>( : IAggregateFunctionDataHelper<GroupArrayGeneralData<Node, Trait::sampler != Sampler::NONE>, GroupArrayGeneralImpl<Node, Trait>>(
{data_type_}, parameters_) {data_type_}, parameters_, std::make_shared<DataTypeArray>(data_type_))
, data_type(this->argument_types[0]) , data_type(this->argument_types[0])
, max_elems(max_elems_) , max_elems(max_elems_)
, seed(seed_) , seed(seed_)
@ -432,8 +430,6 @@ public:
String getName() const override { return getNameByTrait<Trait>(); } String getName() const override { return getNameByTrait<Trait>(); }
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeArray>(data_type); }
void insert(Data & a, const Node * v, Arena * arena) const void insert(Data & a, const Node * v, Arena * arena) const
{ {
++a.total_values; ++a.total_values;
@ -697,7 +693,7 @@ class GroupArrayGeneralListImpl final
public: public:
GroupArrayGeneralListImpl(const DataTypePtr & data_type_, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max()) GroupArrayGeneralListImpl(const DataTypePtr & data_type_, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max())
: IAggregateFunctionDataHelper<GroupArrayGeneralListData<Node>, GroupArrayGeneralListImpl<Node, Trait>>({data_type_}, parameters_) : IAggregateFunctionDataHelper<GroupArrayGeneralListData<Node>, GroupArrayGeneralListImpl<Node, Trait>>({data_type_}, parameters_, std::make_shared<DataTypeArray>(data_type_))
, data_type(this->argument_types[0]) , data_type(this->argument_types[0])
, max_elems(max_elems_) , max_elems(max_elems_)
{ {
@ -705,8 +701,6 @@ public:
String getName() const override { return getNameByTrait<Trait>(); } String getName() const override { return getNameByTrait<Trait>(); }
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeArray>(data_type); }
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{ {
if (limit_num_elems && data(place).elems >= max_elems) if (limit_num_elems && data(place).elems >= max_elems)

View File

@ -64,7 +64,7 @@ private:
public: public:
AggregateFunctionGroupArrayInsertAtGeneric(const DataTypes & arguments, const Array & params) AggregateFunctionGroupArrayInsertAtGeneric(const DataTypes & arguments, const Array & params)
: IAggregateFunctionDataHelper<AggregateFunctionGroupArrayInsertAtDataGeneric, AggregateFunctionGroupArrayInsertAtGeneric>(arguments, params) : IAggregateFunctionDataHelper<AggregateFunctionGroupArrayInsertAtDataGeneric, AggregateFunctionGroupArrayInsertAtGeneric>(arguments, params, std::make_shared<DataTypeArray>(arguments[0]))
, type(argument_types[0]) , type(argument_types[0])
, serialization(type->getDefaultSerialization()) , serialization(type->getDefaultSerialization())
{ {
@ -101,11 +101,6 @@ public:
String getName() const override { return "groupArrayInsertAt"; } String getName() const override { return "groupArrayInsertAt"; }
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeArray>(type);
}
bool allocatesMemoryInArena() const override { return false; } bool allocatesMemoryInArena() const override { return false; }
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override

View File

@ -93,12 +93,15 @@ public:
using ColumnResult = ColumnVectorOrDecimal<ResultT>; using ColumnResult = ColumnVectorOrDecimal<ResultT>;
explicit MovingImpl(const DataTypePtr & data_type_, UInt64 window_size_ = std::numeric_limits<UInt64>::max()) explicit MovingImpl(const DataTypePtr & data_type_, UInt64 window_size_ = std::numeric_limits<UInt64>::max())
: IAggregateFunctionDataHelper<Data, MovingImpl<T, LimitNumElements, Data>>({data_type_}, {}) : IAggregateFunctionDataHelper<Data, MovingImpl<T, LimitNumElements, Data>>({data_type_}, {}, createResultType(data_type_))
, window_size(window_size_) {} , window_size(window_size_) {}
String getName() const override { return Data::name; } String getName() const override { return Data::name; }
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeArray>(getReturnTypeElement()); } static DataTypePtr createResultType(const DataTypePtr & argument)
{
return std::make_shared<DataTypeArray>(getReturnTypeElement(argument));
}
void NO_SANITIZE_UNDEFINED add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override void NO_SANITIZE_UNDEFINED add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{ {
@ -183,14 +186,14 @@ public:
} }
private: private:
auto getReturnTypeElement() const static auto getReturnTypeElement(const DataTypePtr & argument)
{ {
if constexpr (!is_decimal<ResultT>) if constexpr (!is_decimal<ResultT>)
return std::make_shared<DataTypeNumber<ResultT>>(); return std::make_shared<DataTypeNumber<ResultT>>();
else else
{ {
using Res = DataTypeDecimal<ResultT>; using Res = DataTypeDecimal<ResultT>;
return std::make_shared<Res>(Res::maxPrecision(), getDecimalScale(*this->argument_types.at(0))); return std::make_shared<Res>(Res::maxPrecision(), getDecimalScale(*argument));
} }
} }
}; };

View File

@ -74,7 +74,7 @@ namespace
/// groupBitmap needs to know about the data type that was used to create bitmaps. /// groupBitmap needs to know about the data type that was used to create bitmaps.
/// We need to look inside the type of its argument to obtain it. /// We need to look inside the type of its argument to obtain it.
const DataTypeAggregateFunction & datatype_aggfunc = dynamic_cast<const DataTypeAggregateFunction &>(*argument_type_ptr); const DataTypeAggregateFunction & datatype_aggfunc = dynamic_cast<const DataTypeAggregateFunction &>(*argument_type_ptr);
AggregateFunctionPtr aggfunc = datatype_aggfunc.getFunction(); ConstAggregateFunctionPtr aggfunc = datatype_aggfunc.getFunction();
if (aggfunc->getName() != AggregateFunctionGroupBitmapData<UInt8>::name()) if (aggfunc->getName() != AggregateFunctionGroupBitmapData<UInt8>::name())
throw Exception( throw Exception(

View File

@ -19,13 +19,13 @@ class AggregateFunctionBitmap final : public IAggregateFunctionDataHelper<Data,
{ {
public: public:
explicit AggregateFunctionBitmap(const DataTypePtr & type) explicit AggregateFunctionBitmap(const DataTypePtr & type)
: IAggregateFunctionDataHelper<Data, AggregateFunctionBitmap<T, Data>>({type}, {}) : IAggregateFunctionDataHelper<Data, AggregateFunctionBitmap<T, Data>>({type}, {}, createResultType())
{ {
} }
String getName() const override { return Data::name(); } String getName() const override { return Data::name(); }
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeNumber<T>>(); } static DataTypePtr createResultType() { return std::make_shared<DataTypeNumber<T>>(); }
bool allocatesMemoryInArena() const override { return false; } bool allocatesMemoryInArena() const override { return false; }
@ -59,13 +59,13 @@ private:
static constexpr size_t STATE_VERSION_1_MIN_REVISION = 54455; static constexpr size_t STATE_VERSION_1_MIN_REVISION = 54455;
public: public:
explicit AggregateFunctionBitmapL2(const DataTypePtr & type) explicit AggregateFunctionBitmapL2(const DataTypePtr & type)
: IAggregateFunctionDataHelper<Data, AggregateFunctionBitmapL2<T, Data, Policy>>({type}, {}) : IAggregateFunctionDataHelper<Data, AggregateFunctionBitmapL2<T, Data, Policy>>({type}, {}, createResultType())
{ {
} }
String getName() const override { return Policy::name; } String getName() const override { return Policy::name; }
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeNumber<T>>(); } static DataTypePtr createResultType() { return std::make_shared<DataTypeNumber<T>>(); }
bool allocatesMemoryInArena() const override { return false; } bool allocatesMemoryInArena() const override { return false; }

View File

@ -26,8 +26,8 @@ class AggregateFunctionGroupUniqArrayDate : public AggregateFunctionGroupUniqArr
{ {
public: public:
explicit AggregateFunctionGroupUniqArrayDate(const DataTypePtr & argument_type, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max()) explicit AggregateFunctionGroupUniqArrayDate(const DataTypePtr & argument_type, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max())
: AggregateFunctionGroupUniqArray<DataTypeDate::FieldType, HasLimit>(argument_type, parameters_, max_elems_) {} : AggregateFunctionGroupUniqArray<DataTypeDate::FieldType, HasLimit>(argument_type, parameters_, createResultType(), max_elems_) {}
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeArray>(std::make_shared<DataTypeDate>()); } static DataTypePtr createResultType() { return std::make_shared<DataTypeArray>(std::make_shared<DataTypeDate>()); }
}; };
template <typename HasLimit> template <typename HasLimit>
@ -35,8 +35,8 @@ class AggregateFunctionGroupUniqArrayDateTime : public AggregateFunctionGroupUni
{ {
public: public:
explicit AggregateFunctionGroupUniqArrayDateTime(const DataTypePtr & argument_type, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max()) explicit AggregateFunctionGroupUniqArrayDateTime(const DataTypePtr & argument_type, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max())
: AggregateFunctionGroupUniqArray<DataTypeDateTime::FieldType, HasLimit>(argument_type, parameters_, max_elems_) {} : AggregateFunctionGroupUniqArray<DataTypeDateTime::FieldType, HasLimit>(argument_type, parameters_, createResultType(), max_elems_) {}
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeArray>(std::make_shared<DataTypeDateTime>()); } static DataTypePtr createResultType() { return std::make_shared<DataTypeArray>(std::make_shared<DataTypeDateTime>()); }
}; };
template <typename HasLimit, typename ... TArgs> template <typename HasLimit, typename ... TArgs>

View File

@ -50,15 +50,16 @@ private:
public: public:
AggregateFunctionGroupUniqArray(const DataTypePtr & argument_type, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max()) AggregateFunctionGroupUniqArray(const DataTypePtr & argument_type, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max())
: IAggregateFunctionDataHelper<AggregateFunctionGroupUniqArrayData<T>, : IAggregateFunctionDataHelper<AggregateFunctionGroupUniqArrayData<T>,
AggregateFunctionGroupUniqArray<T, LimitNumElems>>({argument_type}, parameters_), AggregateFunctionGroupUniqArray<T, LimitNumElems>>({argument_type}, parameters_, std::make_shared<DataTypeArray>(argument_type)),
max_elems(max_elems_) {} max_elems(max_elems_) {}
String getName() const override { return "groupUniqArray"; } AggregateFunctionGroupUniqArray(const DataTypePtr & argument_type, const Array & parameters_, const DataTypePtr & result_type_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max())
: IAggregateFunctionDataHelper<AggregateFunctionGroupUniqArrayData<T>,
AggregateFunctionGroupUniqArray<T, LimitNumElems>>({argument_type}, parameters_, result_type_),
max_elems(max_elems_) {}
DataTypePtr getReturnType() const override
{ String getName() const override { return "groupUniqArray"; }
return std::make_shared<DataTypeArray>(this->argument_types[0]);
}
bool allocatesMemoryInArena() const override { return false; } bool allocatesMemoryInArena() const override { return false; }
@ -153,17 +154,12 @@ class AggregateFunctionGroupUniqArrayGeneric
public: public:
AggregateFunctionGroupUniqArrayGeneric(const DataTypePtr & input_data_type_, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max()) AggregateFunctionGroupUniqArrayGeneric(const DataTypePtr & input_data_type_, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max())
: IAggregateFunctionDataHelper<AggregateFunctionGroupUniqArrayGenericData, AggregateFunctionGroupUniqArrayGeneric<is_plain_column, LimitNumElems>>({input_data_type_}, parameters_) : IAggregateFunctionDataHelper<AggregateFunctionGroupUniqArrayGenericData, AggregateFunctionGroupUniqArrayGeneric<is_plain_column, LimitNumElems>>({input_data_type_}, parameters_, std::make_shared<DataTypeArray>(input_data_type_))
, input_data_type(this->argument_types[0]) , input_data_type(this->argument_types[0])
, max_elems(max_elems_) {} , max_elems(max_elems_) {}
String getName() const override { return "groupUniqArray"; } String getName() const override { return "groupUniqArray"; }
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeArray>(input_data_type);
}
bool allocatesMemoryInArena() const override bool allocatesMemoryInArena() const override
{ {
return true; return true;

View File

@ -307,7 +307,7 @@ private:
public: public:
AggregateFunctionHistogram(UInt32 max_bins_, const DataTypes & arguments, const Array & params) AggregateFunctionHistogram(UInt32 max_bins_, const DataTypes & arguments, const Array & params)
: IAggregateFunctionDataHelper<AggregateFunctionHistogramData, AggregateFunctionHistogram<T>>(arguments, params) : IAggregateFunctionDataHelper<AggregateFunctionHistogramData, AggregateFunctionHistogram<T>>(arguments, params, createResultType())
, max_bins(max_bins_) , max_bins(max_bins_)
{ {
} }
@ -316,7 +316,7 @@ public:
{ {
return Data::structSize(max_bins); return Data::structSize(max_bins);
} }
DataTypePtr getReturnType() const override static DataTypePtr createResultType()
{ {
DataTypes types; DataTypes types;
auto mean = std::make_shared<DataTypeNumber<Data::Mean>>(); auto mean = std::make_shared<DataTypeNumber<Data::Mean>>();

View File

@ -448,7 +448,7 @@ AggregateFunctionPtr AggregateFunctionIf::getOwnNullAdapter(
/// Nullability of the last argument (condition) does not affect the nullability of the result (NULL is processed as false). /// Nullability of the last argument (condition) does not affect the nullability of the result (NULL is processed as false).
/// For other arguments it is as usual (at least one is NULL then the result is NULL if possible). /// For other arguments it is as usual (at least one is NULL then the result is NULL if possible).
bool return_type_is_nullable = !properties.returns_default_when_only_null && getReturnType()->canBeInsideNullable() bool return_type_is_nullable = !properties.returns_default_when_only_null && getResultType()->canBeInsideNullable()
&& std::any_of(arguments.begin(), arguments.end() - 1, [](const auto & element) { return element->isNullable(); }); && std::any_of(arguments.begin(), arguments.end() - 1, [](const auto & element) { return element->isNullable(); });
bool need_to_serialize_flag = return_type_is_nullable || properties.returns_default_when_only_null; bool need_to_serialize_flag = return_type_is_nullable || properties.returns_default_when_only_null;

View File

@ -36,7 +36,7 @@ private:
public: public:
AggregateFunctionIf(AggregateFunctionPtr nested, const DataTypes & types, const Array & params_) AggregateFunctionIf(AggregateFunctionPtr nested, const DataTypes & types, const Array & params_)
: IAggregateFunctionHelper<AggregateFunctionIf>(types, params_) : IAggregateFunctionHelper<AggregateFunctionIf>(types, params_, nested->getResultType())
, nested_func(nested), num_arguments(types.size()) , nested_func(nested), num_arguments(types.size())
{ {
if (num_arguments == 0) if (num_arguments == 0)
@ -51,11 +51,6 @@ public:
return nested_func->getName() + "If"; return nested_func->getName() + "If";
} }
DataTypePtr getReturnType() const override
{
return nested_func->getReturnType();
}
const IAggregateFunction & getBaseAggregateFunctionWithSameStateRepresentation() const override const IAggregateFunction & getBaseAggregateFunctionWithSameStateRepresentation() const override
{ {
return nested_func->getBaseAggregateFunctionWithSameStateRepresentation(); return nested_func->getBaseAggregateFunctionWithSameStateRepresentation();

View File

@ -177,11 +177,11 @@ public:
String getName() const override { return "intervalLengthSum"; } String getName() const override { return "intervalLengthSum"; }
explicit AggregateFunctionIntervalLengthSum(const DataTypes & arguments) explicit AggregateFunctionIntervalLengthSum(const DataTypes & arguments)
: IAggregateFunctionDataHelper<Data, AggregateFunctionIntervalLengthSum<T, Data>>(arguments, {}) : IAggregateFunctionDataHelper<Data, AggregateFunctionIntervalLengthSum<T, Data>>(arguments, {}, createResultType())
{ {
} }
DataTypePtr getReturnType() const override static DataTypePtr createResultType()
{ {
if constexpr (std::is_floating_point_v<T>) if constexpr (std::is_floating_point_v<T>)
return std::make_shared<DataTypeFloat64>(); return std::make_shared<DataTypeFloat64>();

View File

@ -309,7 +309,7 @@ public:
UInt64 batch_size_, UInt64 batch_size_,
const DataTypes & arguments_types, const DataTypes & arguments_types,
const Array & params) const Array & params)
: IAggregateFunctionDataHelper<Data, AggregateFunctionMLMethod<Data, Name>>(arguments_types, params) : IAggregateFunctionDataHelper<Data, AggregateFunctionMLMethod<Data, Name>>(arguments_types, params, createResultType())
, param_num(param_num_) , param_num(param_num_)
, learning_rate(learning_rate_) , learning_rate(learning_rate_)
, l2_reg_coef(l2_reg_coef_) , l2_reg_coef(l2_reg_coef_)
@ -319,8 +319,7 @@ public:
{ {
} }
/// This function is called when SELECT linearRegression(...) is called static DataTypePtr createResultType()
DataTypePtr getReturnType() const override
{ {
return std::make_shared<DataTypeArray>(std::make_shared<DataTypeFloat64>()); return std::make_shared<DataTypeArray>(std::make_shared<DataTypeFloat64>());
} }

View File

@ -133,7 +133,7 @@ private:
public: public:
explicit AggregateFunctionMannWhitney(const DataTypes & arguments, const Array & params) explicit AggregateFunctionMannWhitney(const DataTypes & arguments, const Array & params)
:IAggregateFunctionDataHelper<MannWhitneyData, AggregateFunctionMannWhitney> ({arguments}, {}) :IAggregateFunctionDataHelper<MannWhitneyData, AggregateFunctionMannWhitney> ({arguments}, {}, createResultType())
{ {
if (params.size() > 2) if (params.size() > 2)
throw Exception("Aggregate function " + getName() + " require two parameter or less", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); throw Exception("Aggregate function " + getName() + " require two parameter or less", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
@ -174,7 +174,7 @@ public:
bool allocatesMemoryInArena() const override { return true; } bool allocatesMemoryInArena() const override { return true; }
DataTypePtr getReturnType() const override static DataTypePtr createResultType()
{ {
DataTypes types DataTypes types
{ {

View File

@ -18,6 +18,7 @@
#include <Functions/FunctionHelpers.h> #include <Functions/FunctionHelpers.h>
#include <IO/ReadHelpers.h> #include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h> #include <IO/WriteHelpers.h>
#include "DataTypes/Serializations/ISerialization.h"
#include "base/types.h" #include "base/types.h"
#include <Common/Arena.h> #include <Common/Arena.h>
#include "AggregateFunctions/AggregateFunctionFactory.h" #include "AggregateFunctions/AggregateFunctionFactory.h"
@ -104,26 +105,32 @@ public:
return nested_func->getDefaultVersion(); return nested_func->getDefaultVersion();
} }
AggregateFunctionMap(AggregateFunctionPtr nested, const DataTypes & types) : Base(types, nested->getParameters()), nested_func(nested) AggregateFunctionMap(AggregateFunctionPtr nested, const DataTypes & types)
: Base(types, nested->getParameters(), std::make_shared<DataTypeMap>(DataTypes{getKeyType(types, nested), nested->getResultType()}))
, nested_func(nested)
{ {
if (types.empty()) key_type = getKeyType(types, nested_func);
throw Exception(
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function " + getName() + " requires at least one argument");
if (types.size() > 1)
throw Exception(
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function " + getName() + " requires only one map argument");
const auto * map_type = checkAndGetDataType<DataTypeMap>(types[0].get());
if (!map_type)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Aggregate function " + getName() + " requires map as argument");
key_type = map_type->getKeyType();
} }
String getName() const override { return nested_func->getName() + "Map"; } String getName() const override { return nested_func->getName() + "Map"; }
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeMap>(DataTypes{key_type, nested_func->getReturnType()}); } static DataTypePtr getKeyType(const DataTypes & types, const AggregateFunctionPtr & nested)
{
if (types.empty())
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Aggregate function {}Map requires at least one argument", nested->getName());
if (types.size() > 1)
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Aggregate function {}Map requires only one map argument", nested->getName());
const auto * map_type = checkAndGetDataType<DataTypeMap>(types[0].get());
if (!map_type)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Aggregate function {}Map requires map as argument", nested->getName());
return map_type->getKeyType();
}
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{ {

View File

@ -62,7 +62,8 @@ private:
public: public:
AggregateFunctionIntersectionsMax(AggregateFunctionIntersectionsKind kind_, const DataTypes & arguments) AggregateFunctionIntersectionsMax(AggregateFunctionIntersectionsKind kind_, const DataTypes & arguments)
: IAggregateFunctionDataHelper<MaxIntersectionsData<PointType>, AggregateFunctionIntersectionsMax<PointType>>(arguments, {}), kind(kind_) : IAggregateFunctionDataHelper<MaxIntersectionsData<PointType>, AggregateFunctionIntersectionsMax<PointType>>(arguments, {}, createResultType(kind_))
, kind(kind_)
{ {
if (!isNativeNumber(arguments[0])) if (!isNativeNumber(arguments[0]))
throw Exception{getName() + ": first argument must be represented by integer", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; throw Exception{getName() + ": first argument must be represented by integer", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
@ -81,9 +82,9 @@ public:
: "maxIntersectionsPosition"; : "maxIntersectionsPosition";
} }
DataTypePtr getReturnType() const override static DataTypePtr createResultType(AggregateFunctionIntersectionsKind kind_)
{ {
if (kind == AggregateFunctionIntersectionsKind::Count) if (kind_ == AggregateFunctionIntersectionsKind::Count)
return std::make_shared<DataTypeUInt64>(); return std::make_shared<DataTypeUInt64>();
else else
return std::make_shared<DataTypeNumber<PointType>>(); return std::make_shared<DataTypeNumber<PointType>>();

View File

@ -36,7 +36,7 @@ private:
public: public:
AggregateFunctionMeanZTest(const DataTypes & arguments, const Array & params) AggregateFunctionMeanZTest(const DataTypes & arguments, const Array & params)
: IAggregateFunctionDataHelper<Data, AggregateFunctionMeanZTest<Data>>({arguments}, params) : IAggregateFunctionDataHelper<Data, AggregateFunctionMeanZTest<Data>>({arguments}, params, createResultType())
{ {
pop_var_x = params.at(0).safeGet<Float64>(); pop_var_x = params.at(0).safeGet<Float64>();
pop_var_y = params.at(1).safeGet<Float64>(); pop_var_y = params.at(1).safeGet<Float64>();
@ -63,7 +63,7 @@ public:
return Data::name; return Data::name;
} }
DataTypePtr getReturnType() const override static DataTypePtr createResultType()
{ {
DataTypes types DataTypes types
{ {

View File

@ -30,7 +30,7 @@ private:
public: public:
AggregateFunctionMerge(const AggregateFunctionPtr & nested_, const DataTypePtr & argument, const Array & params_) AggregateFunctionMerge(const AggregateFunctionPtr & nested_, const DataTypePtr & argument, const Array & params_)
: IAggregateFunctionHelper<AggregateFunctionMerge>({argument}, params_) : IAggregateFunctionHelper<AggregateFunctionMerge>({argument}, params_, createResultType())
, nested_func(nested_) , nested_func(nested_)
{ {
const DataTypeAggregateFunction * data_type = typeid_cast<const DataTypeAggregateFunction *>(argument.get()); const DataTypeAggregateFunction * data_type = typeid_cast<const DataTypeAggregateFunction *>(argument.get());
@ -45,9 +45,9 @@ public:
return nested_func->getName() + "Merge"; return nested_func->getName() + "Merge";
} }
DataTypePtr getReturnType() const override DataTypePtr createResultType() const
{ {
return nested_func->getReturnType(); return nested_func->getResultType();
} }
const IAggregateFunction & getBaseAggregateFunctionWithSameStateRepresentation() const override const IAggregateFunction & getBaseAggregateFunctionWithSameStateRepresentation() const override

View File

@ -1219,7 +1219,7 @@ private:
public: public:
explicit AggregateFunctionsSingleValue(const DataTypePtr & type) explicit AggregateFunctionsSingleValue(const DataTypePtr & type)
: IAggregateFunctionDataHelper<Data, AggregateFunctionsSingleValue<Data>>({type}, {}) : IAggregateFunctionDataHelper<Data, AggregateFunctionsSingleValue<Data>>({type}, {}, createResultType())
, serialization(type->getDefaultSerialization()) , serialization(type->getDefaultSerialization())
{ {
if (StringRef(Data::name()) == StringRef("min") if (StringRef(Data::name()) == StringRef("min")
@ -1233,7 +1233,7 @@ public:
String getName() const override { return Data::name(); } String getName() const override { return Data::name(); }
DataTypePtr getReturnType() const override DataTypePtr createResultType() const
{ {
auto result_type = this->argument_types.at(0); auto result_type = this->argument_types.at(0);
if constexpr (Data::is_nullable) if constexpr (Data::is_nullable)

View File

@ -6,6 +6,7 @@
#include <AggregateFunctions/IAggregateFunction.h> #include <AggregateFunctions/IAggregateFunction.h>
#include <IO/ReadHelpers.h> #include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h> #include <IO/WriteHelpers.h>
#include "DataTypes/IDataType.h"
namespace DB namespace DB
@ -19,16 +20,16 @@ class AggregateFunctionNothing final : public IAggregateFunctionHelper<Aggregate
{ {
public: public:
AggregateFunctionNothing(const DataTypes & arguments, const Array & params) AggregateFunctionNothing(const DataTypes & arguments, const Array & params)
: IAggregateFunctionHelper<AggregateFunctionNothing>(arguments, params) {} : IAggregateFunctionHelper<AggregateFunctionNothing>(arguments, params, createResultType(arguments)) {}
String getName() const override String getName() const override
{ {
return "nothing"; return "nothing";
} }
DataTypePtr getReturnType() const override static DataTypePtr createResultType(const DataTypes & arguments)
{ {
return argument_types.empty() ? std::make_shared<DataTypeNullable>(std::make_shared<DataTypeNothing>()) : argument_types.front(); return arguments.empty() ? std::make_shared<DataTypeNullable>(std::make_shared<DataTypeNothing>()) : arguments.front();
} }
bool allocatesMemoryInArena() const override { return false; } bool allocatesMemoryInArena() const override { return false; }

View File

@ -87,7 +87,7 @@ public:
transformed_nested_function->getParameters()); transformed_nested_function->getParameters());
} }
bool return_type_is_nullable = !properties.returns_default_when_only_null && nested_function->getReturnType()->canBeInsideNullable(); bool return_type_is_nullable = !properties.returns_default_when_only_null && nested_function->getResultType()->canBeInsideNullable();
bool serialize_flag = return_type_is_nullable || properties.returns_default_when_only_null; bool serialize_flag = return_type_is_nullable || properties.returns_default_when_only_null;
if (arguments.size() == 1) if (arguments.size() == 1)

View File

@ -82,7 +82,8 @@ protected:
public: public:
AggregateFunctionNullBase(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params) AggregateFunctionNullBase(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params)
: IAggregateFunctionHelper<Derived>(arguments, params), nested_function{nested_function_} : IAggregateFunctionHelper<Derived>(arguments, params, createResultType(nested_function_))
, nested_function{nested_function_}
{ {
if constexpr (result_is_nullable) if constexpr (result_is_nullable)
prefix_size = nested_function->alignOfData(); prefix_size = nested_function->alignOfData();
@ -96,11 +97,11 @@ public:
return nested_function->getName(); return nested_function->getName();
} }
DataTypePtr getReturnType() const override static DataTypePtr createResultType(const AggregateFunctionPtr & nested_function_)
{ {
return result_is_nullable return result_is_nullable
? makeNullable(nested_function->getReturnType()) ? makeNullable(nested_function_->getResultType())
: nested_function->getReturnType(); : nested_function_->getResultType();
} }
void create(AggregateDataPtr __restrict place) const override void create(AggregateDataPtr __restrict place) const override
@ -270,7 +271,7 @@ public:
{ {
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder); llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * return_type = toNativeType(b, this->getReturnType()); auto * return_type = toNativeType(b, this->getResultType());
llvm::Value * result = nullptr; llvm::Value * result = nullptr;

View File

@ -4,6 +4,7 @@
#include <Columns/ColumnNullable.h> #include <Columns/ColumnNullable.h>
#include <Columns/ColumnsCommon.h> #include <Columns/ColumnsCommon.h>
#include <Common/typeid_cast.h> #include <Common/typeid_cast.h>
#include "DataTypes/Serializations/ISerialization.h"
#include <DataTypes/DataTypeNullable.h> #include <DataTypes/DataTypeNullable.h>
#include <IO/ReadHelpers.h> #include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h> #include <IO/WriteHelpers.h>
@ -30,16 +31,14 @@ private:
AggregateFunctionPtr nested_function; AggregateFunctionPtr nested_function;
size_t size_of_data; size_t size_of_data;
DataTypePtr inner_type;
bool inner_nullable; bool inner_nullable;
public: public:
AggregateFunctionOrFill(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params) AggregateFunctionOrFill(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params)
: IAggregateFunctionHelper<AggregateFunctionOrFill>{arguments, params} : IAggregateFunctionHelper<AggregateFunctionOrFill>{arguments, params, createResultType(nested_function_->getResultType())}
, nested_function{nested_function_} , nested_function{nested_function_}
, size_of_data {nested_function->sizeOfData()} , size_of_data {nested_function->sizeOfData()}
, inner_type {nested_function->getReturnType()} , inner_nullable {nested_function->getResultType()->isNullable()}
, inner_nullable {inner_type->isNullable()}
{ {
// nothing // nothing
} }
@ -246,22 +245,22 @@ public:
readChar(place[size_of_data], buf); readChar(place[size_of_data], buf);
} }
DataTypePtr getReturnType() const override static DataTypePtr createResultType(const DataTypePtr & inner_type_)
{ {
if constexpr (UseNull) if constexpr (UseNull)
{ {
// -OrNull // -OrNull
if (inner_nullable) if (inner_type_->isNullable())
return inner_type; return inner_type_;
return std::make_shared<DataTypeNullable>(inner_type); return std::make_shared<DataTypeNullable>(inner_type_);
} }
else else
{ {
// -OrDefault // -OrDefault
return inner_type; return inner_type_;
} }
} }

View File

@ -72,7 +72,7 @@ private:
public: public:
AggregateFunctionQuantile(const DataTypes & argument_types_, const Array & params) AggregateFunctionQuantile(const DataTypes & argument_types_, const Array & params)
: IAggregateFunctionDataHelper<Data, AggregateFunctionQuantile<Value, Data, Name, has_second_arg, FloatReturnType, returns_many>>( : IAggregateFunctionDataHelper<Data, AggregateFunctionQuantile<Value, Data, Name, has_second_arg, FloatReturnType, returns_many>>(
argument_types_, params) argument_types_, params, createResultType(argument_types_))
, levels(params, returns_many) , levels(params, returns_many)
, level(levels.levels[0]) , level(levels.levels[0])
, argument_type(this->argument_types[0]) , argument_type(this->argument_types[0])
@ -83,14 +83,14 @@ public:
String getName() const override { return Name::name; } String getName() const override { return Name::name; }
DataTypePtr getReturnType() const override static DataTypePtr createResultType(const DataTypes & argument_types_)
{ {
DataTypePtr res; DataTypePtr res;
if constexpr (returns_float) if constexpr (returns_float)
res = std::make_shared<DataTypeNumber<FloatReturnType>>(); res = std::make_shared<DataTypeNumber<FloatReturnType>>();
else else
res = argument_type; res = argument_types_[0];
if constexpr (returns_many) if constexpr (returns_many)
return std::make_shared<DataTypeArray>(res); return std::make_shared<DataTypeArray>(res);

View File

@ -51,7 +51,7 @@ class AggregateFunctionRankCorrelation :
{ {
public: public:
explicit AggregateFunctionRankCorrelation(const DataTypes & arguments) explicit AggregateFunctionRankCorrelation(const DataTypes & arguments)
:IAggregateFunctionDataHelper<RankCorrelationData, AggregateFunctionRankCorrelation> ({arguments}, {}) :IAggregateFunctionDataHelper<RankCorrelationData, AggregateFunctionRankCorrelation> ({arguments}, {}, std::make_shared<DataTypeNumber<Float64>>())
{} {}
String getName() const override String getName() const override
@ -61,11 +61,6 @@ public:
bool allocatesMemoryInArena() const override { return true; } bool allocatesMemoryInArena() const override { return true; }
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeNumber<Float64>>();
}
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{ {
Float64 new_x = columns[0]->getFloat64(row_num); Float64 new_x = columns[0]->getFloat64(row_num);

View File

@ -43,7 +43,7 @@ public:
size_t step_, size_t step_,
const DataTypes & arguments, const DataTypes & arguments,
const Array & params) const Array & params)
: IAggregateFunctionHelper<AggregateFunctionResample<Key>>{arguments, params} : IAggregateFunctionHelper<AggregateFunctionResample<Key>>{arguments, params, createResultType(nested_function_)}
, nested_function{nested_function_} , nested_function{nested_function_}
, last_col{arguments.size() - 1} , last_col{arguments.size() - 1}
, begin{begin_} , begin{begin_}
@ -190,9 +190,9 @@ public:
nested_function->deserialize(place + i * size_of_data, buf, version, arena); nested_function->deserialize(place + i * size_of_data, buf, version, arena);
} }
DataTypePtr getReturnType() const override static DataTypePtr createResultType(const AggregateFunctionPtr & nested_function_)
{ {
return std::make_shared<DataTypeArray>(nested_function->getReturnType()); return std::make_shared<DataTypeArray>(nested_function_->getResultType());
} }
template <bool merge> template <bool merge>

View File

@ -76,7 +76,7 @@ public:
} }
explicit AggregateFunctionRetention(const DataTypes & arguments) explicit AggregateFunctionRetention(const DataTypes & arguments)
: IAggregateFunctionDataHelper<AggregateFunctionRetentionData, AggregateFunctionRetention>(arguments, {}) : IAggregateFunctionDataHelper<AggregateFunctionRetentionData, AggregateFunctionRetention>(arguments, {}, std::make_shared<DataTypeArray>(std::make_shared<DataTypeUInt8>()))
{ {
for (const auto i : collections::range(0, arguments.size())) for (const auto i : collections::range(0, arguments.size()))
{ {
@ -90,12 +90,6 @@ public:
events_size = static_cast<UInt8>(arguments.size()); events_size = static_cast<UInt8>(arguments.size());
} }
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeArray>(std::make_shared<DataTypeUInt8>());
}
bool allocatesMemoryInArena() const override { return false; } bool allocatesMemoryInArena() const override { return false; }
void add(AggregateDataPtr __restrict place, const IColumn ** columns, const size_t row_num, Arena *) const override void add(AggregateDataPtr __restrict place, const IColumn ** columns, const size_t row_num, Arena *) const override

View File

@ -8,6 +8,7 @@
#include <base/range.h> #include <base/range.h>
#include <base/sort.h> #include <base/sort.h>
#include <Common/PODArray.h> #include <Common/PODArray.h>
#include "DataTypes/Serializations/ISerialization.h"
#include <IO/ReadHelpers.h> #include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h> #include <IO/WriteHelpers.h>
#include <bitset> #include <bitset>
@ -126,8 +127,8 @@ template <typename T, typename Data, typename Derived>
class AggregateFunctionSequenceBase : public IAggregateFunctionDataHelper<Data, Derived> class AggregateFunctionSequenceBase : public IAggregateFunctionDataHelper<Data, Derived>
{ {
public: public:
AggregateFunctionSequenceBase(const DataTypes & arguments, const Array & params, const String & pattern_) AggregateFunctionSequenceBase(const DataTypes & arguments, const Array & params, const String & pattern_, const DataTypePtr & result_type_)
: IAggregateFunctionDataHelper<Data, Derived>(arguments, params) : IAggregateFunctionDataHelper<Data, Derived>(arguments, params, result_type_)
, pattern(pattern_) , pattern(pattern_)
{ {
arg_count = arguments.size(); arg_count = arguments.size();
@ -617,14 +618,12 @@ class AggregateFunctionSequenceMatch final : public AggregateFunctionSequenceBas
{ {
public: public:
AggregateFunctionSequenceMatch(const DataTypes & arguments, const Array & params, const String & pattern_) AggregateFunctionSequenceMatch(const DataTypes & arguments, const Array & params, const String & pattern_)
: AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceMatch<T, Data>>(arguments, params, pattern_) {} : AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceMatch<T, Data>>(arguments, params, pattern_, std::make_shared<DataTypeUInt8>()) {}
using AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceMatch<T, Data>>::AggregateFunctionSequenceBase; using AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceMatch<T, Data>>::AggregateFunctionSequenceBase;
String getName() const override { return "sequenceMatch"; } String getName() const override { return "sequenceMatch"; }
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeUInt8>(); }
bool allocatesMemoryInArena() const override { return false; } bool allocatesMemoryInArena() const override { return false; }
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
@ -655,14 +654,12 @@ class AggregateFunctionSequenceCount final : public AggregateFunctionSequenceBas
{ {
public: public:
AggregateFunctionSequenceCount(const DataTypes & arguments, const Array & params, const String & pattern_) AggregateFunctionSequenceCount(const DataTypes & arguments, const Array & params, const String & pattern_)
: AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceCount<T, Data>>(arguments, params, pattern_) {} : AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceCount<T, Data>>(arguments, params, pattern_, std::make_shared<DataTypeUInt64>()) {}
using AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceCount<T, Data>>::AggregateFunctionSequenceBase; using AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceCount<T, Data>>::AggregateFunctionSequenceBase;
String getName() const override { return "sequenceCount"; } String getName() const override { return "sequenceCount"; }
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeUInt64>(); }
bool allocatesMemoryInArena() const override { return false; } bool allocatesMemoryInArena() const override { return false; }
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override

View File

@ -190,7 +190,7 @@ public:
SequenceDirection seq_direction_, SequenceDirection seq_direction_,
size_t min_required_args_, size_t min_required_args_,
UInt64 max_elems_ = std::numeric_limits<UInt64>::max()) UInt64 max_elems_ = std::numeric_limits<UInt64>::max())
: IAggregateFunctionDataHelper<SequenceNextNodeGeneralData<Node>, Self>({data_type_}, parameters_) : IAggregateFunctionDataHelper<SequenceNextNodeGeneralData<Node>, Self>({data_type_}, parameters_, data_type_)
, seq_base_kind(seq_base_kind_) , seq_base_kind(seq_base_kind_)
, seq_direction(seq_direction_) , seq_direction(seq_direction_)
, min_required_args(min_required_args_) , min_required_args(min_required_args_)
@ -202,8 +202,6 @@ public:
String getName() const override { return "sequenceNextNode"; } String getName() const override { return "sequenceNextNode"; }
DataTypePtr getReturnType() const override { return data_type; }
bool haveSameStateRepresentationImpl(const IAggregateFunction & rhs) const override bool haveSameStateRepresentationImpl(const IAggregateFunction & rhs) const override
{ {
return this->getName() == rhs.getName() && this->haveEqualArgumentTypes(rhs); return this->getName() == rhs.getName() && this->haveEqualArgumentTypes(rhs);

View File

@ -99,7 +99,7 @@ public:
IAggregateFunctionDataHelper< IAggregateFunctionDataHelper<
AggregateFunctionSimpleLinearRegressionData<Ret>, AggregateFunctionSimpleLinearRegressionData<Ret>,
AggregateFunctionSimpleLinearRegression<X, Y, Ret> AggregateFunctionSimpleLinearRegression<X, Y, Ret>
> {arguments, params} > {arguments, params, createResultType()}
{ {
// notice: arguments has been checked before // notice: arguments has been checked before
} }
@ -140,7 +140,7 @@ public:
this->data(place).deserialize(buf); this->data(place).deserialize(buf);
} }
DataTypePtr getReturnType() const override static DataTypePtr createResultType()
{ {
DataTypes types DataTypes types
{ {

View File

@ -20,28 +20,28 @@ private:
public: public:
AggregateFunctionSimpleState(AggregateFunctionPtr nested_, const DataTypes & arguments_, const Array & params_) AggregateFunctionSimpleState(AggregateFunctionPtr nested_, const DataTypes & arguments_, const Array & params_)
: IAggregateFunctionHelper<AggregateFunctionSimpleState>(arguments_, params_) : IAggregateFunctionHelper<AggregateFunctionSimpleState>(arguments_, params_, createResultType(nested_, params_))
, nested_func(nested_) , nested_func(nested_)
{ {
} }
String getName() const override { return nested_func->getName() + "SimpleState"; } String getName() const override { return nested_func->getName() + "SimpleState"; }
DataTypePtr getReturnType() const override static DataTypePtr createResultType(const AggregateFunctionPtr & nested_, const Array & params_)
{ {
DataTypeCustomSimpleAggregateFunction::checkSupportedFunctions(nested_func); DataTypeCustomSimpleAggregateFunction::checkSupportedFunctions(nested_);
// Need to make a clone to avoid recursive reference. // Need to make a clone to avoid recursive reference.
auto storage_type_out = DataTypeFactory::instance().get(nested_func->getReturnType()->getName()); auto storage_type_out = DataTypeFactory::instance().get(nested_->getResultType()->getName());
// Need to make a new function with promoted argument types because SimpleAggregates requires arg_type = return_type. // Need to make a new function with promoted argument types because SimpleAggregates requires arg_type = return_type.
AggregateFunctionProperties properties; AggregateFunctionProperties properties;
auto function auto function
= AggregateFunctionFactory::instance().get(nested_func->getName(), {storage_type_out}, nested_func->getParameters(), properties); = AggregateFunctionFactory::instance().get(nested_->getName(), {storage_type_out}, nested_->getParameters(), properties);
// Need to make a clone because it'll be customized. // Need to make a clone because it'll be customized.
auto storage_type_arg = DataTypeFactory::instance().get(nested_func->getReturnType()->getName()); auto storage_type_arg = DataTypeFactory::instance().get(nested_->getResultType()->getName());
DataTypeCustomNamePtr custom_name DataTypeCustomNamePtr custom_name
= std::make_unique<DataTypeCustomSimpleAggregateFunction>(function, DataTypes{nested_func->getReturnType()}, parameters); = std::make_unique<DataTypeCustomSimpleAggregateFunction>(function, DataTypes{nested_->getResultType()}, params_);
storage_type_arg->setCustomization(std::make_unique<DataTypeCustomDesc>(std::move(custom_name), nullptr)); storage_type_arg->setCustomization(std::make_unique<DataTypeCustomDesc>(std::move(custom_name), nullptr));
return storage_type_arg; return storage_type_arg;
} }

View File

@ -261,7 +261,7 @@ private:
public: public:
AggregateFunctionSparkbar(const DataTypes & arguments, const Array & params) AggregateFunctionSparkbar(const DataTypes & arguments, const Array & params)
: IAggregateFunctionDataHelper<AggregateFunctionSparkbarData<X, Y>, AggregateFunctionSparkbar>( : IAggregateFunctionDataHelper<AggregateFunctionSparkbarData<X, Y>, AggregateFunctionSparkbar>(
arguments, params) arguments, params, std::make_shared<DataTypeString>())
{ {
width = params.at(0).safeGet<UInt64>(); width = params.at(0).safeGet<UInt64>();
if (params.size() == 3) if (params.size() == 3)
@ -283,11 +283,6 @@ public:
return "sparkbar"; return "sparkbar";
} }
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeString>();
}
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * /*arena*/) const override void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * /*arena*/) const override
{ {
X x = assert_cast<const ColumnVector<X> *>(columns[0])->getData()[row_num]; X x = assert_cast<const ColumnVector<X> *>(columns[0])->getData()[row_num];

View File

@ -23,7 +23,7 @@ private:
public: public:
AggregateFunctionState(AggregateFunctionPtr nested_, const DataTypes & arguments_, const Array & params_) AggregateFunctionState(AggregateFunctionPtr nested_, const DataTypes & arguments_, const Array & params_)
: IAggregateFunctionHelper<AggregateFunctionState>(arguments_, params_) : IAggregateFunctionHelper<AggregateFunctionState>(arguments_, params_, createResultType())
, nested_func(nested_) , nested_func(nested_)
{} {}
@ -32,7 +32,7 @@ public:
return nested_func->getName() + "State"; return nested_func->getName() + "State";
} }
DataTypePtr getReturnType() const override DataTypePtr createResultType() const
{ {
return getStateType(); return getStateType();
} }

View File

@ -115,15 +115,11 @@ class AggregateFunctionVariance final
{ {
public: public:
explicit AggregateFunctionVariance(const DataTypePtr & arg) explicit AggregateFunctionVariance(const DataTypePtr & arg)
: IAggregateFunctionDataHelper<AggregateFunctionVarianceData<T, Op>, AggregateFunctionVariance<T, Op>>({arg}, {}) {} : IAggregateFunctionDataHelper<AggregateFunctionVarianceData<T, Op>, AggregateFunctionVariance<T, Op>>({arg}, {}, std::make_shared<DataTypeFloat64>())
{}
String getName() const override { return Op::name; } String getName() const override { return Op::name; }
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeFloat64>();
}
bool allocatesMemoryInArena() const override { return false; } bool allocatesMemoryInArena() const override { return false; }
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
@ -368,15 +364,11 @@ class AggregateFunctionCovariance final
public: public:
explicit AggregateFunctionCovariance(const DataTypes & args) : IAggregateFunctionDataHelper< explicit AggregateFunctionCovariance(const DataTypes & args) : IAggregateFunctionDataHelper<
CovarianceData<T, U, Op, compute_marginal_moments>, CovarianceData<T, U, Op, compute_marginal_moments>,
AggregateFunctionCovariance<T, U, Op, compute_marginal_moments>>(args, {}) {} AggregateFunctionCovariance<T, U, Op, compute_marginal_moments>>(args, {}, std::make_shared<DataTypeFloat64>())
{}
String getName() const override { return Op::name; } String getName() const override { return Op::name; }
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeFloat64>();
}
bool allocatesMemoryInArena() const override { return false; } bool allocatesMemoryInArena() const override { return false; }
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override

View File

@ -81,12 +81,12 @@ public:
using ColVecResult = ColumnVector<ResultType>; using ColVecResult = ColumnVector<ResultType>;
explicit AggregateFunctionVarianceSimple(const DataTypes & argument_types_) explicit AggregateFunctionVarianceSimple(const DataTypes & argument_types_)
: IAggregateFunctionDataHelper<typename StatFunc::Data, AggregateFunctionVarianceSimple<StatFunc>>(argument_types_, {}) : IAggregateFunctionDataHelper<typename StatFunc::Data, AggregateFunctionVarianceSimple<StatFunc>>(argument_types_, {}, std::make_shared<DataTypeNumber<ResultType>>())
, src_scale(0) , src_scale(0)
{} {}
AggregateFunctionVarianceSimple(const IDataType & data_type, const DataTypes & argument_types_) AggregateFunctionVarianceSimple(const IDataType & data_type, const DataTypes & argument_types_)
: IAggregateFunctionDataHelper<typename StatFunc::Data, AggregateFunctionVarianceSimple<StatFunc>>(argument_types_, {}) : IAggregateFunctionDataHelper<typename StatFunc::Data, AggregateFunctionVarianceSimple<StatFunc>>(argument_types_, {}, std::make_shared<DataTypeNumber<ResultType>>())
, src_scale(getDecimalScale(data_type)) , src_scale(getDecimalScale(data_type))
{} {}
@ -117,11 +117,6 @@ public:
UNREACHABLE(); UNREACHABLE();
} }
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeNumber<ResultType>>();
}
bool allocatesMemoryInArena() const override { return false; } bool allocatesMemoryInArena() const override { return false; }
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override

View File

@ -411,23 +411,21 @@ public:
} }
explicit AggregateFunctionSum(const DataTypes & argument_types_) explicit AggregateFunctionSum(const DataTypes & argument_types_)
: IAggregateFunctionDataHelper<Data, AggregateFunctionSum<T, TResult, Data, Type>>(argument_types_, {}) : IAggregateFunctionDataHelper<Data, AggregateFunctionSum<T, TResult, Data, Type>>(argument_types_, {}, createResultType(0))
, scale(0)
{} {}
AggregateFunctionSum(const IDataType & data_type, const DataTypes & argument_types_) AggregateFunctionSum(const IDataType & data_type, const DataTypes & argument_types_)
: IAggregateFunctionDataHelper<Data, AggregateFunctionSum<T, TResult, Data, Type>>(argument_types_, {}) : IAggregateFunctionDataHelper<Data, AggregateFunctionSum<T, TResult, Data, Type>>(argument_types_, {}, createResultType(getDecimalScale(data_type)))
, scale(getDecimalScale(data_type))
{} {}
DataTypePtr getReturnType() const override static DataTypePtr createResultType(UInt32 scale_)
{ {
if constexpr (!is_decimal<T>) if constexpr (!is_decimal<T>)
return std::make_shared<DataTypeNumber<TResult>>(); return std::make_shared<DataTypeNumber<TResult>>();
else else
{ {
using DataType = DataTypeDecimal<TResult>; using DataType = DataTypeDecimal<TResult>;
return std::make_shared<DataType>(DataType::maxPrecision(), scale); return std::make_shared<DataType>(DataType::maxPrecision(), scale_);
} }
} }
@ -548,7 +546,7 @@ public:
for (const auto & argument_type : this->argument_types) for (const auto & argument_type : this->argument_types)
can_be_compiled &= canBeNativeType(*argument_type); can_be_compiled &= canBeNativeType(*argument_type);
auto return_type = getReturnType(); auto return_type = this->getResultType();
can_be_compiled &= canBeNativeType(*return_type); can_be_compiled &= canBeNativeType(*return_type);
return can_be_compiled; return can_be_compiled;
@ -558,7 +556,7 @@ public:
{ {
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder); llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * return_type = toNativeType(b, getReturnType()); auto * return_type = toNativeType(b, this->getResultType());
auto * aggregate_sum_ptr = aggregate_data_ptr; auto * aggregate_sum_ptr = aggregate_data_ptr;
b.CreateStore(llvm::Constant::getNullValue(return_type), aggregate_sum_ptr); b.CreateStore(llvm::Constant::getNullValue(return_type), aggregate_sum_ptr);
@ -568,7 +566,7 @@ public:
{ {
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder); llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * return_type = toNativeType(b, getReturnType()); auto * return_type = toNativeType(b, this->getResultType());
auto * sum_value_ptr = aggregate_data_ptr; auto * sum_value_ptr = aggregate_data_ptr;
auto * sum_value = b.CreateLoad(return_type, sum_value_ptr); auto * sum_value = b.CreateLoad(return_type, sum_value_ptr);
@ -586,7 +584,7 @@ public:
{ {
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder); llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * return_type = toNativeType(b, getReturnType()); auto * return_type = toNativeType(b, this->getResultType());
auto * sum_value_dst_ptr = aggregate_data_dst_ptr; auto * sum_value_dst_ptr = aggregate_data_dst_ptr;
auto * sum_value_dst = b.CreateLoad(return_type, sum_value_dst_ptr); auto * sum_value_dst = b.CreateLoad(return_type, sum_value_dst_ptr);
@ -602,7 +600,7 @@ public:
{ {
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder); llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * return_type = toNativeType(b, getReturnType()); auto * return_type = toNativeType(b, this->getResultType());
auto * sum_value_ptr = aggregate_data_ptr; auto * sum_value_ptr = aggregate_data_ptr;
return b.CreateLoad(return_type, sum_value_ptr); return b.CreateLoad(return_type, sum_value_ptr);
@ -611,8 +609,6 @@ public:
#endif #endif
private: private:
UInt32 scale;
static constexpr auto & castColumnToResult(IColumn & to) static constexpr auto & castColumnToResult(IColumn & to)
{ {
if constexpr (is_decimal<T>) if constexpr (is_decimal<T>)

View File

@ -14,12 +14,13 @@ public:
using Base = AggregateFunctionAvg<T>; using Base = AggregateFunctionAvg<T>;
explicit AggregateFunctionSumCount(const DataTypes & argument_types_, UInt32 num_scale_ = 0) explicit AggregateFunctionSumCount(const DataTypes & argument_types_, UInt32 num_scale_ = 0)
: Base(argument_types_, num_scale_), scale(num_scale_) {} : Base(argument_types_, createResultType(num_scale_), num_scale_)
{}
DataTypePtr getReturnType() const override static DataTypePtr createResultType(UInt32 num_scale_)
{ {
auto second_elem = std::make_shared<DataTypeUInt64>(); auto second_elem = std::make_shared<DataTypeUInt64>();
return std::make_shared<DataTypeTuple>(DataTypes{getReturnTypeFirstElement(), std::move(second_elem)}); return std::make_shared<DataTypeTuple>(DataTypes{getReturnTypeFirstElement(num_scale_), std::move(second_elem)});
} }
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const final void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const final
@ -43,9 +44,7 @@ public:
#endif #endif
private: private:
UInt32 scale; static auto getReturnTypeFirstElement(UInt32 num_scale_)
auto getReturnTypeFirstElement() const
{ {
using FieldType = AvgFieldType<T>; using FieldType = AvgFieldType<T>;
@ -54,7 +53,7 @@ private:
else else
{ {
using DataType = DataTypeDecimal<FieldType>; using DataType = DataTypeDecimal<FieldType>;
return std::make_shared<DataType>(DataType::maxPrecision(), scale); return std::make_shared<DataType>(DataType::maxPrecision(), num_scale_);
} }
} }
}; };

View File

@ -80,7 +80,7 @@ public:
AggregateFunctionMapBase(const DataTypePtr & keys_type_, AggregateFunctionMapBase(const DataTypePtr & keys_type_,
const DataTypes & values_types_, const DataTypes & argument_types_) const DataTypes & values_types_, const DataTypes & argument_types_)
: Base(argument_types_, {} /* parameters */) : Base(argument_types_, {} /* parameters */, createResultType(keys_type_, values_types_, getName()))
, keys_type(keys_type_) , keys_type(keys_type_)
, keys_serialization(keys_type->getDefaultSerialization()) , keys_serialization(keys_type->getDefaultSerialization())
, values_types(values_types_) , values_types(values_types_)
@ -117,19 +117,22 @@ public:
return 0; return 0;
} }
DataTypePtr getReturnType() const override static DataTypePtr createResultType(
const DataTypePtr & keys_type_,
const DataTypes & values_types_,
const String & name_)
{ {
DataTypes types; DataTypes types;
types.emplace_back(std::make_shared<DataTypeArray>(keys_type)); types.emplace_back(std::make_shared<DataTypeArray>(keys_type_));
for (const auto & value_type : values_types) for (const auto & value_type : values_types_)
{ {
if constexpr (std::is_same_v<Visitor, FieldVisitorSum>) if constexpr (std::is_same_v<Visitor, FieldVisitorSum>)
{ {
if (!value_type->isSummable()) if (!value_type->isSummable())
throw Exception{ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, throw Exception{ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Values for {} cannot be summed, passed type {}", "Values for {} cannot be summed, passed type {}",
getName(), value_type->getName()}; name_, value_type->getName()};
} }
DataTypePtr result_type; DataTypePtr result_type;
@ -139,7 +142,7 @@ public:
if (value_type->onlyNull()) if (value_type->onlyNull())
throw Exception{ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, throw Exception{ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Cannot calculate {} of type {}", "Cannot calculate {} of type {}",
getName(), value_type->getName()}; name_, value_type->getName()};
// Overflow, meaning that the returned type is the same as // Overflow, meaning that the returned type is the same as
// the input type. Nulls are skipped. // the input type. Nulls are skipped.
@ -153,7 +156,7 @@ public:
if (!value_type_without_nullable->canBePromoted()) if (!value_type_without_nullable->canBePromoted())
throw Exception{ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, throw Exception{ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Values for {} are expected to be Numeric, Float or Decimal, passed type {}", "Values for {} are expected to be Numeric, Float or Decimal, passed type {}",
getName(), value_type->getName()}; name_, value_type->getName()};
WhichDataType value_type_to_check(value_type_without_nullable); WhichDataType value_type_to_check(value_type_without_nullable);

View File

@ -46,7 +46,7 @@ private:
Float64 confidence_level; Float64 confidence_level;
public: public:
AggregateFunctionTTest(const DataTypes & arguments, const Array & params) AggregateFunctionTTest(const DataTypes & arguments, const Array & params)
: IAggregateFunctionDataHelper<Data, AggregateFunctionTTest<Data>>({arguments}, params) : IAggregateFunctionDataHelper<Data, AggregateFunctionTTest<Data>>({arguments}, params, createResultType(!params.empty()))
{ {
if (!params.empty()) if (!params.empty())
{ {
@ -71,9 +71,9 @@ public:
return Data::name; return Data::name;
} }
DataTypePtr getReturnType() const override static DataTypePtr createResultType(bool need_confidence_interval_)
{ {
if (need_confidence_interval) if (need_confidence_interval_)
{ {
DataTypes types DataTypes types
{ {

View File

@ -31,15 +31,33 @@ namespace
template <bool is_weighted> template <bool is_weighted>
class AggregateFunctionTopKDate : public AggregateFunctionTopK<DataTypeDate::FieldType, is_weighted> class AggregateFunctionTopKDate : public AggregateFunctionTopK<DataTypeDate::FieldType, is_weighted>
{ {
public:
using AggregateFunctionTopK<DataTypeDate::FieldType, is_weighted>::AggregateFunctionTopK; using AggregateFunctionTopK<DataTypeDate::FieldType, is_weighted>::AggregateFunctionTopK;
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeArray>(std::make_shared<DataTypeDate>()); }
AggregateFunctionTopKDate(UInt64 threshold_, UInt64 load_factor, const DataTypes & argument_types_, const Array & params)
: AggregateFunctionTopK<DataTypeDate::FieldType, is_weighted>(
threshold_,
load_factor,
argument_types_,
params,
std::make_shared<DataTypeArray>(std::make_shared<DataTypeDate>()))
{}
}; };
template <bool is_weighted> template <bool is_weighted>
class AggregateFunctionTopKDateTime : public AggregateFunctionTopK<DataTypeDateTime::FieldType, is_weighted> class AggregateFunctionTopKDateTime : public AggregateFunctionTopK<DataTypeDateTime::FieldType, is_weighted>
{ {
public:
using AggregateFunctionTopK<DataTypeDateTime::FieldType, is_weighted>::AggregateFunctionTopK; using AggregateFunctionTopK<DataTypeDateTime::FieldType, is_weighted>::AggregateFunctionTopK;
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeArray>(std::make_shared<DataTypeDateTime>()); }
AggregateFunctionTopKDateTime(UInt64 threshold_, UInt64 load_factor, const DataTypes & argument_types_, const Array & params)
: AggregateFunctionTopK<DataTypeDateTime::FieldType, is_weighted>(
threshold_,
load_factor,
argument_types_,
params,
std::make_shared<DataTypeArray>(std::make_shared<DataTypeDateTime>()))
{}
}; };

View File

@ -11,6 +11,7 @@
#include <Common/SpaceSaving.h> #include <Common/SpaceSaving.h>
#include <Common/assert_cast.h> #include <Common/assert_cast.h>
#include "DataTypes/Serializations/ISerialization.h"
#include <AggregateFunctions/IAggregateFunction.h> #include <AggregateFunctions/IAggregateFunction.h>
@ -40,14 +41,20 @@ protected:
public: public:
AggregateFunctionTopK(UInt64 threshold_, UInt64 load_factor, const DataTypes & argument_types_, const Array & params) AggregateFunctionTopK(UInt64 threshold_, UInt64 load_factor, const DataTypes & argument_types_, const Array & params)
: IAggregateFunctionDataHelper<AggregateFunctionTopKData<T>, AggregateFunctionTopK<T, is_weighted>>(argument_types_, params) : IAggregateFunctionDataHelper<AggregateFunctionTopKData<T>, AggregateFunctionTopK<T, is_weighted>>(argument_types_, params, createResultType(argument_types_))
, threshold(threshold_), reserved(load_factor * threshold) {} , threshold(threshold_), reserved(load_factor * threshold)
{}
AggregateFunctionTopK(UInt64 threshold_, UInt64 load_factor, const DataTypes & argument_types_, const Array & params, const DataTypePtr & result_type_)
: IAggregateFunctionDataHelper<AggregateFunctionTopKData<T>, AggregateFunctionTopK<T, is_weighted>>(argument_types_, params, result_type_)
, threshold(threshold_), reserved(load_factor * threshold)
{}
String getName() const override { return is_weighted ? "topKWeighted" : "topK"; } String getName() const override { return is_weighted ? "topKWeighted" : "topK"; }
DataTypePtr getReturnType() const override static DataTypePtr createResultType(const DataTypes & argument_types_)
{ {
return std::make_shared<DataTypeArray>(this->argument_types[0]); return std::make_shared<DataTypeArray>(argument_types_[0]);
} }
bool allocatesMemoryInArena() const override { return false; } bool allocatesMemoryInArena() const override { return false; }
@ -126,21 +133,20 @@ private:
UInt64 threshold; UInt64 threshold;
UInt64 reserved; UInt64 reserved;
DataTypePtr & input_data_type;
static void deserializeAndInsert(StringRef str, IColumn & data_to); static void deserializeAndInsert(StringRef str, IColumn & data_to);
public: public:
AggregateFunctionTopKGeneric( AggregateFunctionTopKGeneric(
UInt64 threshold_, UInt64 load_factor, const DataTypes & argument_types_, const Array & params) UInt64 threshold_, UInt64 load_factor, const DataTypes & argument_types_, const Array & params)
: IAggregateFunctionDataHelper<AggregateFunctionTopKGenericData, AggregateFunctionTopKGeneric<is_plain_column, is_weighted>>(argument_types_, params) : IAggregateFunctionDataHelper<AggregateFunctionTopKGenericData, AggregateFunctionTopKGeneric<is_plain_column, is_weighted>>(argument_types_, params, createResultType(argument_types_))
, threshold(threshold_), reserved(load_factor * threshold), input_data_type(this->argument_types[0]) {} , threshold(threshold_), reserved(load_factor * threshold) {}
String getName() const override { return is_weighted ? "topKWeighted" : "topK"; } String getName() const override { return is_weighted ? "topKWeighted" : "topK"; }
DataTypePtr getReturnType() const override static DataTypePtr createResultType(const DataTypes & argument_types_)
{ {
return std::make_shared<DataTypeArray>(input_data_type); return std::make_shared<DataTypeArray>(argument_types_[0]);
} }
bool allocatesMemoryInArena() const override bool allocatesMemoryInArena() const override

View File

@ -358,17 +358,12 @@ private:
public: public:
explicit AggregateFunctionUniq(const DataTypes & argument_types_) explicit AggregateFunctionUniq(const DataTypes & argument_types_)
: IAggregateFunctionDataHelper<Data, AggregateFunctionUniq<T, Data>>(argument_types_, {}) : IAggregateFunctionDataHelper<Data, AggregateFunctionUniq<T, Data>>(argument_types_, {}, std::make_shared<DataTypeUInt64>())
{ {
} }
String getName() const override { return Data::getName(); } String getName() const override { return Data::getName(); }
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeUInt64>();
}
bool allocatesMemoryInArena() const override { return false; } bool allocatesMemoryInArena() const override { return false; }
/// ALWAYS_INLINE is required to have better code layout for uniqHLL12 function /// ALWAYS_INLINE is required to have better code layout for uniqHLL12 function
@ -462,7 +457,7 @@ private:
public: public:
explicit AggregateFunctionUniqVariadic(const DataTypes & arguments) explicit AggregateFunctionUniqVariadic(const DataTypes & arguments)
: IAggregateFunctionDataHelper<Data, AggregateFunctionUniqVariadic<Data>>(arguments, {}) : IAggregateFunctionDataHelper<Data, AggregateFunctionUniqVariadic<Data>>(arguments, {}, std::make_shared<DataTypeUInt64>())
{ {
if (argument_is_tuple) if (argument_is_tuple)
num_args = typeid_cast<const DataTypeTuple &>(*arguments[0]).getElements().size(); num_args = typeid_cast<const DataTypeTuple &>(*arguments[0]).getElements().size();
@ -472,11 +467,6 @@ public:
String getName() const override { return Data::getName(); } String getName() const override { return Data::getName(); }
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeUInt64>();
}
bool allocatesMemoryInArena() const override { return false; } bool allocatesMemoryInArena() const override { return false; }
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override

View File

@ -126,7 +126,8 @@ class AggregateFunctionUniqCombined final
{ {
public: public:
AggregateFunctionUniqCombined(const DataTypes & argument_types_, const Array & params_) AggregateFunctionUniqCombined(const DataTypes & argument_types_, const Array & params_)
: IAggregateFunctionDataHelper<AggregateFunctionUniqCombinedData<T, K, HashValueType>, AggregateFunctionUniqCombined<T, K, HashValueType>>(argument_types_, params_) {} : IAggregateFunctionDataHelper<AggregateFunctionUniqCombinedData<T, K, HashValueType>, AggregateFunctionUniqCombined<T, K, HashValueType>>(argument_types_, params_, std::make_shared<DataTypeUInt64>())
{}
String getName() const override String getName() const override
{ {
@ -136,11 +137,6 @@ public:
return "uniqCombined"; return "uniqCombined";
} }
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeUInt64>();
}
bool allocatesMemoryInArena() const override { return false; } bool allocatesMemoryInArena() const override { return false; }
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
@ -192,7 +188,7 @@ private:
public: public:
explicit AggregateFunctionUniqCombinedVariadic(const DataTypes & arguments, const Array & params) explicit AggregateFunctionUniqCombinedVariadic(const DataTypes & arguments, const Array & params)
: IAggregateFunctionDataHelper<AggregateFunctionUniqCombinedData<UInt64, K, HashValueType>, : IAggregateFunctionDataHelper<AggregateFunctionUniqCombinedData<UInt64, K, HashValueType>,
AggregateFunctionUniqCombinedVariadic<is_exact, argument_is_tuple, K, HashValueType>>(arguments, params) AggregateFunctionUniqCombinedVariadic<is_exact, argument_is_tuple, K, HashValueType>>(arguments, params, std::make_shared<DataTypeUInt64>())
{ {
if (argument_is_tuple) if (argument_is_tuple)
num_args = typeid_cast<const DataTypeTuple &>(*arguments[0]).getElements().size(); num_args = typeid_cast<const DataTypeTuple &>(*arguments[0]).getElements().size();
@ -208,11 +204,6 @@ public:
return "uniqCombined"; return "uniqCombined";
} }
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeUInt64>();
}
bool allocatesMemoryInArena() const override { return false; } bool allocatesMemoryInArena() const override { return false; }
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override

View File

@ -174,7 +174,7 @@ private:
public: public:
AggregateFunctionUniqUpTo(UInt8 threshold_, const DataTypes & argument_types_, const Array & params_) AggregateFunctionUniqUpTo(UInt8 threshold_, const DataTypes & argument_types_, const Array & params_)
: IAggregateFunctionDataHelper<AggregateFunctionUniqUpToData<T>, AggregateFunctionUniqUpTo<T>>(argument_types_, params_) : IAggregateFunctionDataHelper<AggregateFunctionUniqUpToData<T>, AggregateFunctionUniqUpTo<T>>(argument_types_, params_, std::make_shared<DataTypeUInt64>())
, threshold(threshold_) , threshold(threshold_)
{ {
} }
@ -186,11 +186,6 @@ public:
String getName() const override { return "uniqUpTo"; } String getName() const override { return "uniqUpTo"; }
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeUInt64>();
}
bool allocatesMemoryInArena() const override { return false; } bool allocatesMemoryInArena() const override { return false; }
/// ALWAYS_INLINE is required to have better code layout for uniqUpTo function /// ALWAYS_INLINE is required to have better code layout for uniqUpTo function
@ -235,7 +230,7 @@ private:
public: public:
AggregateFunctionUniqUpToVariadic(const DataTypes & arguments, const Array & params, UInt8 threshold_) AggregateFunctionUniqUpToVariadic(const DataTypes & arguments, const Array & params, UInt8 threshold_)
: IAggregateFunctionDataHelper<AggregateFunctionUniqUpToData<UInt64>, AggregateFunctionUniqUpToVariadic<is_exact, argument_is_tuple>>(arguments, params) : IAggregateFunctionDataHelper<AggregateFunctionUniqUpToData<UInt64>, AggregateFunctionUniqUpToVariadic<is_exact, argument_is_tuple>>(arguments, params, std::make_shared<DataTypeUInt64>())
, threshold(threshold_) , threshold(threshold_)
{ {
if (argument_is_tuple) if (argument_is_tuple)
@ -251,11 +246,6 @@ public:
String getName() const override { return "uniqUpTo"; } String getName() const override { return "uniqUpTo"; }
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeUInt64>();
}
bool allocatesMemoryInArena() const override { return false; } bool allocatesMemoryInArena() const override { return false; }
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override

View File

@ -221,7 +221,7 @@ public:
} }
AggregateFunctionWindowFunnel(const DataTypes & arguments, const Array & params) AggregateFunctionWindowFunnel(const DataTypes & arguments, const Array & params)
: IAggregateFunctionDataHelper<Data, AggregateFunctionWindowFunnel<T, Data>>(arguments, params) : IAggregateFunctionDataHelper<Data, AggregateFunctionWindowFunnel<T, Data>>(arguments, params, std::make_shared<DataTypeUInt8>())
{ {
events_size = arguments.size() - 1; events_size = arguments.size() - 1;
window = params.at(0).safeGet<UInt64>(); window = params.at(0).safeGet<UInt64>();
@ -245,11 +245,6 @@ public:
} }
} }
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeUInt8>();
}
bool allocatesMemoryInArena() const override { return false; } bool allocatesMemoryInArena() const override { return false; }
void add(AggregateDataPtr __restrict place, const IColumn ** columns, const size_t row_num, Arena *) const override void add(AggregateDataPtr __restrict place, const IColumn ** columns, const size_t row_num, Arena *) const override

View File

@ -118,7 +118,7 @@ class AggregateFunctionCrossTab : public IAggregateFunctionDataHelper<Data, Aggr
{ {
public: public:
explicit AggregateFunctionCrossTab(const DataTypes & arguments) explicit AggregateFunctionCrossTab(const DataTypes & arguments)
: IAggregateFunctionDataHelper<Data, AggregateFunctionCrossTab<Data>>({arguments}, {}) : IAggregateFunctionDataHelper<Data, AggregateFunctionCrossTab<Data>>({arguments}, {}, createResultType())
{ {
} }
@ -132,7 +132,7 @@ public:
return false; return false;
} }
DataTypePtr getReturnType() const override static DataTypePtr createResultType()
{ {
return std::make_shared<DataTypeNumber<Float64>>(); return std::make_shared<DataTypeNumber<Float64>>();
} }

View File

@ -10,6 +10,7 @@
#include <base/types.h> #include <base/types.h>
#include <Common/Exception.h> #include <Common/Exception.h>
#include <Common/ThreadPool.h> #include <Common/ThreadPool.h>
#include <Core/IResolvedFunction.h>
#include "config.h" #include "config.h"
@ -48,7 +49,9 @@ using AggregateDataPtr = char *;
using ConstAggregateDataPtr = const char *; using ConstAggregateDataPtr = const char *;
class IAggregateFunction; class IAggregateFunction;
using AggregateFunctionPtr = std::shared_ptr<const IAggregateFunction>; using AggregateFunctionPtr = std::shared_ptr<IAggregateFunction>;
using ConstAggregateFunctionPtr = std::shared_ptr<const IAggregateFunction>;
struct AggregateFunctionProperties; struct AggregateFunctionProperties;
/** Aggregate functions interface. /** Aggregate functions interface.
@ -59,18 +62,18 @@ struct AggregateFunctionProperties;
* (which can be created in some memory pool), * (which can be created in some memory pool),
* and IAggregateFunction is the external interface for manipulating them. * and IAggregateFunction is the external interface for manipulating them.
*/ */
class IAggregateFunction : public std::enable_shared_from_this<IAggregateFunction> class IAggregateFunction : public std::enable_shared_from_this<IAggregateFunction>, public IResolvedFunction
{ {
public: public:
IAggregateFunction(const DataTypes & argument_types_, const Array & parameters_) IAggregateFunction(const DataTypes & argument_types_, const Array & parameters_, const DataTypePtr & result_type_)
: argument_types(argument_types_), parameters(parameters_) {} : result_type(result_type_)
, argument_types(argument_types_)
, parameters(parameters_)
{}
/// Get main function name. /// Get main function name.
virtual String getName() const = 0; virtual String getName() const = 0;
/// Get the result type.
virtual DataTypePtr getReturnType() const = 0;
/// Get the data type of internal state. By default it is AggregateFunction(name(params), argument_types...). /// Get the data type of internal state. By default it is AggregateFunction(name(params), argument_types...).
virtual DataTypePtr getStateType() const; virtual DataTypePtr getStateType() const;
@ -102,7 +105,7 @@ public:
virtual size_t getDefaultVersion() const { return 0; } virtual size_t getDefaultVersion() const { return 0; }
virtual ~IAggregateFunction() = default; ~IAggregateFunction() override = default;
/** Data manipulating functions. */ /** Data manipulating functions. */
@ -348,8 +351,9 @@ public:
*/ */
virtual AggregateFunctionPtr getNestedFunction() const { return {}; } virtual AggregateFunctionPtr getNestedFunction() const { return {}; }
const DataTypes & getArgumentTypes() const { return argument_types; } const DataTypePtr & getResultType() const override { return result_type; }
const Array & getParameters() const { return parameters; } const DataTypes & getArgumentTypes() const override { return argument_types; }
const Array & getParameters() const override { return parameters; }
// Any aggregate function can be calculated over a window, but there are some // Any aggregate function can be calculated over a window, but there are some
// window functions such as rank() that require a different interface, e.g. // window functions such as rank() that require a different interface, e.g.
@ -398,6 +402,7 @@ public:
#endif #endif
protected: protected:
DataTypePtr result_type;
DataTypes argument_types; DataTypes argument_types;
Array parameters; Array parameters;
}; };
@ -414,8 +419,8 @@ private:
} }
public: public:
IAggregateFunctionHelper(const DataTypes & argument_types_, const Array & parameters_) IAggregateFunctionHelper(const DataTypes & argument_types_, const Array & parameters_, const DataTypePtr & result_type_)
: IAggregateFunction(argument_types_, parameters_) {} : IAggregateFunction(argument_types_, parameters_, result_type_) {}
AddFunc getAddressOfAddFunction() const override { return &addFree; } AddFunc getAddressOfAddFunction() const override { return &addFree; }
@ -695,15 +700,15 @@ public:
// Derived class can `override` this to flag that DateTime64 is not supported. // Derived class can `override` this to flag that DateTime64 is not supported.
static constexpr bool DateTime64Supported = true; static constexpr bool DateTime64Supported = true;
IAggregateFunctionDataHelper(const DataTypes & argument_types_, const Array & parameters_) IAggregateFunctionDataHelper(const DataTypes & argument_types_, const Array & parameters_, const DataTypePtr & result_type_)
: IAggregateFunctionHelper<Derived>(argument_types_, parameters_) : IAggregateFunctionHelper<Derived>(argument_types_, parameters_, result_type_)
{ {
/// To prevent derived classes changing the destroy() without updating hasTrivialDestructor() to match it /// To prevent derived classes changing the destroy() without updating hasTrivialDestructor() to match it
/// Enforce that either both of them are changed or none are /// Enforce that either both of them are changed or none are
constexpr bool declares_destroy_and_hasTrivialDestructor = constexpr bool declares_destroy_and_has_trivial_destructor =
std::is_same_v<decltype(&IAggregateFunctionDataHelper::destroy), decltype(&Derived::destroy)> == std::is_same_v<decltype(&IAggregateFunctionDataHelper::destroy), decltype(&Derived::destroy)> ==
std::is_same_v<decltype(&IAggregateFunctionDataHelper::hasTrivialDestructor), decltype(&Derived::hasTrivialDestructor)>; std::is_same_v<decltype(&IAggregateFunctionDataHelper::hasTrivialDestructor), decltype(&Derived::hasTrivialDestructor)>;
static_assert(declares_destroy_and_hasTrivialDestructor, static_assert(declares_destroy_and_has_trivial_destructor,
"destroy() and hasTrivialDestructor() methods of an aggregate function must be either both overridden or not"); "destroy() and hasTrivialDestructor() methods of an aggregate function must be either both overridden or not");
} }

View File

@ -2,6 +2,7 @@
#include <Common/SipHash.h> #include <Common/SipHash.h>
#include <Common/FieldVisitorToString.h> #include <Common/FieldVisitorToString.h>
#include "Core/ColumnsWithTypeAndName.h"
#include <IO/WriteBufferFromString.h> #include <IO/WriteBufferFromString.h>
#include <IO/Operators.h> #include <IO/Operators.h>
@ -25,25 +26,54 @@ FunctionNode::FunctionNode(String function_name_)
children[arguments_child_index] = std::make_shared<ListNode>(); children[arguments_child_index] = std::make_shared<ListNode>();
} }
void FunctionNode::resolveAsFunction(FunctionOverloadResolverPtr function_value, DataTypePtr result_type_value) ColumnsWithTypeAndName FunctionNode::getArgumentTypes() const
{ {
aggregate_function = nullptr; ColumnsWithTypeAndName argument_types;
for (const auto & arg : getArguments().getNodes())
{
ColumnWithTypeAndName argument;
argument.type = arg->getResultType();
argument_types.push_back(argument);
}
return argument_types;
}
FunctionBasePtr FunctionNode::getFunction() const
{
return std::dynamic_pointer_cast<IFunctionBase>(function);
}
AggregateFunctionPtr FunctionNode::getAggregateFunction() const
{
return std::dynamic_pointer_cast<IAggregateFunction>(function);
}
bool FunctionNode::isAggregateFunction() const
{
return typeid_cast<AggregateFunctionPtr>(function) != nullptr && !isWindowFunction();
}
bool FunctionNode::isOrdinaryFunction() const
{
return typeid_cast<FunctionBasePtr>(function) != nullptr;
}
void FunctionNode::resolveAsFunction(FunctionBasePtr function_value)
{
function_name = function_value->getName();
function = std::move(function_value); function = std::move(function_value);
result_type = std::move(result_type_value);
function_name = function->getName();
} }
void FunctionNode::resolveAsAggregateFunction(AggregateFunctionPtr aggregate_function_value, DataTypePtr result_type_value) void FunctionNode::resolveAsAggregateFunction(AggregateFunctionPtr aggregate_function_value)
{ {
function = nullptr; function_name = aggregate_function_value->getName();
aggregate_function = std::move(aggregate_function_value); function = std::move(aggregate_function_value);
result_type = std::move(result_type_value);
function_name = aggregate_function->getName();
} }
void FunctionNode::resolveAsWindowFunction(AggregateFunctionPtr window_function_value, DataTypePtr result_type_value) void FunctionNode::resolveAsWindowFunction(AggregateFunctionPtr window_function_value)
{ {
resolveAsAggregateFunction(window_function_value, result_type_value); resolveAsAggregateFunction(window_function_value);
} }
void FunctionNode::dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const void FunctionNode::dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const
@ -63,8 +93,8 @@ void FunctionNode::dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state
buffer << ", function_type: " << function_type; buffer << ", function_type: " << function_type;
if (result_type) if (function)
buffer << ", result_type: " + result_type->getName(); buffer << ", result_type: " + function->getResultType()->getName();
const auto & parameters = getParameters(); const auto & parameters = getParameters();
if (!parameters.getNodes().empty()) if (!parameters.getNodes().empty())
@ -95,12 +125,14 @@ bool FunctionNode::isEqualImpl(const IQueryTreeNode & rhs) const
isOrdinaryFunction() != rhs_typed.isOrdinaryFunction() || isOrdinaryFunction() != rhs_typed.isOrdinaryFunction() ||
isWindowFunction() != rhs_typed.isWindowFunction()) isWindowFunction() != rhs_typed.isWindowFunction())
return false; return false;
auto lhs_result_type = getResultType();
auto rhs_result_type = rhs.getResultType();
if (result_type && rhs_typed.result_type && !result_type->equals(*rhs_typed.getResultType())) if (lhs_result_type && rhs_result_type && !lhs_result_type->equals(*rhs_result_type))
return false; return false;
else if (result_type && !rhs_typed.result_type) else if (lhs_result_type && !rhs_result_type)
return false; return false;
else if (!result_type && rhs_typed.result_type) else if (!lhs_result_type && rhs_result_type)
return false; return false;
return true; return true;
@ -114,7 +146,7 @@ void FunctionNode::updateTreeHashImpl(HashState & hash_state) const
hash_state.update(isAggregateFunction()); hash_state.update(isAggregateFunction());
hash_state.update(isWindowFunction()); hash_state.update(isWindowFunction());
if (result_type) if (auto result_type = getResultType())
{ {
auto result_type_name = result_type->getName(); auto result_type_name = result_type->getName();
hash_state.update(result_type_name.size()); hash_state.update(result_type_name.size());
@ -130,8 +162,6 @@ QueryTreeNodePtr FunctionNode::cloneImpl() const
* because ordinary functions or aggregate functions must be stateless. * because ordinary functions or aggregate functions must be stateless.
*/ */
result_function->function = function; result_function->function = function;
result_function->aggregate_function = aggregate_function;
result_function->result_type = result_type;
return result_function; return result_function;
} }

View File

@ -1,8 +1,12 @@
#pragma once #pragma once
#include <memory>
#include <Core/IResolvedFunction.h>
#include <Analyzer/IQueryTreeNode.h> #include <Analyzer/IQueryTreeNode.h>
#include <Analyzer/ListNode.h> #include <Analyzer/ListNode.h>
#include <Analyzer/ConstantValue.h> #include <Analyzer/ConstantValue.h>
#include <Common/typeid_cast.h>
#include "Core/ColumnsWithTypeAndName.h"
namespace DB namespace DB
{ {
@ -10,8 +14,11 @@ namespace DB
class IFunctionOverloadResolver; class IFunctionOverloadResolver;
using FunctionOverloadResolverPtr = std::shared_ptr<IFunctionOverloadResolver>; using FunctionOverloadResolverPtr = std::shared_ptr<IFunctionOverloadResolver>;
class IFunctionBase;
using FunctionBasePtr = std::shared_ptr<IFunctionBase>;
class IAggregateFunction; class IAggregateFunction;
using AggregateFunctionPtr = std::shared_ptr<const IAggregateFunction>; using AggregateFunctionPtr = std::shared_ptr<IAggregateFunction>;
/** Function node represents function in query tree. /** Function node represents function in query tree.
* Function syntax: function_name(parameter_1, ...)(argument_1, ...). * Function syntax: function_name(parameter_1, ...)(argument_1, ...).
@ -96,6 +103,8 @@ public:
return children[arguments_child_index]; return children[arguments_child_index];
} }
ColumnsWithTypeAndName getArgumentTypes() const;
/// Returns true if function node has window, false otherwise /// Returns true if function node has window, false otherwise
bool hasWindow() const bool hasWindow() const
{ {
@ -124,24 +133,18 @@ public:
/** Get non aggregate function. /** Get non aggregate function.
* If function is not resolved nullptr returned. * If function is not resolved nullptr returned.
*/ */
const FunctionOverloadResolverPtr & getFunction() const FunctionBasePtr getFunction() const;
{
return function;
}
/** Get aggregate function. /** Get aggregate function.
* If function is not resolved nullptr returned. * If function is not resolved nullptr returned.
* If function is resolved as non aggregate function nullptr returned. * If function is resolved as non aggregate function nullptr returned.
*/ */
const AggregateFunctionPtr & getAggregateFunction() const AggregateFunctionPtr getAggregateFunction() const;
{
return aggregate_function;
}
/// Is function node resolved /// Is function node resolved
bool isResolved() const bool isResolved() const
{ {
return result_type != nullptr && (function != nullptr || aggregate_function != nullptr); return function != nullptr;
} }
/// Is function node window function /// Is function node window function
@ -151,16 +154,10 @@ public:
} }
/// Is function node aggregate function /// Is function node aggregate function
bool isAggregateFunction() const bool isAggregateFunction() const;
{
return aggregate_function != nullptr && !isWindowFunction();
}
/// Is function node ordinary function /// Is function node ordinary function
bool isOrdinaryFunction() const bool isOrdinaryFunction() const;
{
return function != nullptr;
}
/** Resolve function node as non aggregate function. /** Resolve function node as non aggregate function.
* It is important that function name is updated with resolved function name. * It is important that function name is updated with resolved function name.
@ -168,19 +165,19 @@ public:
* Assume we have `multiIf` function with single condition, it can be converted to `if` function. * Assume we have `multiIf` function with single condition, it can be converted to `if` function.
* Function name must be updated accordingly. * Function name must be updated accordingly.
*/ */
void resolveAsFunction(FunctionOverloadResolverPtr function_value, DataTypePtr result_type_value); void resolveAsFunction(FunctionBasePtr function_value);
/** Resolve function node as aggregate function. /** Resolve function node as aggregate function.
* It is important that function name is updated with resolved function name. * It is important that function name is updated with resolved function name.
* Main motivation for this is query tree optimizations. * Main motivation for this is query tree optimizations.
*/ */
void resolveAsAggregateFunction(AggregateFunctionPtr aggregate_function_value, DataTypePtr result_type_value); void resolveAsAggregateFunction(AggregateFunctionPtr aggregate_function_value);
/** Resolve function node as window function. /** Resolve function node as window function.
* It is important that function name is updated with resolved function name. * It is important that function name is updated with resolved function name.
* Main motivation for this is query tree optimizations. * Main motivation for this is query tree optimizations.
*/ */
void resolveAsWindowFunction(AggregateFunctionPtr window_function_value, DataTypePtr result_type_value); void resolveAsWindowFunction(AggregateFunctionPtr window_function_value);
QueryTreeNodeType getNodeType() const override QueryTreeNodeType getNodeType() const override
{ {
@ -189,7 +186,7 @@ public:
DataTypePtr getResultType() const override DataTypePtr getResultType() const override
{ {
return result_type; return function->getResultType();
} }
void dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const override; void dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const override;
@ -205,9 +202,7 @@ protected:
private: private:
String function_name; String function_name;
FunctionOverloadResolverPtr function; IResolvedFunctionPtr function;
AggregateFunctionPtr aggregate_function;
DataTypePtr result_type;
static constexpr size_t parameters_child_index = 0; static constexpr size_t parameters_child_index = 0;
static constexpr size_t arguments_child_index = 1; static constexpr size_t arguments_child_index = 1;

View File

@ -147,7 +147,6 @@ public:
private: private:
static inline void resolveAggregateFunctionNode(FunctionNode & function_node, const String & aggregate_function_name) static inline void resolveAggregateFunctionNode(FunctionNode & function_node, const String & aggregate_function_name)
{ {
auto function_result_type = function_node.getResultType();
auto function_aggregate_function = function_node.getAggregateFunction(); auto function_aggregate_function = function_node.getAggregateFunction();
AggregateFunctionProperties properties; AggregateFunctionProperties properties;
@ -156,7 +155,7 @@ private:
function_aggregate_function->getParameters(), function_aggregate_function->getParameters(),
properties); properties);
function_node.resolveAsAggregateFunction(std::move(aggregate_function), std::move(function_result_type)); function_node.resolveAsAggregateFunction(std::move(aggregate_function));
} }
}; };

View File

@ -69,7 +69,7 @@ public:
auto result_type = function_node->getResultType(); auto result_type = function_node->getResultType();
AggregateFunctionProperties properties; AggregateFunctionProperties properties;
auto aggregate_function = AggregateFunctionFactory::instance().get("count", {}, {}, properties); auto aggregate_function = AggregateFunctionFactory::instance().get("count", {}, {}, properties);
function_node->resolveAsAggregateFunction(std::move(aggregate_function), std::move(result_type)); function_node->resolveAsAggregateFunction(std::move(aggregate_function));
function_node->getArguments().getNodes().clear(); function_node->getArguments().getNodes().clear();
} }
}; };

View File

@ -9,6 +9,7 @@
#include <Analyzer/InDepthQueryTreeVisitor.h> #include <Analyzer/InDepthQueryTreeVisitor.h>
#include <Analyzer/FunctionNode.h> #include <Analyzer/FunctionNode.h>
#include "Core/ColumnWithTypeAndName.h"
namespace DB namespace DB
{ {
@ -138,7 +139,6 @@ public:
static inline void resolveAggregateOrWindowFunctionNode(FunctionNode & function_node, const String & aggregate_function_name) static inline void resolveAggregateOrWindowFunctionNode(FunctionNode & function_node, const String & aggregate_function_name)
{ {
auto function_result_type = function_node.getResultType();
auto function_aggregate_function = function_node.getAggregateFunction(); auto function_aggregate_function = function_node.getAggregateFunction();
AggregateFunctionProperties properties; AggregateFunctionProperties properties;
@ -148,16 +148,15 @@ public:
properties); properties);
if (function_node.isAggregateFunction()) if (function_node.isAggregateFunction())
function_node.resolveAsAggregateFunction(std::move(aggregate_function), std::move(function_result_type)); function_node.resolveAsAggregateFunction(std::move(aggregate_function));
else if (function_node.isWindowFunction()) else if (function_node.isWindowFunction())
function_node.resolveAsWindowFunction(std::move(aggregate_function), std::move(function_result_type)); function_node.resolveAsWindowFunction(std::move(aggregate_function));
} }
inline void resolveOrdinaryFunctionNode(FunctionNode & function_node, const String & function_name) const inline void resolveOrdinaryFunctionNode(FunctionNode & function_node, const String & function_name) const
{ {
auto function_result_type = function_node.getResultType();
auto function = FunctionFactory::instance().get(function_name, context); auto function = FunctionFactory::instance().get(function_name, context);
function_node.resolveAsFunction(function, std::move(function_result_type)); function_node.resolveAsFunction(function->build(function_node.getArgumentTypes()));
} }
private: private:

View File

@ -78,11 +78,11 @@ public:
column.name += ".size0"; column.name += ".size0";
column.type = std::make_shared<DataTypeUInt64>(); column.type = std::make_shared<DataTypeUInt64>();
resolveOrdinaryFunctionNode(*function_node, "equals");
function_arguments_nodes.clear(); function_arguments_nodes.clear();
function_arguments_nodes.push_back(std::make_shared<ColumnNode>(column, column_source)); function_arguments_nodes.push_back(std::make_shared<ColumnNode>(column, column_source));
function_arguments_nodes.push_back(std::make_shared<ConstantNode>(static_cast<UInt64>(0))); function_arguments_nodes.push_back(std::make_shared<ConstantNode>(static_cast<UInt64>(0)));
resolveOrdinaryFunctionNode(*function_node, "equals");
} }
else if (function_name == "notEmpty") else if (function_name == "notEmpty")
{ {
@ -90,11 +90,11 @@ public:
column.name += ".size0"; column.name += ".size0";
column.type = std::make_shared<DataTypeUInt64>(); column.type = std::make_shared<DataTypeUInt64>();
resolveOrdinaryFunctionNode(*function_node, "notEquals");
function_arguments_nodes.clear(); function_arguments_nodes.clear();
function_arguments_nodes.push_back(std::make_shared<ColumnNode>(column, column_source)); function_arguments_nodes.push_back(std::make_shared<ColumnNode>(column, column_source));
function_arguments_nodes.push_back(std::make_shared<ConstantNode>(static_cast<UInt64>(0))); function_arguments_nodes.push_back(std::make_shared<ConstantNode>(static_cast<UInt64>(0)));
resolveOrdinaryFunctionNode(*function_node, "notEquals");
} }
} }
else if (column_type.isNullable()) else if (column_type.isNullable())
@ -112,9 +112,9 @@ public:
column.name += ".null"; column.name += ".null";
column.type = std::make_shared<DataTypeUInt8>(); column.type = std::make_shared<DataTypeUInt8>();
resolveOrdinaryFunctionNode(*function_node, "not");
function_arguments_nodes = {std::make_shared<ColumnNode>(column, column_source)}; function_arguments_nodes = {std::make_shared<ColumnNode>(column, column_source)};
resolveOrdinaryFunctionNode(*function_node, "not");
} }
} }
else if (column_type.isMap()) else if (column_type.isMap())
@ -182,9 +182,9 @@ public:
column.type = data_type_map.getKeyType(); column.type = data_type_map.getKeyType();
auto has_function_argument = std::make_shared<ColumnNode>(column, column_source); auto has_function_argument = std::make_shared<ColumnNode>(column, column_source);
resolveOrdinaryFunctionNode(*function_node, "has");
function_arguments_nodes[0] = std::move(has_function_argument); function_arguments_nodes[0] = std::move(has_function_argument);
resolveOrdinaryFunctionNode(*function_node, "has");
} }
} }
} }
@ -192,9 +192,8 @@ public:
private: private:
inline void resolveOrdinaryFunctionNode(FunctionNode & function_node, const String & function_name) const inline void resolveOrdinaryFunctionNode(FunctionNode & function_node, const String & function_name) const
{ {
auto function_result_type = function_node.getResultType();
auto function = FunctionFactory::instance().get(function_name, context); auto function = FunctionFactory::instance().get(function_name, context);
function_node.resolveAsFunction(function, std::move(function_result_type)); function_node.resolveAsFunction(function->build(function_node.getArgumentTypes()));
} }
ContextPtr & context; ContextPtr & context;

View File

@ -59,14 +59,13 @@ private:
std::unordered_set<String> names_to_collect; std::unordered_set<String> names_to_collect;
}; };
QueryTreeNodePtr createResolvedFunction(const ContextPtr & context, const String & name, const DataTypePtr & result_type, QueryTreeNodes arguments) QueryTreeNodePtr createResolvedFunction(const ContextPtr & context, const String & name, QueryTreeNodes arguments)
{ {
auto function_node = std::make_shared<FunctionNode>(name); auto function_node = std::make_shared<FunctionNode>(name);
auto function = FunctionFactory::instance().get(name, context); auto function = FunctionFactory::instance().get(name, context);
function_node->resolveAsFunction(std::move(function), result_type);
function_node->getArguments().getNodes() = std::move(arguments); function_node->getArguments().getNodes() = std::move(arguments);
function_node->resolveAsFunction(function->build(function_node->getArgumentTypes()));
return function_node; return function_node;
} }
@ -74,11 +73,6 @@ FunctionNodePtr createResolvedAggregateFunction(const String & name, const Query
{ {
auto function_node = std::make_shared<FunctionNode>(name); auto function_node = std::make_shared<FunctionNode>(name);
AggregateFunctionProperties properties;
auto aggregate_function = AggregateFunctionFactory::instance().get(name, {argument->getResultType()}, parameters, properties);
function_node->resolveAsAggregateFunction(aggregate_function, aggregate_function->getReturnType());
function_node->getArguments().getNodes() = { argument };
if (!parameters.empty()) if (!parameters.empty())
{ {
QueryTreeNodes parameter_nodes; QueryTreeNodes parameter_nodes;
@ -86,18 +80,27 @@ FunctionNodePtr createResolvedAggregateFunction(const String & name, const Query
parameter_nodes.emplace_back(std::make_shared<ConstantNode>(param)); parameter_nodes.emplace_back(std::make_shared<ConstantNode>(param));
function_node->getParameters().getNodes() = std::move(parameter_nodes); function_node->getParameters().getNodes() = std::move(parameter_nodes);
} }
function_node->getArguments().getNodes() = { argument };
AggregateFunctionProperties properties;
auto aggregate_function = AggregateFunctionFactory::instance().get(
name,
{ argument->getResultType() },
parameters,
properties);
function_node->resolveAsAggregateFunction(aggregate_function);
return function_node; return function_node;
} }
QueryTreeNodePtr createTupleElementFunction(const ContextPtr & context, const DataTypePtr & result_type, QueryTreeNodePtr argument, UInt64 index) QueryTreeNodePtr createTupleElementFunction(const ContextPtr & context, QueryTreeNodePtr argument, UInt64 index)
{ {
return createResolvedFunction(context, "tupleElement", result_type, {std::move(argument), std::make_shared<ConstantNode>(index)}); return createResolvedFunction(context, "tupleElement", {argument, std::make_shared<ConstantNode>(index)});
} }
QueryTreeNodePtr createArrayElementFunction(const ContextPtr & context, const DataTypePtr & result_type, QueryTreeNodePtr argument, UInt64 index) QueryTreeNodePtr createArrayElementFunction(const ContextPtr & context, QueryTreeNodePtr argument, UInt64 index)
{ {
return createResolvedFunction(context, "arrayElement", result_type, {std::move(argument), std::make_shared<ConstantNode>(index)}); return createResolvedFunction(context, "arrayElement", {argument, std::make_shared<ConstantNode>(index)});
} }
void replaceWithSumCount(QueryTreeNodePtr & node, const FunctionNodePtr & sum_count_node, ContextPtr context) void replaceWithSumCount(QueryTreeNodePtr & node, const FunctionNodePtr & sum_count_node, ContextPtr context)
@ -115,20 +118,20 @@ void replaceWithSumCount(QueryTreeNodePtr & node, const FunctionNodePtr & sum_co
if (function_name == "sum") if (function_name == "sum")
{ {
assert(node->getResultType()->equals(*sum_count_result_type->getElement(0))); assert(node->getResultType()->equals(*sum_count_result_type->getElement(0)));
node = createTupleElementFunction(context, node->getResultType(), sum_count_node, 1); node = createTupleElementFunction(context, sum_count_node, 1);
} }
else if (function_name == "count") else if (function_name == "count")
{ {
assert(node->getResultType()->equals(*sum_count_result_type->getElement(1))); assert(node->getResultType()->equals(*sum_count_result_type->getElement(1)));
node = createTupleElementFunction(context, node->getResultType(), sum_count_node, 2); node = createTupleElementFunction(context, sum_count_node, 2);
} }
else if (function_name == "avg") else if (function_name == "avg")
{ {
auto sum_result = createTupleElementFunction(context, sum_count_result_type->getElement(0), sum_count_node, 1); auto sum_result = createTupleElementFunction(context, sum_count_node, 1);
auto count_result = createTupleElementFunction(context, sum_count_result_type->getElement(1), sum_count_node, 2); auto count_result = createTupleElementFunction(context, sum_count_node, 2);
/// To avoid integer division by zero /// To avoid integer division by zero
auto count_float_result = createResolvedFunction(context, "toFloat64", std::make_shared<DataTypeFloat64>(), {count_result}); auto count_float_result = createResolvedFunction(context, "toFloat64", {count_result});
node = createResolvedFunction(context, "divide", node->getResultType(), {sum_result, count_float_result}); node = createResolvedFunction(context, "divide", {sum_result, count_float_result});
} }
else else
{ {
@ -238,7 +241,7 @@ void tryFuseQuantiles(QueryTreeNodePtr query_tree_node, ContextPtr context)
for (size_t i = 0; i < nodes_set.size(); ++i) for (size_t i = 0; i < nodes_set.size(); ++i)
{ {
size_t array_index = i + 1; size_t array_index = i + 1;
*nodes[i] = createArrayElementFunction(context, result_array_type->getNestedType(), quantiles_node, array_index); *nodes[i] = createArrayElementFunction(context, quantiles_node, array_index);
} }
} }
} }

View File

@ -55,8 +55,8 @@ public:
return; return;
auto multi_if_function = std::make_shared<FunctionNode>("multiIf"); auto multi_if_function = std::make_shared<FunctionNode>("multiIf");
multi_if_function->resolveAsFunction(multi_if_function_ptr, std::make_shared<DataTypeUInt8>());
multi_if_function->getArguments().getNodes() = std::move(multi_if_arguments); multi_if_function->getArguments().getNodes() = std::move(multi_if_arguments);
multi_if_function->resolveAsFunction(multi_if_function_ptr->build(multi_if_function->getArgumentTypes()));
node = std::move(multi_if_function); node = std::move(multi_if_function);
} }

View File

@ -27,7 +27,7 @@ public:
return; return;
auto result_type = function_node->getResultType(); auto result_type = function_node->getResultType();
function_node->resolveAsFunction(if_function_ptr, std::move(result_type)); function_node->resolveAsFunction(if_function_ptr->build(function_node->getArgumentTypes()));
} }
private: private:

View File

@ -48,12 +48,10 @@ public:
private: private:
static inline void resolveAsCountAggregateFunction(FunctionNode & function_node) static inline void resolveAsCountAggregateFunction(FunctionNode & function_node)
{ {
auto function_result_type = function_node.getResultType();
AggregateFunctionProperties properties; AggregateFunctionProperties properties;
auto aggregate_function = AggregateFunctionFactory::instance().get("count", {}, {}, properties); auto aggregate_function = AggregateFunctionFactory::instance().get("count", {}, {}, properties);
function_node.resolveAsAggregateFunction(std::move(aggregate_function), std::move(function_result_type)); function_node.resolveAsAggregateFunction(std::move(aggregate_function));
} }
}; };

View File

@ -4287,7 +4287,7 @@ ProjectionNames QueryAnalyzer::resolveFunction(QueryTreeNodePtr & node, Identifi
bool force_grouping_standard_compatibility = scope.context->getSettingsRef().force_grouping_standard_compatibility; bool force_grouping_standard_compatibility = scope.context->getSettingsRef().force_grouping_standard_compatibility;
auto grouping_function = std::make_shared<FunctionGrouping>(force_grouping_standard_compatibility); auto grouping_function = std::make_shared<FunctionGrouping>(force_grouping_standard_compatibility);
auto grouping_function_adaptor = std::make_shared<FunctionToOverloadResolverAdaptor>(std::move(grouping_function)); auto grouping_function_adaptor = std::make_shared<FunctionToOverloadResolverAdaptor>(std::move(grouping_function));
function_node.resolveAsFunction(std::move(grouping_function_adaptor), std::make_shared<DataTypeUInt64>()); function_node.resolveAsFunction(grouping_function_adaptor->build({}));
return result_projection_names; return result_projection_names;
} }
} }
@ -4307,7 +4307,7 @@ ProjectionNames QueryAnalyzer::resolveFunction(QueryTreeNodePtr & node, Identifi
AggregateFunctionProperties properties; AggregateFunctionProperties properties;
auto aggregate_function = AggregateFunctionFactory::instance().get(function_name, argument_types, parameters, properties); auto aggregate_function = AggregateFunctionFactory::instance().get(function_name, argument_types, parameters, properties);
function_node.resolveAsWindowFunction(aggregate_function, aggregate_function->getReturnType()); function_node.resolveAsWindowFunction(aggregate_function);
bool window_node_is_identifier = function_node.getWindowNode()->getNodeType() == QueryTreeNodeType::IDENTIFIER; bool window_node_is_identifier = function_node.getWindowNode()->getNodeType() == QueryTreeNodeType::IDENTIFIER;
ProjectionName window_projection_name = resolveWindow(function_node.getWindowNode(), scope); ProjectionName window_projection_name = resolveWindow(function_node.getWindowNode(), scope);
@ -4361,7 +4361,7 @@ ProjectionNames QueryAnalyzer::resolveFunction(QueryTreeNodePtr & node, Identifi
AggregateFunctionProperties properties; AggregateFunctionProperties properties;
auto aggregate_function = AggregateFunctionFactory::instance().get(function_name, argument_types, parameters, properties); auto aggregate_function = AggregateFunctionFactory::instance().get(function_name, argument_types, parameters, properties);
function_node.resolveAsAggregateFunction(aggregate_function, aggregate_function->getReturnType()); function_node.resolveAsAggregateFunction(aggregate_function);
return result_projection_names; return result_projection_names;
} }
@ -4538,6 +4538,8 @@ ProjectionNames QueryAnalyzer::resolveFunction(QueryTreeNodePtr & node, Identifi
constant_value = std::make_shared<ConstantValue>(std::move(column_constant_value), result_type); constant_value = std::make_shared<ConstantValue>(std::move(column_constant_value), result_type);
} }
} }
function_node.resolveAsFunction(std::move(function_base));
} }
catch (Exception & e) catch (Exception & e)
{ {
@ -4545,8 +4547,6 @@ ProjectionNames QueryAnalyzer::resolveFunction(QueryTreeNodePtr & node, Identifi
throw; throw;
} }
function_node.resolveAsFunction(std::move(function), std::move(result_type));
if (constant_value) if (constant_value)
node = std::make_shared<ConstantNode>(std::move(constant_value), node); node = std::make_shared<ConstantNode>(std::move(constant_value), node);

View File

@ -13,6 +13,7 @@
#include <Analyzer/InDepthQueryTreeVisitor.h> #include <Analyzer/InDepthQueryTreeVisitor.h>
#include <Analyzer/ConstantNode.h> #include <Analyzer/ConstantNode.h>
#include <Analyzer/FunctionNode.h> #include <Analyzer/FunctionNode.h>
#include "Core/ColumnsWithTypeAndName.h"
namespace DB namespace DB
{ {
@ -117,11 +118,12 @@ public:
not_function_result_type = makeNullable(not_function_result_type); not_function_result_type = makeNullable(not_function_result_type);
auto not_function = std::make_shared<FunctionNode>("not"); auto not_function = std::make_shared<FunctionNode>("not");
not_function->resolveAsFunction(FunctionFactory::instance().get("not", context), std::move(not_function_result_type));
auto & not_function_arguments = not_function->getArguments().getNodes(); auto & not_function_arguments = not_function->getArguments().getNodes();
not_function_arguments.push_back(std::move(nested_if_function_arguments_nodes[0])); not_function_arguments.push_back(std::move(nested_if_function_arguments_nodes[0]));
not_function->resolveAsFunction(FunctionFactory::instance().get("not", context)->build(not_function->getArgumentTypes()));
function_node_arguments_nodes[0] = std::move(not_function); function_node_arguments_nodes[0] = std::move(not_function);
function_node_arguments_nodes.resize(1); function_node_arguments_nodes.resize(1);
@ -139,8 +141,7 @@ private:
function_node.getAggregateFunction()->getParameters(), function_node.getAggregateFunction()->getParameters(),
properties); properties);
auto function_result_type = function_node.getResultType(); function_node.resolveAsAggregateFunction(std::move(aggregate_function));
function_node.resolveAsAggregateFunction(std::move(aggregate_function), std::move(function_result_type));
} }
ContextPtr & context; ContextPtr & context;

View File

@ -76,7 +76,7 @@ public:
properties); properties);
auto function_result_type = function_node->getResultType(); auto function_result_type = function_node->getResultType();
function_node->resolveAsAggregateFunction(std::move(aggregate_function), std::move(function_result_type)); function_node->resolveAsAggregateFunction(std::move(aggregate_function));
} }
}; };

View File

@ -28,7 +28,7 @@ namespace ErrorCodes
} }
static String getTypeString(const AggregateFunctionPtr & func, std::optional<size_t> version = std::nullopt) static String getTypeString(const ConstAggregateFunctionPtr & func, std::optional<size_t> version = std::nullopt)
{ {
WriteBufferFromOwnString stream; WriteBufferFromOwnString stream;
@ -62,18 +62,18 @@ static String getTypeString(const AggregateFunctionPtr & func, std::optional<siz
} }
ColumnAggregateFunction::ColumnAggregateFunction(const AggregateFunctionPtr & func_, std::optional<size_t> version_) ColumnAggregateFunction::ColumnAggregateFunction(const ConstAggregateFunctionPtr & func_, std::optional<size_t> version_)
: func(func_), type_string(getTypeString(func, version_)), version(version_) : func(func_), type_string(getTypeString(func, version_)), version(version_)
{ {
} }
ColumnAggregateFunction::ColumnAggregateFunction(const AggregateFunctionPtr & func_, const ConstArenas & arenas_) ColumnAggregateFunction::ColumnAggregateFunction(const ConstAggregateFunctionPtr & func_, const ConstArenas & arenas_)
: foreign_arenas(arenas_), func(func_), type_string(getTypeString(func)) : foreign_arenas(arenas_), func(func_), type_string(getTypeString(func))
{ {
} }
void ColumnAggregateFunction::set(const AggregateFunctionPtr & func_, size_t version_) void ColumnAggregateFunction::set(const ConstAggregateFunctionPtr & func_, size_t version_)
{ {
func = func_; func = func_;
version = version_; version = version_;
@ -146,7 +146,7 @@ MutableColumnPtr ColumnAggregateFunction::convertToValues(MutableColumnPtr colum
/// insertResultInto may invalidate states, so we must unshare ownership of them /// insertResultInto may invalidate states, so we must unshare ownership of them
column_aggregate_func.ensureOwnership(); column_aggregate_func.ensureOwnership();
MutableColumnPtr res = func->getReturnType()->createColumn(); MutableColumnPtr res = func->getResultType()->createColumn();
res->reserve(data.size()); res->reserve(data.size());
/// If there are references to states in final column, we must hold their ownership /// If there are references to states in final column, we must hold their ownership

View File

@ -70,7 +70,7 @@ private:
ArenaPtr my_arena; ArenaPtr my_arena;
/// Used for destroying states and for finalization of values. /// Used for destroying states and for finalization of values.
AggregateFunctionPtr func; ConstAggregateFunctionPtr func;
/// Source column. Used (holds source from destruction), /// Source column. Used (holds source from destruction),
/// if this column has been constructed from another and uses all or part of its values. /// if this column has been constructed from another and uses all or part of its values.
@ -92,9 +92,9 @@ private:
/// Create a new column that has another column as a source. /// Create a new column that has another column as a source.
MutablePtr createView() const; MutablePtr createView() const;
explicit ColumnAggregateFunction(const AggregateFunctionPtr & func_, std::optional<size_t> version_ = std::nullopt); explicit ColumnAggregateFunction(const ConstAggregateFunctionPtr & func_, std::optional<size_t> version_ = std::nullopt);
ColumnAggregateFunction(const AggregateFunctionPtr & func_, const ConstArenas & arenas_); ColumnAggregateFunction(const ConstAggregateFunctionPtr & func_, const ConstArenas & arenas_);
ColumnAggregateFunction(const ColumnAggregateFunction & src_); ColumnAggregateFunction(const ColumnAggregateFunction & src_);
@ -103,10 +103,10 @@ private:
public: public:
~ColumnAggregateFunction() override; ~ColumnAggregateFunction() override;
void set(const AggregateFunctionPtr & func_, size_t version_); void set(const ConstAggregateFunctionPtr & func_, size_t version_);
AggregateFunctionPtr getAggregateFunction() { return func; } ConstAggregateFunctionPtr getAggregateFunction() { return func; }
AggregateFunctionPtr getAggregateFunction() const { return func; } ConstAggregateFunctionPtr getAggregateFunction() const { return func; }
/// If we have another column as a source (owner of data), copy all data to ourself and reset source. /// If we have another column as a source (owner of data), copy all data to ourself and reset source.
/// This is needed before inserting new elements, because we must own these elements (to destroy them in destructor), /// This is needed before inserting new elements, because we must own these elements (to destroy them in destructor),

View File

@ -0,0 +1,29 @@
#pragma once
#include <memory>
#include <vector>
namespace DB
{
class IDataType;
using DataTypePtr = std::shared_ptr<const IDataType>;
using DataTypes = std::vector<DataTypePtr>;
struct Array;
class IResolvedFunction
{
public:
virtual const DataTypePtr & getResultType() const = 0;
virtual const DataTypes & getArgumentTypes() const = 0;
virtual const Array & getParameters() const = 0;
virtual ~IResolvedFunction() = default;
};
using IResolvedFunctionPtr = std::shared_ptr<IResolvedFunction>;
}

View File

@ -19,7 +19,7 @@ namespace DB
class DataTypeAggregateFunction final : public IDataType class DataTypeAggregateFunction final : public IDataType
{ {
private: private:
AggregateFunctionPtr function; ConstAggregateFunctionPtr function;
DataTypes argument_types; DataTypes argument_types;
Array parameters; Array parameters;
mutable std::optional<size_t> version; mutable std::optional<size_t> version;
@ -30,9 +30,9 @@ private:
public: public:
static constexpr bool is_parametric = true; static constexpr bool is_parametric = true;
DataTypeAggregateFunction(const AggregateFunctionPtr & function_, const DataTypes & argument_types_, DataTypeAggregateFunction(ConstAggregateFunctionPtr function_, const DataTypes & argument_types_,
const Array & parameters_, std::optional<size_t> version_ = std::nullopt) const Array & parameters_, std::optional<size_t> version_ = std::nullopt)
: function(function_) : function(std::move(function_))
, argument_types(argument_types_) , argument_types(argument_types_)
, parameters(parameters_) , parameters(parameters_)
, version(version_) , version(version_)
@ -40,7 +40,7 @@ public:
} }
String getFunctionName() const { return function->getName(); } String getFunctionName() const { return function->getName(); }
AggregateFunctionPtr getFunction() const { return function; } ConstAggregateFunctionPtr getFunction() const { return function; }
String doGetName() const override; String doGetName() const override;
String getNameWithoutVersion() const; String getNameWithoutVersion() const;
@ -51,7 +51,7 @@ public:
bool canBeInsideNullable() const override { return false; } bool canBeInsideNullable() const override { return false; }
DataTypePtr getReturnType() const { return function->getReturnType(); } DataTypePtr getReturnType() const { return function->getResultType(); }
DataTypePtr getReturnTypeToPredict() const { return function->getReturnTypeToPredict(); } DataTypePtr getReturnTypeToPredict() const { return function->getReturnTypeToPredict(); }
DataTypes getArgumentsDataTypes() const { return argument_types; } DataTypes getArgumentsDataTypes() const { return argument_types; }

View File

@ -131,9 +131,9 @@ static std::pair<DataTypePtr, DataTypeCustomDescPtr> create(const ASTPtr & argum
DataTypePtr storage_type = DataTypeFactory::instance().get(argument_types[0]->getName()); DataTypePtr storage_type = DataTypeFactory::instance().get(argument_types[0]->getName());
if (!function->getReturnType()->equals(*removeLowCardinality(storage_type))) if (!function->getResultType()->equals(*removeLowCardinality(storage_type)))
{ {
throw Exception("Incompatible data types between aggregate function '" + function->getName() + "' which returns " + function->getReturnType()->getName() + " and column storage type " + storage_type->getName(), throw Exception("Incompatible data types between aggregate function '" + function->getName() + "' which returns " + function->getResultType()->getName() + " and column storage type " + storage_type->getName(),
ErrorCodes::BAD_ARGUMENTS); ErrorCodes::BAD_ARGUMENTS);
} }

View File

@ -108,14 +108,14 @@ void SerializationAggregateFunction::deserializeBinaryBulk(IColumn & column, Rea
} }
} }
static String serializeToString(const AggregateFunctionPtr & function, const IColumn & column, size_t row_num, size_t version) static String serializeToString(const ConstAggregateFunctionPtr & function, const IColumn & column, size_t row_num, size_t version)
{ {
WriteBufferFromOwnString buffer; WriteBufferFromOwnString buffer;
function->serialize(assert_cast<const ColumnAggregateFunction &>(column).getData()[row_num], buffer, version); function->serialize(assert_cast<const ColumnAggregateFunction &>(column).getData()[row_num], buffer, version);
return buffer.str(); return buffer.str();
} }
static void deserializeFromString(const AggregateFunctionPtr & function, IColumn & column, const String & s, size_t version) static void deserializeFromString(const ConstAggregateFunctionPtr & function, IColumn & column, const String & s, size_t version)
{ {
ColumnAggregateFunction & column_concrete = assert_cast<ColumnAggregateFunction &>(column); ColumnAggregateFunction & column_concrete = assert_cast<ColumnAggregateFunction &>(column);

View File

@ -11,14 +11,14 @@ namespace DB
class SerializationAggregateFunction final : public ISerialization class SerializationAggregateFunction final : public ISerialization
{ {
private: private:
AggregateFunctionPtr function; ConstAggregateFunctionPtr function;
String type_name; String type_name;
size_t version; size_t version;
public: public:
static constexpr bool is_parametric = true; static constexpr bool is_parametric = true;
SerializationAggregateFunction(const AggregateFunctionPtr & function_, String type_name_, size_t version_) SerializationAggregateFunction(const ConstAggregateFunctionPtr & function_, String type_name_, size_t version_)
: function(function_), type_name(std::move(type_name_)), version(version_) {} : function(function_), type_name(std::move(type_name_)), version(version_) {}
/// NOTE These two functions for serializing single values are incompatible with the functions below. /// NOTE These two functions for serializing single values are incompatible with the functions below.

View File

@ -1736,7 +1736,7 @@ namespace
} }
const std::shared_ptr<const DataTypeAggregateFunction> aggregate_function_data_type; const std::shared_ptr<const DataTypeAggregateFunction> aggregate_function_data_type;
const AggregateFunctionPtr aggregate_function; ConstAggregateFunctionPtr aggregate_function;
String text_buffer; String text_buffer;
}; };

View File

@ -895,7 +895,7 @@ class FunctionBinaryArithmetic : public IFunction
const ColumnAggregateFunction & column = typeid_cast<const ColumnAggregateFunction &>( const ColumnAggregateFunction & column = typeid_cast<const ColumnAggregateFunction &>(
agg_state_is_const ? assert_cast<const ColumnConst &>(agg_state_column).getDataColumn() : agg_state_column); agg_state_is_const ? assert_cast<const ColumnConst &>(agg_state_column).getDataColumn() : agg_state_column);
AggregateFunctionPtr function = column.getAggregateFunction(); ConstAggregateFunctionPtr function = column.getAggregateFunction();
size_t size = agg_state_is_const ? 1 : input_rows_count; size_t size = agg_state_is_const ? 1 : input_rows_count;
@ -960,7 +960,7 @@ class FunctionBinaryArithmetic : public IFunction
const ColumnAggregateFunction & rhs = typeid_cast<const ColumnAggregateFunction &>( const ColumnAggregateFunction & rhs = typeid_cast<const ColumnAggregateFunction &>(
rhs_is_const ? assert_cast<const ColumnConst &>(rhs_column).getDataColumn() : rhs_column); rhs_is_const ? assert_cast<const ColumnConst &>(rhs_column).getDataColumn() : rhs_column);
AggregateFunctionPtr function = lhs.getAggregateFunction(); ConstAggregateFunctionPtr function = lhs.getAggregateFunction();
size_t size = (lhs_is_const && rhs_is_const) ? 1 : input_rows_count; size_t size = (lhs_is_const && rhs_is_const) ? 1 : input_rows_count;

View File

@ -3,6 +3,8 @@
#include <Core/ColumnNumbers.h> #include <Core/ColumnNumbers.h>
#include <Core/ColumnsWithTypeAndName.h> #include <Core/ColumnsWithTypeAndName.h>
#include <Core/Names.h> #include <Core/Names.h>
#include <Core/IResolvedFunction.h>
#include <Common/Exception.h>
#include <DataTypes/IDataType.h> #include <DataTypes/IDataType.h>
#include "config.h" #include "config.h"
@ -122,11 +124,11 @@ using Values = std::vector<llvm::Value *>;
/** Function with known arguments and return type (when the specific overload was chosen). /** Function with known arguments and return type (when the specific overload was chosen).
* It is also the point where all function-specific properties are known. * It is also the point where all function-specific properties are known.
*/ */
class IFunctionBase class IFunctionBase : public IResolvedFunction
{ {
public: public:
virtual ~IFunctionBase() = default; ~IFunctionBase() override = default;
virtual ColumnPtr execute( /// NOLINT virtual ColumnPtr execute( /// NOLINT
const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count, bool dry_run = false) const const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count, bool dry_run = false) const
@ -137,8 +139,10 @@ public:
/// Get the main function name. /// Get the main function name.
virtual String getName() const = 0; virtual String getName() const = 0;
virtual const DataTypes & getArgumentTypes() const = 0; const Array & getParameters() const final
virtual const DataTypePtr & getResultType() const = 0; {
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "IFunctionBase doesn't support getParameters method");
}
/// Do preparations and return executable. /// Do preparations and return executable.
/// sample_columns should contain data types of arguments and values of constants, if relevant. /// sample_columns should contain data types of arguments and values of constants, if relevant.

View File

@ -51,6 +51,8 @@ public:
const DataTypes & getArgumentTypes() const override { return arguments; } const DataTypes & getArgumentTypes() const override { return arguments; }
const DataTypePtr & getResultType() const override { return result_type; } const DataTypePtr & getResultType() const override { return result_type; }
const FunctionPtr & getFunction() const { return function; }
#if USE_EMBEDDED_COMPILER #if USE_EMBEDDED_COMPILER
bool isCompilable() const override { return function->isCompilable(getArgumentTypes()); } bool isCompilable() const override { return function->isCompilable(getArgumentTypes()); }

View File

@ -104,7 +104,7 @@ DataTypePtr FunctionArrayReduce::getReturnTypeImpl(const ColumnsWithTypeAndName
aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name, argument_types, params_row, properties); aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name, argument_types, params_row, properties);
} }
return aggregate_function->getReturnType(); return aggregate_function->getResultType();
} }

View File

@ -122,7 +122,7 @@ DataTypePtr FunctionArrayReduceInRanges::getReturnTypeImpl(const ColumnsWithType
aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name, argument_types, params_row, properties); aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name, argument_types, params_row, properties);
} }
return std::make_shared<DataTypeArray>(aggregate_function->getReturnType()); return std::make_shared<DataTypeArray>(aggregate_function->getResultType());
} }

View File

@ -87,7 +87,7 @@ DataTypePtr FunctionInitializeAggregation::getReturnTypeImpl(const ColumnsWithTy
aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name, argument_types, params_row, properties); aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name, argument_types, params_row, properties);
} }
return aggregate_function->getReturnType(); return aggregate_function->getResultType();
} }

View File

@ -91,7 +91,7 @@ public:
if (arguments.size() == 2) if (arguments.size() == 2)
column_with_groups = arguments[1].column; column_with_groups = arguments[1].column;
AggregateFunctionPtr aggregate_function_ptr = column_with_states->getAggregateFunction(); ConstAggregateFunctionPtr aggregate_function_ptr = column_with_states->getAggregateFunction();
const IAggregateFunction & agg_func = *aggregate_function_ptr; const IAggregateFunction & agg_func = *aggregate_function_ptr;
AlignedBuffer place(agg_func.sizeOfData(), agg_func.alignOfData()); AlignedBuffer place(agg_func.sizeOfData(), agg_func.alignOfData());
@ -99,7 +99,7 @@ public:
/// Will pass empty arena if agg_func does not allocate memory in arena /// Will pass empty arena if agg_func does not allocate memory in arena
std::unique_ptr<Arena> arena = agg_func.allocatesMemoryInArena() ? std::make_unique<Arena>() : nullptr; std::unique_ptr<Arena> arena = agg_func.allocatesMemoryInArena() ? std::make_unique<Arena>() : nullptr;
auto result_column_ptr = agg_func.getReturnType()->createColumn(); auto result_column_ptr = agg_func.getResultType()->createColumn();
IColumn & result_column = *result_column_ptr; IColumn & result_column = *result_column_ptr;
result_column.reserve(column_with_states->size()); result_column.reserve(column_with_states->size());

View File

@ -47,8 +47,6 @@ void ActionsDAG::Node::toTree(JSONBuilder::JSONMap & map) const
if (function_base) if (function_base)
map.add("Function", function_base->getName()); map.add("Function", function_base->getName());
else if (function_builder)
map.add("Function", function_builder->getName());
if (type == ActionType::FUNCTION) if (type == ActionType::FUNCTION)
map.add("Compiled", is_function_compiled); map.add("Compiled", is_function_compiled);
@ -166,7 +164,6 @@ const ActionsDAG::Node & ActionsDAG::addFunction(
Node node; Node node;
node.type = ActionType::FUNCTION; node.type = ActionType::FUNCTION;
node.function_builder = function;
node.children = std::move(children); node.children = std::move(children);
bool all_const = true; bool all_const = true;
@ -238,6 +235,86 @@ const ActionsDAG::Node & ActionsDAG::addFunction(
return addNode(std::move(node)); return addNode(std::move(node));
} }
const ActionsDAG::Node & ActionsDAG::addFunction(
const FunctionBasePtr & function_base,
NodeRawConstPtrs children,
std::string result_name)
{
size_t num_arguments = children.size();
Node node;
node.type = ActionType::FUNCTION;
node.children = std::move(children);
bool all_const = true;
ColumnsWithTypeAndName arguments(num_arguments);
for (size_t i = 0; i < num_arguments; ++i)
{
const auto & child = *node.children[i];
ColumnWithTypeAndName argument;
argument.column = child.column;
argument.type = child.result_type;
argument.name = child.result_name;
if (!argument.column || !isColumnConst(*argument.column))
all_const = false;
arguments[i] = std::move(argument);
}
node.function_base = function_base;
node.result_type = node.function_base->getResultType();
node.function = node.function_base->prepare(arguments);
node.is_deterministic = node.function_base->isDeterministic();
/// If all arguments are constants, and function is suitable to be executed in 'prepare' stage - execute function.
if (node.function_base->isSuitableForConstantFolding())
{
ColumnPtr column;
if (all_const)
{
size_t num_rows = arguments.empty() ? 0 : arguments.front().column->size();
column = node.function->execute(arguments, node.result_type, num_rows, true);
}
else
{
column = node.function_base->getConstantResultForNonConstArguments(arguments, node.result_type);
}
/// If the result is not a constant, just in case, we will consider the result as unknown.
if (column && isColumnConst(*column))
{
/// All constant (literal) columns in block are added with size 1.
/// But if there was no columns in block before executing a function, the result has size 0.
/// Change the size to 1.
if (column->empty())
column = column->cloneResized(1);
node.column = std::move(column);
}
}
if (result_name.empty())
{
result_name = function_base->getName() + "(";
for (size_t i = 0; i < num_arguments; ++i)
{
if (i)
result_name += ", ";
result_name += node.children[i]->result_name;
}
result_name += ")";
}
node.result_name = std::move(result_name);
return addNode(std::move(node));
}
const ActionsDAG::Node & ActionsDAG::findInOutputs(const std::string & name) const const ActionsDAG::Node & ActionsDAG::findInOutputs(const std::string & name) const
{ {
if (const auto * node = tryFindInOutputs(name)) if (const auto * node = tryFindInOutputs(name))
@ -1927,8 +2004,7 @@ ActionsDAGPtr ActionsDAG::cloneActionsForFilterPushDown(
FunctionOverloadResolverPtr func_builder_cast = CastInternalOverloadResolver<CastType::nonAccurate>::createImpl(); FunctionOverloadResolverPtr func_builder_cast = CastInternalOverloadResolver<CastType::nonAccurate>::createImpl();
predicate->function_builder = func_builder_cast; predicate->function_base = func_builder_cast->build(arguments);
predicate->function_base = predicate->function_builder->build(arguments);
predicate->function = predicate->function_base->prepare(arguments); predicate->function = predicate->function_base->prepare(arguments);
} }
} }
@ -1939,7 +2015,9 @@ ActionsDAGPtr ActionsDAG::cloneActionsForFilterPushDown(
predicate->children.swap(new_children); predicate->children.swap(new_children);
auto arguments = prepareFunctionArguments(predicate->children); auto arguments = prepareFunctionArguments(predicate->children);
predicate->function_base = predicate->function_builder->build(arguments); FunctionOverloadResolverPtr func_builder_and = std::make_unique<FunctionToOverloadResolverAdaptor>(std::make_shared<FunctionAnd>());
predicate->function_base = func_builder_and->build(arguments);
predicate->function = predicate->function_base->prepare(arguments); predicate->function = predicate->function_base->prepare(arguments);
} }
} }

View File

@ -74,7 +74,6 @@ public:
std::string result_name; std::string result_name;
DataTypePtr result_type; DataTypePtr result_type;
FunctionOverloadResolverPtr function_builder;
/// Can be used to get function signature or properties like monotonicity. /// Can be used to get function signature or properties like monotonicity.
FunctionBasePtr function_base; FunctionBasePtr function_base;
/// Prepared function which is used in function execution. /// Prepared function which is used in function execution.
@ -139,6 +138,10 @@ public:
const FunctionOverloadResolverPtr & function, const FunctionOverloadResolverPtr & function,
NodeRawConstPtrs children, NodeRawConstPtrs children,
std::string result_name); std::string result_name);
const Node & addFunction(
const FunctionBasePtr & function_base,
NodeRawConstPtrs children,
std::string result_name);
/// Find first column by name in output nodes. This search is linear. /// Find first column by name in output nodes. This search is linear.
const Node & findInOutputs(const std::string & name) const; const Node & findInOutputs(const std::string & name) const;

View File

@ -53,7 +53,7 @@ void AggregateDescription::explain(WriteBuffer & out, size_t indent) const
out << type->getName(); out << type->getName();
} }
out << ") → " << function->getReturnType()->getName() << "\n"; out << ") → " << function->getResultType()->getName() << "\n";
} }
else else
out << prefix << " Function: nullptr\n"; out << prefix << " Function: nullptr\n";
@ -109,7 +109,7 @@ void AggregateDescription::explain(JSONBuilder::JSONMap & map) const
args_array->add(type->getName()); args_array->add(type->getName());
function_map->add("Argument Types", std::move(args_array)); function_map->add("Argument Types", std::move(args_array));
function_map->add("Result Type", function->getReturnType()->getName()); function_map->add("Result Type", function->getResultType()->getName());
map.add("Function", std::move(function_map)); map.add("Function", std::move(function_map));
} }

View File

@ -45,7 +45,7 @@ OutputBlockColumns prepareOutputBlockColumns(
} }
else else
{ {
final_aggregate_columns[i] = aggregate_functions[i]->getReturnType()->createColumn(); final_aggregate_columns[i] = aggregate_functions[i]->getResultType()->createColumn();
final_aggregate_columns[i]->reserve(rows); final_aggregate_columns[i]->reserve(rows);
if (aggregate_functions[i]->isState()) if (aggregate_functions[i]->isState())

View File

@ -433,7 +433,7 @@ Block Aggregator::Params::getHeader(
{ {
auto & elem = res.getByName(aggregate.column_name); auto & elem = res.getByName(aggregate.column_name);
elem.type = aggregate.function->getReturnType(); elem.type = aggregate.function->getResultType();
elem.column = elem.type->createColumn(); elem.column = elem.type->createColumn();
} }
} }
@ -452,7 +452,7 @@ Block Aggregator::Params::getHeader(
DataTypePtr type; DataTypePtr type;
if (final) if (final)
type = aggregate.function->getReturnType(); type = aggregate.function->getResultType();
else else
type = std::make_shared<DataTypeAggregateFunction>(aggregate.function, argument_types, aggregate.parameters); type = std::make_shared<DataTypeAggregateFunction>(aggregate.function, argument_types, aggregate.parameters);

View File

@ -423,7 +423,7 @@ void ExpressionAnalyzer::analyzeAggregation(ActionsDAGPtr & temp_actions)
aggregated_columns = temp_actions->getNamesAndTypesList(); aggregated_columns = temp_actions->getNamesAndTypesList();
for (const auto & desc : aggregate_descriptions) for (const auto & desc : aggregate_descriptions)
aggregated_columns.emplace_back(desc.column_name, desc.function->getReturnType()); aggregated_columns.emplace_back(desc.column_name, desc.function->getResultType());
} }
@ -2021,7 +2021,7 @@ ExpressionAnalysisResult::ExpressionAnalysisResult(
for (const auto & f : w.window_functions) for (const auto & f : w.window_functions)
{ {
query_analyzer.columns_after_window.push_back( query_analyzer.columns_after_window.push_back(
{f.column_name, f.aggregate_function->getReturnType()}); {f.column_name, f.aggregate_function->getResultType()});
} }
} }

View File

@ -403,7 +403,7 @@ static void compileInsertAggregatesIntoResultColumns(llvm::Module & module, cons
std::vector<ColumnDataPlaceholder> columns(functions.size()); std::vector<ColumnDataPlaceholder> columns(functions.size());
for (size_t i = 0; i < functions.size(); ++i) for (size_t i = 0; i < functions.size(); ++i)
{ {
auto return_type = functions[i].function->getReturnType(); auto return_type = functions[i].function->getResultType();
auto * data = b.CreateLoad(column_type, b.CreateConstInBoundsGEP1_64(column_type, columns_arg, i)); auto * data = b.CreateLoad(column_type, b.CreateConstInBoundsGEP1_64(column_type, columns_arg, i));
auto * column_data_type = toNativeType(b, removeNullable(return_type)); auto * column_data_type = toNativeType(b, removeNullable(return_type));

View File

@ -365,8 +365,8 @@ void Planner::buildQueryPlanIfNeeded()
{ {
auto function_node = std::make_shared<FunctionNode>("and"); auto function_node = std::make_shared<FunctionNode>("and");
auto and_function = FunctionFactory::instance().get("and", query_context); auto and_function = FunctionFactory::instance().get("and", query_context);
function_node->resolveAsFunction(std::move(and_function), std::make_shared<DataTypeUInt8>());
function_node->getArguments().getNodes() = {query_node.getPrewhere(), query_node.getWhere()}; function_node->getArguments().getNodes() = {query_node.getPrewhere(), query_node.getWhere()};
function_node->resolveAsFunction(and_function->build(function_node->getArgumentTypes()));
query_node.getWhere() = std::move(function_node); query_node.getWhere() = std::move(function_node);
query_node.getPrewhere() = {}; query_node.getPrewhere() = {};
} }

View File

@ -121,7 +121,7 @@ public:
return node; return node;
} }
const ActionsDAG::Node * addFunctionIfNecessary(const std::string & node_name, ActionsDAG::NodeRawConstPtrs children, FunctionOverloadResolverPtr function) const ActionsDAG::Node * addFunctionIfNecessary(const std::string & node_name, ActionsDAG::NodeRawConstPtrs children, FunctionBasePtr function)
{ {
auto it = node_name_to_node.find(node_name); auto it = node_name_to_node.find(node_name);
if (it != node_name_to_node.end()) if (it != node_name_to_node.end())
@ -325,7 +325,7 @@ PlannerActionsVisitorImpl::NodeNameAndNodeMinLevel PlannerActionsVisitorImpl::vi
lambda_actions, captured_column_names, lambda_arguments_names_and_types, result_type, lambda_expression_node_name); lambda_actions, captured_column_names, lambda_arguments_names_and_types, result_type, lambda_expression_node_name);
actions_stack.pop_back(); actions_stack.pop_back();
actions_stack[level].addFunctionIfNecessary(lambda_node_name, std::move(lambda_children), std::move(function_capture)); actions_stack[level].addFunctionIfNecessary(lambda_node_name, std::move(lambda_children), function_capture->build({}));
size_t actions_stack_size = actions_stack.size(); size_t actions_stack_size = actions_stack.size();
for (size_t i = level + 1; i < actions_stack_size; ++i) for (size_t i = level + 1; i < actions_stack_size; ++i)

View File

@ -101,14 +101,14 @@ public:
{ {
auto grouping_ordinary_function = std::make_shared<FunctionGroupingOrdinary>(arguments_indexes, force_grouping_standard_compatibility); auto grouping_ordinary_function = std::make_shared<FunctionGroupingOrdinary>(arguments_indexes, force_grouping_standard_compatibility);
auto grouping_ordinary_function_adaptor = std::make_shared<FunctionToOverloadResolverAdaptor>(std::move(grouping_ordinary_function)); auto grouping_ordinary_function_adaptor = std::make_shared<FunctionToOverloadResolverAdaptor>(std::move(grouping_ordinary_function));
function_node->resolveAsFunction(std::move(grouping_ordinary_function_adaptor), std::make_shared<DataTypeUInt64>()); function_node->resolveAsFunction(grouping_ordinary_function_adaptor->build({}));
break; break;
} }
case GroupByKind::ROLLUP: case GroupByKind::ROLLUP:
{ {
auto grouping_rollup_function = std::make_shared<FunctionGroupingForRollup>(arguments_indexes, aggregation_keys_size, force_grouping_standard_compatibility); auto grouping_rollup_function = std::make_shared<FunctionGroupingForRollup>(arguments_indexes, aggregation_keys_size, force_grouping_standard_compatibility);
auto grouping_rollup_function_adaptor = std::make_shared<FunctionToOverloadResolverAdaptor>(std::move(grouping_rollup_function)); auto grouping_rollup_function_adaptor = std::make_shared<FunctionToOverloadResolverAdaptor>(std::move(grouping_rollup_function));
function_node->resolveAsFunction(std::move(grouping_rollup_function_adaptor), std::make_shared<DataTypeUInt64>()); function_node->resolveAsFunction(grouping_rollup_function_adaptor->build({}));
function_node->getArguments().getNodes().push_back(std::move(grouping_set_argument_column)); function_node->getArguments().getNodes().push_back(std::move(grouping_set_argument_column));
break; break;
} }
@ -116,7 +116,7 @@ public:
{ {
auto grouping_cube_function = std::make_shared<FunctionGroupingForCube>(arguments_indexes, aggregation_keys_size, force_grouping_standard_compatibility); auto grouping_cube_function = std::make_shared<FunctionGroupingForCube>(arguments_indexes, aggregation_keys_size, force_grouping_standard_compatibility);
auto grouping_cube_function_adaptor = std::make_shared<FunctionToOverloadResolverAdaptor>(std::move(grouping_cube_function)); auto grouping_cube_function_adaptor = std::make_shared<FunctionToOverloadResolverAdaptor>(std::move(grouping_cube_function));
function_node->resolveAsFunction(std::move(grouping_cube_function_adaptor), std::make_shared<DataTypeUInt64>()); function_node->resolveAsFunction(grouping_cube_function_adaptor->build({}));
function_node->getArguments().getNodes().push_back(std::move(grouping_set_argument_column)); function_node->getArguments().getNodes().push_back(std::move(grouping_set_argument_column));
break; break;
} }
@ -124,7 +124,7 @@ public:
{ {
auto grouping_grouping_sets_function = std::make_shared<FunctionGroupingForGroupingSets>(arguments_indexes, grouping_sets_keys_indices, force_grouping_standard_compatibility); auto grouping_grouping_sets_function = std::make_shared<FunctionGroupingForGroupingSets>(arguments_indexes, grouping_sets_keys_indices, force_grouping_standard_compatibility);
auto grouping_grouping_sets_function_adaptor = std::make_shared<FunctionToOverloadResolverAdaptor>(std::move(grouping_grouping_sets_function)); auto grouping_grouping_sets_function_adaptor = std::make_shared<FunctionToOverloadResolverAdaptor>(std::move(grouping_grouping_sets_function));
function_node->resolveAsFunction(std::move(grouping_grouping_sets_function_adaptor), std::make_shared<DataTypeUInt64>()); function_node->resolveAsFunction(grouping_grouping_sets_function_adaptor->build({}));
function_node->getArguments().getNodes().push_back(std::move(grouping_set_argument_column)); function_node->getArguments().getNodes().push_back(std::move(grouping_set_argument_column));
break; break;
} }

View File

@ -65,7 +65,7 @@ std::optional<AggregationAnalysisResult> analyzeAggregation(QueryTreeNodePtr & q
ColumnsWithTypeAndName aggregates_columns; ColumnsWithTypeAndName aggregates_columns;
aggregates_columns.reserve(aggregates_descriptions.size()); aggregates_columns.reserve(aggregates_descriptions.size());
for (auto & aggregate_description : aggregates_descriptions) for (auto & aggregate_description : aggregates_descriptions)
aggregates_columns.emplace_back(nullptr, aggregate_description.function->getReturnType(), aggregate_description.column_name); aggregates_columns.emplace_back(nullptr, aggregate_description.function->getResultType(), aggregate_description.column_name);
Names aggregation_keys; Names aggregation_keys;
@ -284,7 +284,7 @@ std::optional<WindowAnalysisResult> analyzeWindow(QueryTreeNodePtr & query_tree,
for (auto & window_description : window_descriptions) for (auto & window_description : window_descriptions)
for (auto & window_function : window_description.window_functions) for (auto & window_function : window_description.window_functions)
window_functions_additional_columns.emplace_back(nullptr, window_function.aggregate_function->getReturnType(), window_function.column_name); window_functions_additional_columns.emplace_back(nullptr, window_function.aggregate_function->getResultType(), window_function.column_name);
auto before_window_step = std::make_unique<ActionsChainStep>(before_window_actions, auto before_window_step = std::make_unique<ActionsChainStep>(before_window_actions,
ActionsChainStep::AvailableOutputColumnsStrategy::ALL_NODES, ActionsChainStep::AvailableOutputColumnsStrategy::ALL_NODES,

Some files were not shown because too many files have changed in this diff Show More