Refactor FunctionNode

This commit is contained in:
Dmitry Novik 2022-11-28 15:02:59 +00:00
parent b6eddbac0d
commit 2c70dbc76a
106 changed files with 636 additions and 558 deletions

View File

@ -49,14 +49,16 @@ private:
public:
AggregateFunctionThrow(const DataTypes & argument_types_, const Array & parameters_, Float64 throw_probability_)
: IAggregateFunctionDataHelper(argument_types_, parameters_), throw_probability(throw_probability_) {}
: IAggregateFunctionDataHelper(argument_types_, parameters_, createResultType())
, throw_probability(throw_probability_)
{}
String getName() const override
{
return "aggThrow";
}
DataTypePtr getReturnType() const override
static DataTypePtr createResultType()
{
return std::make_shared<DataTypeUInt8>();
}

View File

@ -37,10 +37,10 @@ class AggregateFunctionAnalysisOfVariance final : public IAggregateFunctionDataH
{
public:
explicit AggregateFunctionAnalysisOfVariance(const DataTypes & arguments, const Array & params)
: IAggregateFunctionDataHelper(arguments, params)
: IAggregateFunctionDataHelper(arguments, params, createResultType())
{}
DataTypePtr getReturnType() const override
DataTypePtr createResultType() const
{
DataTypes types {std::make_shared<DataTypeNumber<Float64>>(), std::make_shared<DataTypeNumber<Float64>>() };
Strings names {"f_statistic", "p_value"};

View File

@ -38,7 +38,6 @@ template <typename Data>
class AggregateFunctionArgMinMax final : public IAggregateFunctionDataHelper<Data, AggregateFunctionArgMinMax<Data>>
{
private:
const DataTypePtr & type_res;
const DataTypePtr & type_val;
const SerializationPtr serialization_res;
const SerializationPtr serialization_val;
@ -47,10 +46,9 @@ private:
public:
AggregateFunctionArgMinMax(const DataTypePtr & type_res_, const DataTypePtr & type_val_)
: Base({type_res_, type_val_}, {})
, type_res(this->argument_types[0])
: Base({type_res_, type_val_}, {}, type_res_)
, type_val(this->argument_types[1])
, serialization_res(type_res->getDefaultSerialization())
, serialization_res(type_res_->getDefaultSerialization())
, serialization_val(type_val->getDefaultSerialization())
{
if (!type_val->isComparable())
@ -63,11 +61,6 @@ public:
return StringRef(Data::ValueData_t::name()) == StringRef("min") ? "argMin" : "argMax";
}
DataTypePtr getReturnType() const override
{
return type_res;
}
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{
if (this->data(place).value.changeIfBetter(*columns[1], row_num, arena))

View File

@ -30,7 +30,7 @@ private:
public:
AggregateFunctionArray(AggregateFunctionPtr nested_, const DataTypes & arguments, const Array & params_)
: IAggregateFunctionHelper<AggregateFunctionArray>(arguments, params_)
: IAggregateFunctionHelper<AggregateFunctionArray>(arguments, params_, createResultType())
, nested_func(nested_), num_arguments(arguments.size())
{
assert(parameters == nested_func->getParameters());
@ -44,9 +44,9 @@ public:
return nested_func->getName() + "Array";
}
DataTypePtr getReturnType() const override
DataTypePtr createResultType() const
{
return nested_func->getReturnType();
return nested_func->getResultType();
}
const IAggregateFunction & getBaseAggregateFunctionWithSameStateRepresentation() const override

View File

@ -11,6 +11,9 @@
#include <AggregateFunctions/AggregateFunctionSum.h>
#include <Core/DecimalFunctions.h>
#include "Core/IResolvedFunction.h"
#include "DataTypes/IDataType.h"
#include "DataTypes/Serializations/ISerialization.h"
#include "config.h"
#if USE_EMBEDDED_COMPILER
@ -83,10 +86,20 @@ public:
using Fraction = AvgFraction<Numerator, Denominator>;
explicit AggregateFunctionAvgBase(const DataTypes & argument_types_,
UInt32 num_scale_ = 0, UInt32 denom_scale_ = 0)
: Base(argument_types_, {}), num_scale(num_scale_), denom_scale(denom_scale_) {}
UInt32 num_scale_ = 0, UInt32 denom_scale_ = 0)
: Base(argument_types_, {}, createResultType())
, num_scale(num_scale_)
, denom_scale(denom_scale_)
{}
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeNumber<Float64>>(); }
AggregateFunctionAvgBase(const DataTypes & argument_types_, const DataTypePtr & result_type_,
UInt32 num_scale_ = 0, UInt32 denom_scale_ = 0)
: Base(argument_types_, {}, result_type_)
, num_scale(num_scale_)
, denom_scale(denom_scale_)
{}
DataTypePtr createResultType() const { return std::make_shared<DataTypeNumber<Float64>>(); }
bool allocatesMemoryInArena() const override { return false; }
@ -135,7 +148,7 @@ public:
for (const auto & argument : this->argument_types)
can_be_compiled &= canBeNativeType(*argument);
auto return_type = getReturnType();
auto return_type = this->getResultType();
can_be_compiled &= canBeNativeType(*return_type);
return can_be_compiled;

View File

@ -97,11 +97,12 @@ class AggregateFunctionBitwise final : public IAggregateFunctionDataHelper<Data,
{
public:
explicit AggregateFunctionBitwise(const DataTypePtr & type)
: IAggregateFunctionDataHelper<Data, AggregateFunctionBitwise<T, Data>>({type}, {}) {}
: IAggregateFunctionDataHelper<Data, AggregateFunctionBitwise<T, Data>>({type}, {}, createResultType())
{}
String getName() const override { return Data::name(); }
DataTypePtr getReturnType() const override
static DataTypePtr createResultType()
{
return std::make_shared<DataTypeNumber<T>>();
}
@ -137,7 +138,7 @@ public:
bool isCompilable() const override
{
auto return_type = getReturnType();
auto return_type = this->getResultType();
return canBeNativeType(*return_type);
}
@ -151,7 +152,7 @@ public:
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * return_type = toNativeType(b, getReturnType());
auto * return_type = toNativeType(b, this->getResultType());
auto * value_ptr = aggregate_data_ptr;
auto * value = b.CreateLoad(return_type, value_ptr);
@ -166,7 +167,7 @@ public:
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * return_type = toNativeType(b, getReturnType());
auto * return_type = toNativeType(b, this->getResultType());
auto * value_dst_ptr = aggregate_data_dst_ptr;
auto * value_dst = b.CreateLoad(return_type, value_dst_ptr);
@ -183,7 +184,7 @@ public:
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * return_type = toNativeType(b, getReturnType());
auto * return_type = toNativeType(b, this->getResultType());
auto * value_ptr = aggregate_data_ptr;
return b.CreateLoad(return_type, value_ptr);

View File

@ -112,7 +112,7 @@ public:
}
explicit AggregateFunctionBoundingRatio(const DataTypes & arguments)
: IAggregateFunctionDataHelper<AggregateFunctionBoundingRatioData, AggregateFunctionBoundingRatio>(arguments, {})
: IAggregateFunctionDataHelper<AggregateFunctionBoundingRatioData, AggregateFunctionBoundingRatio>(arguments, {}, std::make_shared<DataTypeFloat64>())
{
const auto * x_arg = arguments.at(0).get();
const auto * y_arg = arguments.at(1).get();
@ -122,11 +122,6 @@ public:
ErrorCodes::BAD_ARGUMENTS);
}
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeFloat64>();
}
bool allocatesMemoryInArena() const override { return false; }
void add(AggregateDataPtr __restrict place, const IColumn ** columns, const size_t row_num, Arena *) const override

View File

@ -46,9 +46,9 @@ private:
}
public:
AggregateFunctionCategoricalIV(const DataTypes & arguments_, const Array & params_) :
IAggregateFunctionHelper<AggregateFunctionCategoricalIV>{arguments_, params_},
category_count{arguments_.size() - 1}
AggregateFunctionCategoricalIV(const DataTypes & arguments_, const Array & params_)
: IAggregateFunctionHelper<AggregateFunctionCategoricalIV>{arguments_, params_, createResultType()}
, category_count{arguments_.size() - 1}
{
// notice: argument types has been checked before
}
@ -121,7 +121,7 @@ public:
buf.readStrict(place, sizeOfData());
}
DataTypePtr getReturnType() const override
static DataTypePtr createResultType()
{
return std::make_shared<DataTypeArray>(
std::make_shared<DataTypeNumber<Float64>>());

View File

@ -39,11 +39,13 @@ namespace ErrorCodes
class AggregateFunctionCount final : public IAggregateFunctionDataHelper<AggregateFunctionCountData, AggregateFunctionCount>
{
public:
explicit AggregateFunctionCount(const DataTypes & argument_types_) : IAggregateFunctionDataHelper(argument_types_, {}) {}
explicit AggregateFunctionCount(const DataTypes & argument_types_)
: IAggregateFunctionDataHelper(argument_types_, {}, createResultType())
{}
String getName() const override { return "count"; }
DataTypePtr getReturnType() const override
static DataTypePtr createResultType()
{
return std::make_shared<DataTypeUInt64>();
}
@ -167,7 +169,7 @@ public:
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * return_type = toNativeType(b, getReturnType());
auto * return_type = toNativeType(b, this->getResultType());
auto * count_value_ptr = aggregate_data_ptr;
auto * count_value = b.CreateLoad(return_type, count_value_ptr);
@ -180,7 +182,7 @@ public:
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * return_type = toNativeType(b, getReturnType());
auto * return_type = toNativeType(b, this->getResultType());
auto * count_value_dst_ptr = aggregate_data_dst_ptr;
auto * count_value_dst = b.CreateLoad(return_type, count_value_dst_ptr);
@ -197,7 +199,7 @@ public:
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * return_type = toNativeType(b, getReturnType());
auto * return_type = toNativeType(b, this->getResultType());
auto * count_value_ptr = aggregate_data_ptr;
return b.CreateLoad(return_type, count_value_ptr);
@ -214,7 +216,7 @@ class AggregateFunctionCountNotNullUnary final
{
public:
AggregateFunctionCountNotNullUnary(const DataTypePtr & argument, const Array & params)
: IAggregateFunctionDataHelper<AggregateFunctionCountData, AggregateFunctionCountNotNullUnary>({argument}, params)
: IAggregateFunctionDataHelper<AggregateFunctionCountData, AggregateFunctionCountNotNullUnary>({argument}, params, createResultType())
{
if (!argument->isNullable())
throw Exception("Logical error: not Nullable data type passed to AggregateFunctionCountNotNullUnary", ErrorCodes::LOGICAL_ERROR);
@ -222,7 +224,7 @@ public:
String getName() const override { return "count"; }
DataTypePtr getReturnType() const override
static DataTypePtr createResultType()
{
return std::make_shared<DataTypeUInt64>();
}
@ -311,7 +313,7 @@ public:
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * return_type = toNativeType(b, getReturnType());
auto * return_type = toNativeType(b, this->getResultType());
auto * is_null_value = b.CreateExtractValue(values[0], {1});
auto * increment_value = b.CreateSelect(is_null_value, llvm::ConstantInt::get(return_type, 0), llvm::ConstantInt::get(return_type, 1));
@ -327,7 +329,7 @@ public:
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * return_type = toNativeType(b, getReturnType());
auto * return_type = toNativeType(b, this->getResultType());
auto * count_value_dst_ptr = aggregate_data_dst_ptr;
auto * count_value_dst = b.CreateLoad(return_type, count_value_dst_ptr);
@ -344,7 +346,7 @@ public:
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * return_type = toNativeType(b, getReturnType());
auto * return_type = toNativeType(b, this->getResultType());
auto * count_value_ptr = aggregate_data_ptr;
return b.CreateLoad(return_type, count_value_ptr);

View File

@ -31,7 +31,7 @@ class AggregationFunctionDeltaSum final
{
public:
AggregationFunctionDeltaSum(const DataTypes & arguments, const Array & params)
: IAggregateFunctionDataHelper<AggregationFunctionDeltaSumData<T>, AggregationFunctionDeltaSum<T>>{arguments, params}
: IAggregateFunctionDataHelper<AggregationFunctionDeltaSumData<T>, AggregationFunctionDeltaSum<T>>{arguments, params, createResultType()}
{}
AggregationFunctionDeltaSum()
@ -40,7 +40,7 @@ public:
String getName() const override { return "deltaSum"; }
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeNumber<T>>(); }
static DataTypePtr createResultType() { return std::make_shared<DataTypeNumber<T>>(); }
bool allocatesMemoryInArena() const override { return false; }

View File

@ -38,7 +38,7 @@ public:
: IAggregateFunctionDataHelper<
AggregationFunctionDeltaSumTimestampData<ValueType, TimestampType>,
AggregationFunctionDeltaSumTimestamp<ValueType, TimestampType>
>{arguments, params}
>{arguments, params, createResultType()}
{}
AggregationFunctionDeltaSumTimestamp()
@ -52,7 +52,7 @@ public:
String getName() const override { return "deltaSumTimestamp"; }
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeNumber<ValueType>>(); }
static DataTypePtr createResultType() { return std::make_shared<DataTypeNumber<ValueType>>(); }
void NO_SANITIZE_UNDEFINED ALWAYS_INLINE add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
{

View File

@ -168,7 +168,7 @@ private:
public:
AggregateFunctionDistinct(AggregateFunctionPtr nested_func_, const DataTypes & arguments, const Array & params_)
: IAggregateFunctionDataHelper<Data, AggregateFunctionDistinct>(arguments, params_)
: IAggregateFunctionDataHelper<Data, AggregateFunctionDistinct>(arguments, params_, nested_func_->getResultType())
, nested_func(nested_func_)
, arguments_num(arguments.size())
{
@ -255,11 +255,6 @@ public:
return nested_func->getName() + "Distinct";
}
DataTypePtr getReturnType() const override
{
return nested_func->getReturnType();
}
bool allocatesMemoryInArena() const override
{
return true;

View File

@ -92,14 +92,14 @@ private:
public:
explicit AggregateFunctionEntropy(const DataTypes & argument_types_)
: IAggregateFunctionDataHelper<EntropyData<Value>, AggregateFunctionEntropy<Value>>(argument_types_, {})
: IAggregateFunctionDataHelper<EntropyData<Value>, AggregateFunctionEntropy<Value>>(argument_types_, {}, createResultType())
, num_args(argument_types_.size())
{
}
String getName() const override { return "entropy"; }
DataTypePtr getReturnType() const override
static DataTypePtr createResultType()
{
return std::make_shared<DataTypeNumber<Float64>>();
}

View File

@ -29,7 +29,7 @@ private:
public:
AggregateFunctionExponentialMovingAverage(const DataTypes & argument_types_, const Array & params)
: IAggregateFunctionDataHelper<ExponentiallySmoothedAverage, AggregateFunctionExponentialMovingAverage>(argument_types_, params)
: IAggregateFunctionDataHelper<ExponentiallySmoothedAverage, AggregateFunctionExponentialMovingAverage>(argument_types_, params, createResultType())
{
if (params.size() != 1)
throw Exception{"Aggregate function " + getName() + " requires exactly one parameter: half decay time.",
@ -43,7 +43,7 @@ public:
return "exponentialMovingAverage";
}
DataTypePtr getReturnType() const override
static DataTypePtr createResultType()
{
return std::make_shared<DataTypeNumber<Float64>>();
}

View File

@ -107,7 +107,7 @@ private:
public:
AggregateFunctionForEach(AggregateFunctionPtr nested_, const DataTypes & arguments, const Array & params_)
: IAggregateFunctionDataHelper<AggregateFunctionForEachData, AggregateFunctionForEach>(arguments, params_)
: IAggregateFunctionDataHelper<AggregateFunctionForEachData, AggregateFunctionForEach>(arguments, params_, nested_->getResultType())
, nested_func(nested_), num_arguments(arguments.size())
{
nested_size_of_data = nested_func->sizeOfData();
@ -125,11 +125,6 @@ public:
return nested_func->getName() + "ForEach";
}
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeArray>(nested_func->getReturnType());
}
bool isVersioned() const override
{
return nested_func->isVersioned();

View File

@ -121,7 +121,7 @@ public:
explicit GroupArrayNumericImpl(
const DataTypePtr & data_type_, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max(), UInt64 seed_ = 123456)
: IAggregateFunctionDataHelper<GroupArrayNumericData<T, Trait::sampler != Sampler::NONE>, GroupArrayNumericImpl<T, Trait>>(
{data_type_}, parameters_)
{data_type_}, parameters_, std::make_shared<DataTypeArray>(data_type_))
, max_elems(max_elems_)
, seed(seed_)
{
@ -129,8 +129,6 @@ public:
String getName() const override { return getNameByTrait<Trait>(); }
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeArray>(this->argument_types[0]); }
void insert(Data & a, const T & v, Arena * arena) const
{
++a.total_values;
@ -423,7 +421,7 @@ class GroupArrayGeneralImpl final
public:
GroupArrayGeneralImpl(const DataTypePtr & data_type_, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max(), UInt64 seed_ = 123456)
: IAggregateFunctionDataHelper<GroupArrayGeneralData<Node, Trait::sampler != Sampler::NONE>, GroupArrayGeneralImpl<Node, Trait>>(
{data_type_}, parameters_)
{data_type_}, parameters_, std::make_shared<DataTypeArray>(data_type_))
, data_type(this->argument_types[0])
, max_elems(max_elems_)
, seed(seed_)
@ -432,8 +430,6 @@ public:
String getName() const override { return getNameByTrait<Trait>(); }
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeArray>(data_type); }
void insert(Data & a, const Node * v, Arena * arena) const
{
++a.total_values;
@ -697,7 +693,7 @@ class GroupArrayGeneralListImpl final
public:
GroupArrayGeneralListImpl(const DataTypePtr & data_type_, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max())
: IAggregateFunctionDataHelper<GroupArrayGeneralListData<Node>, GroupArrayGeneralListImpl<Node, Trait>>({data_type_}, parameters_)
: IAggregateFunctionDataHelper<GroupArrayGeneralListData<Node>, GroupArrayGeneralListImpl<Node, Trait>>({data_type_}, parameters_, std::make_shared<DataTypeArray>(data_type_))
, data_type(this->argument_types[0])
, max_elems(max_elems_)
{
@ -705,8 +701,6 @@ public:
String getName() const override { return getNameByTrait<Trait>(); }
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeArray>(data_type); }
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{
if (limit_num_elems && data(place).elems >= max_elems)

View File

@ -64,7 +64,7 @@ private:
public:
AggregateFunctionGroupArrayInsertAtGeneric(const DataTypes & arguments, const Array & params)
: IAggregateFunctionDataHelper<AggregateFunctionGroupArrayInsertAtDataGeneric, AggregateFunctionGroupArrayInsertAtGeneric>(arguments, params)
: IAggregateFunctionDataHelper<AggregateFunctionGroupArrayInsertAtDataGeneric, AggregateFunctionGroupArrayInsertAtGeneric>(arguments, params, std::make_shared<DataTypeArray>(arguments[0]))
, type(argument_types[0])
, serialization(type->getDefaultSerialization())
{
@ -101,11 +101,6 @@ public:
String getName() const override { return "groupArrayInsertAt"; }
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeArray>(type);
}
bool allocatesMemoryInArena() const override { return false; }
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override

View File

@ -93,12 +93,15 @@ public:
using ColumnResult = ColumnVectorOrDecimal<ResultT>;
explicit MovingImpl(const DataTypePtr & data_type_, UInt64 window_size_ = std::numeric_limits<UInt64>::max())
: IAggregateFunctionDataHelper<Data, MovingImpl<T, LimitNumElements, Data>>({data_type_}, {})
: IAggregateFunctionDataHelper<Data, MovingImpl<T, LimitNumElements, Data>>({data_type_}, {}, createResultType(data_type_))
, window_size(window_size_) {}
String getName() const override { return Data::name; }
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeArray>(getReturnTypeElement()); }
static DataTypePtr createResultType(const DataTypePtr & argument)
{
return std::make_shared<DataTypeArray>(getReturnTypeElement(argument));
}
void NO_SANITIZE_UNDEFINED add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{
@ -183,14 +186,14 @@ public:
}
private:
auto getReturnTypeElement() const
static auto getReturnTypeElement(const DataTypePtr & argument)
{
if constexpr (!is_decimal<ResultT>)
return std::make_shared<DataTypeNumber<ResultT>>();
else
{
using Res = DataTypeDecimal<ResultT>;
return std::make_shared<Res>(Res::maxPrecision(), getDecimalScale(*this->argument_types.at(0)));
return std::make_shared<Res>(Res::maxPrecision(), getDecimalScale(*argument));
}
}
};

View File

@ -74,7 +74,7 @@ namespace
/// groupBitmap needs to know about the data type that was used to create bitmaps.
/// We need to look inside the type of its argument to obtain it.
const DataTypeAggregateFunction & datatype_aggfunc = dynamic_cast<const DataTypeAggregateFunction &>(*argument_type_ptr);
AggregateFunctionPtr aggfunc = datatype_aggfunc.getFunction();
ConstAggregateFunctionPtr aggfunc = datatype_aggfunc.getFunction();
if (aggfunc->getName() != AggregateFunctionGroupBitmapData<UInt8>::name())
throw Exception(

View File

@ -19,13 +19,13 @@ class AggregateFunctionBitmap final : public IAggregateFunctionDataHelper<Data,
{
public:
explicit AggregateFunctionBitmap(const DataTypePtr & type)
: IAggregateFunctionDataHelper<Data, AggregateFunctionBitmap<T, Data>>({type}, {})
: IAggregateFunctionDataHelper<Data, AggregateFunctionBitmap<T, Data>>({type}, {}, createResultType())
{
}
String getName() const override { return Data::name(); }
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeNumber<T>>(); }
static DataTypePtr createResultType() { return std::make_shared<DataTypeNumber<T>>(); }
bool allocatesMemoryInArena() const override { return false; }
@ -59,13 +59,13 @@ private:
static constexpr size_t STATE_VERSION_1_MIN_REVISION = 54455;
public:
explicit AggregateFunctionBitmapL2(const DataTypePtr & type)
: IAggregateFunctionDataHelper<Data, AggregateFunctionBitmapL2<T, Data, Policy>>({type}, {})
: IAggregateFunctionDataHelper<Data, AggregateFunctionBitmapL2<T, Data, Policy>>({type}, {}, createResultType())
{
}
String getName() const override { return Policy::name; }
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeNumber<T>>(); }
static DataTypePtr createResultType() { return std::make_shared<DataTypeNumber<T>>(); }
bool allocatesMemoryInArena() const override { return false; }

View File

@ -26,8 +26,8 @@ class AggregateFunctionGroupUniqArrayDate : public AggregateFunctionGroupUniqArr
{
public:
explicit AggregateFunctionGroupUniqArrayDate(const DataTypePtr & argument_type, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max())
: AggregateFunctionGroupUniqArray<DataTypeDate::FieldType, HasLimit>(argument_type, parameters_, max_elems_) {}
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeArray>(std::make_shared<DataTypeDate>()); }
: AggregateFunctionGroupUniqArray<DataTypeDate::FieldType, HasLimit>(argument_type, parameters_, createResultType(), max_elems_) {}
static DataTypePtr createResultType() { return std::make_shared<DataTypeArray>(std::make_shared<DataTypeDate>()); }
};
template <typename HasLimit>
@ -35,8 +35,8 @@ class AggregateFunctionGroupUniqArrayDateTime : public AggregateFunctionGroupUni
{
public:
explicit AggregateFunctionGroupUniqArrayDateTime(const DataTypePtr & argument_type, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max())
: AggregateFunctionGroupUniqArray<DataTypeDateTime::FieldType, HasLimit>(argument_type, parameters_, max_elems_) {}
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeArray>(std::make_shared<DataTypeDateTime>()); }
: AggregateFunctionGroupUniqArray<DataTypeDateTime::FieldType, HasLimit>(argument_type, parameters_, createResultType(), max_elems_) {}
static DataTypePtr createResultType() { return std::make_shared<DataTypeArray>(std::make_shared<DataTypeDateTime>()); }
};
template <typename HasLimit, typename ... TArgs>

View File

@ -50,15 +50,16 @@ private:
public:
AggregateFunctionGroupUniqArray(const DataTypePtr & argument_type, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max())
: IAggregateFunctionDataHelper<AggregateFunctionGroupUniqArrayData<T>,
AggregateFunctionGroupUniqArray<T, LimitNumElems>>({argument_type}, parameters_),
AggregateFunctionGroupUniqArray<T, LimitNumElems>>({argument_type}, parameters_, std::make_shared<DataTypeArray>(argument_type)),
max_elems(max_elems_) {}
String getName() const override { return "groupUniqArray"; }
AggregateFunctionGroupUniqArray(const DataTypePtr & argument_type, const Array & parameters_, const DataTypePtr & result_type_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max())
: IAggregateFunctionDataHelper<AggregateFunctionGroupUniqArrayData<T>,
AggregateFunctionGroupUniqArray<T, LimitNumElems>>({argument_type}, parameters_, result_type_),
max_elems(max_elems_) {}
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeArray>(this->argument_types[0]);
}
String getName() const override { return "groupUniqArray"; }
bool allocatesMemoryInArena() const override { return false; }
@ -153,17 +154,12 @@ class AggregateFunctionGroupUniqArrayGeneric
public:
AggregateFunctionGroupUniqArrayGeneric(const DataTypePtr & input_data_type_, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max())
: IAggregateFunctionDataHelper<AggregateFunctionGroupUniqArrayGenericData, AggregateFunctionGroupUniqArrayGeneric<is_plain_column, LimitNumElems>>({input_data_type_}, parameters_)
: IAggregateFunctionDataHelper<AggregateFunctionGroupUniqArrayGenericData, AggregateFunctionGroupUniqArrayGeneric<is_plain_column, LimitNumElems>>({input_data_type_}, parameters_, std::make_shared<DataTypeArray>(input_data_type_))
, input_data_type(this->argument_types[0])
, max_elems(max_elems_) {}
String getName() const override { return "groupUniqArray"; }
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeArray>(input_data_type);
}
bool allocatesMemoryInArena() const override
{
return true;

View File

@ -307,7 +307,7 @@ private:
public:
AggregateFunctionHistogram(UInt32 max_bins_, const DataTypes & arguments, const Array & params)
: IAggregateFunctionDataHelper<AggregateFunctionHistogramData, AggregateFunctionHistogram<T>>(arguments, params)
: IAggregateFunctionDataHelper<AggregateFunctionHistogramData, AggregateFunctionHistogram<T>>(arguments, params, createResultType())
, max_bins(max_bins_)
{
}
@ -316,7 +316,7 @@ public:
{
return Data::structSize(max_bins);
}
DataTypePtr getReturnType() const override
static DataTypePtr createResultType()
{
DataTypes types;
auto mean = std::make_shared<DataTypeNumber<Data::Mean>>();

View File

@ -448,7 +448,7 @@ AggregateFunctionPtr AggregateFunctionIf::getOwnNullAdapter(
/// Nullability of the last argument (condition) does not affect the nullability of the result (NULL is processed as false).
/// For other arguments it is as usual (at least one is NULL then the result is NULL if possible).
bool return_type_is_nullable = !properties.returns_default_when_only_null && getReturnType()->canBeInsideNullable()
bool return_type_is_nullable = !properties.returns_default_when_only_null && getResultType()->canBeInsideNullable()
&& std::any_of(arguments.begin(), arguments.end() - 1, [](const auto & element) { return element->isNullable(); });
bool need_to_serialize_flag = return_type_is_nullable || properties.returns_default_when_only_null;

View File

@ -36,7 +36,7 @@ private:
public:
AggregateFunctionIf(AggregateFunctionPtr nested, const DataTypes & types, const Array & params_)
: IAggregateFunctionHelper<AggregateFunctionIf>(types, params_)
: IAggregateFunctionHelper<AggregateFunctionIf>(types, params_, nested->getResultType())
, nested_func(nested), num_arguments(types.size())
{
if (num_arguments == 0)
@ -51,11 +51,6 @@ public:
return nested_func->getName() + "If";
}
DataTypePtr getReturnType() const override
{
return nested_func->getReturnType();
}
const IAggregateFunction & getBaseAggregateFunctionWithSameStateRepresentation() const override
{
return nested_func->getBaseAggregateFunctionWithSameStateRepresentation();

View File

@ -177,11 +177,11 @@ public:
String getName() const override { return "intervalLengthSum"; }
explicit AggregateFunctionIntervalLengthSum(const DataTypes & arguments)
: IAggregateFunctionDataHelper<Data, AggregateFunctionIntervalLengthSum<T, Data>>(arguments, {})
: IAggregateFunctionDataHelper<Data, AggregateFunctionIntervalLengthSum<T, Data>>(arguments, {}, createResultType())
{
}
DataTypePtr getReturnType() const override
static DataTypePtr createResultType()
{
if constexpr (std::is_floating_point_v<T>)
return std::make_shared<DataTypeFloat64>();

View File

@ -309,7 +309,7 @@ public:
UInt64 batch_size_,
const DataTypes & arguments_types,
const Array & params)
: IAggregateFunctionDataHelper<Data, AggregateFunctionMLMethod<Data, Name>>(arguments_types, params)
: IAggregateFunctionDataHelper<Data, AggregateFunctionMLMethod<Data, Name>>(arguments_types, params, createResultType())
, param_num(param_num_)
, learning_rate(learning_rate_)
, l2_reg_coef(l2_reg_coef_)
@ -319,8 +319,7 @@ public:
{
}
/// This function is called when SELECT linearRegression(...) is called
DataTypePtr getReturnType() const override
static DataTypePtr createResultType()
{
return std::make_shared<DataTypeArray>(std::make_shared<DataTypeFloat64>());
}

View File

@ -133,7 +133,7 @@ private:
public:
explicit AggregateFunctionMannWhitney(const DataTypes & arguments, const Array & params)
:IAggregateFunctionDataHelper<MannWhitneyData, AggregateFunctionMannWhitney> ({arguments}, {})
:IAggregateFunctionDataHelper<MannWhitneyData, AggregateFunctionMannWhitney> ({arguments}, {}, createResultType())
{
if (params.size() > 2)
throw Exception("Aggregate function " + getName() + " require two parameter or less", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
@ -174,7 +174,7 @@ public:
bool allocatesMemoryInArena() const override { return true; }
DataTypePtr getReturnType() const override
static DataTypePtr createResultType()
{
DataTypes types
{

View File

@ -18,6 +18,7 @@
#include <Functions/FunctionHelpers.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
#include "DataTypes/Serializations/ISerialization.h"
#include "base/types.h"
#include <Common/Arena.h>
#include "AggregateFunctions/AggregateFunctionFactory.h"
@ -104,26 +105,32 @@ public:
return nested_func->getDefaultVersion();
}
AggregateFunctionMap(AggregateFunctionPtr nested, const DataTypes & types) : Base(types, nested->getParameters()), nested_func(nested)
AggregateFunctionMap(AggregateFunctionPtr nested, const DataTypes & types)
: Base(types, nested->getParameters(), std::make_shared<DataTypeMap>(DataTypes{getKeyType(types, nested), nested->getResultType()}))
, nested_func(nested)
{
if (types.empty())
throw Exception(
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function " + getName() + " requires at least one argument");
if (types.size() > 1)
throw Exception(
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function " + getName() + " requires only one map argument");
const auto * map_type = checkAndGetDataType<DataTypeMap>(types[0].get());
if (!map_type)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Aggregate function " + getName() + " requires map as argument");
key_type = map_type->getKeyType();
key_type = getKeyType(types, nested_func);
}
String getName() const override { return nested_func->getName() + "Map"; }
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeMap>(DataTypes{key_type, nested_func->getReturnType()}); }
static DataTypePtr getKeyType(const DataTypes & types, const AggregateFunctionPtr & nested)
{
if (types.empty())
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Aggregate function {}Map requires at least one argument", nested->getName());
if (types.size() > 1)
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Aggregate function {}Map requires only one map argument", nested->getName());
const auto * map_type = checkAndGetDataType<DataTypeMap>(types[0].get());
if (!map_type)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Aggregate function {}Map requires map as argument", nested->getName());
return map_type->getKeyType();
}
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{

View File

@ -62,7 +62,8 @@ private:
public:
AggregateFunctionIntersectionsMax(AggregateFunctionIntersectionsKind kind_, const DataTypes & arguments)
: IAggregateFunctionDataHelper<MaxIntersectionsData<PointType>, AggregateFunctionIntersectionsMax<PointType>>(arguments, {}), kind(kind_)
: IAggregateFunctionDataHelper<MaxIntersectionsData<PointType>, AggregateFunctionIntersectionsMax<PointType>>(arguments, {}, createResultType(kind_))
, kind(kind_)
{
if (!isNativeNumber(arguments[0]))
throw Exception{getName() + ": first argument must be represented by integer", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
@ -81,9 +82,9 @@ public:
: "maxIntersectionsPosition";
}
DataTypePtr getReturnType() const override
static DataTypePtr createResultType(AggregateFunctionIntersectionsKind kind_)
{
if (kind == AggregateFunctionIntersectionsKind::Count)
if (kind_ == AggregateFunctionIntersectionsKind::Count)
return std::make_shared<DataTypeUInt64>();
else
return std::make_shared<DataTypeNumber<PointType>>();

View File

@ -36,7 +36,7 @@ private:
public:
AggregateFunctionMeanZTest(const DataTypes & arguments, const Array & params)
: IAggregateFunctionDataHelper<Data, AggregateFunctionMeanZTest<Data>>({arguments}, params)
: IAggregateFunctionDataHelper<Data, AggregateFunctionMeanZTest<Data>>({arguments}, params, createResultType())
{
pop_var_x = params.at(0).safeGet<Float64>();
pop_var_y = params.at(1).safeGet<Float64>();
@ -63,7 +63,7 @@ public:
return Data::name;
}
DataTypePtr getReturnType() const override
static DataTypePtr createResultType()
{
DataTypes types
{

View File

@ -30,7 +30,7 @@ private:
public:
AggregateFunctionMerge(const AggregateFunctionPtr & nested_, const DataTypePtr & argument, const Array & params_)
: IAggregateFunctionHelper<AggregateFunctionMerge>({argument}, params_)
: IAggregateFunctionHelper<AggregateFunctionMerge>({argument}, params_, createResultType())
, nested_func(nested_)
{
const DataTypeAggregateFunction * data_type = typeid_cast<const DataTypeAggregateFunction *>(argument.get());
@ -45,9 +45,9 @@ public:
return nested_func->getName() + "Merge";
}
DataTypePtr getReturnType() const override
DataTypePtr createResultType() const
{
return nested_func->getReturnType();
return nested_func->getResultType();
}
const IAggregateFunction & getBaseAggregateFunctionWithSameStateRepresentation() const override

View File

@ -1219,7 +1219,7 @@ private:
public:
explicit AggregateFunctionsSingleValue(const DataTypePtr & type)
: IAggregateFunctionDataHelper<Data, AggregateFunctionsSingleValue<Data>>({type}, {})
: IAggregateFunctionDataHelper<Data, AggregateFunctionsSingleValue<Data>>({type}, {}, createResultType())
, serialization(type->getDefaultSerialization())
{
if (StringRef(Data::name()) == StringRef("min")
@ -1233,7 +1233,7 @@ public:
String getName() const override { return Data::name(); }
DataTypePtr getReturnType() const override
DataTypePtr createResultType() const
{
auto result_type = this->argument_types.at(0);
if constexpr (Data::is_nullable)

View File

@ -6,6 +6,7 @@
#include <AggregateFunctions/IAggregateFunction.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
#include "DataTypes/IDataType.h"
namespace DB
@ -19,16 +20,16 @@ class AggregateFunctionNothing final : public IAggregateFunctionHelper<Aggregate
{
public:
AggregateFunctionNothing(const DataTypes & arguments, const Array & params)
: IAggregateFunctionHelper<AggregateFunctionNothing>(arguments, params) {}
: IAggregateFunctionHelper<AggregateFunctionNothing>(arguments, params, createResultType(arguments)) {}
String getName() const override
{
return "nothing";
}
DataTypePtr getReturnType() const override
static DataTypePtr createResultType(const DataTypes & arguments)
{
return argument_types.empty() ? std::make_shared<DataTypeNullable>(std::make_shared<DataTypeNothing>()) : argument_types.front();
return arguments.empty() ? std::make_shared<DataTypeNullable>(std::make_shared<DataTypeNothing>()) : arguments.front();
}
bool allocatesMemoryInArena() const override { return false; }

View File

@ -87,7 +87,7 @@ public:
transformed_nested_function->getParameters());
}
bool return_type_is_nullable = !properties.returns_default_when_only_null && nested_function->getReturnType()->canBeInsideNullable();
bool return_type_is_nullable = !properties.returns_default_when_only_null && nested_function->getResultType()->canBeInsideNullable();
bool serialize_flag = return_type_is_nullable || properties.returns_default_when_only_null;
if (arguments.size() == 1)

View File

@ -82,7 +82,8 @@ protected:
public:
AggregateFunctionNullBase(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params)
: IAggregateFunctionHelper<Derived>(arguments, params), nested_function{nested_function_}
: IAggregateFunctionHelper<Derived>(arguments, params, createResultType(nested_function_))
, nested_function{nested_function_}
{
if constexpr (result_is_nullable)
prefix_size = nested_function->alignOfData();
@ -96,11 +97,11 @@ public:
return nested_function->getName();
}
DataTypePtr getReturnType() const override
static DataTypePtr createResultType(const AggregateFunctionPtr & nested_function_)
{
return result_is_nullable
? makeNullable(nested_function->getReturnType())
: nested_function->getReturnType();
? makeNullable(nested_function_->getResultType())
: nested_function_->getResultType();
}
void create(AggregateDataPtr __restrict place) const override
@ -270,7 +271,7 @@ public:
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * return_type = toNativeType(b, this->getReturnType());
auto * return_type = toNativeType(b, this->getResultType());
llvm::Value * result = nullptr;

View File

@ -4,6 +4,7 @@
#include <Columns/ColumnNullable.h>
#include <Columns/ColumnsCommon.h>
#include <Common/typeid_cast.h>
#include "DataTypes/Serializations/ISerialization.h"
#include <DataTypes/DataTypeNullable.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
@ -30,16 +31,14 @@ private:
AggregateFunctionPtr nested_function;
size_t size_of_data;
DataTypePtr inner_type;
bool inner_nullable;
public:
AggregateFunctionOrFill(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params)
: IAggregateFunctionHelper<AggregateFunctionOrFill>{arguments, params}
: IAggregateFunctionHelper<AggregateFunctionOrFill>{arguments, params, createResultType(nested_function_->getResultType())}
, nested_function{nested_function_}
, size_of_data {nested_function->sizeOfData()}
, inner_type {nested_function->getReturnType()}
, inner_nullable {inner_type->isNullable()}
, inner_nullable {nested_function->getResultType()->isNullable()}
{
// nothing
}
@ -246,22 +245,22 @@ public:
readChar(place[size_of_data], buf);
}
DataTypePtr getReturnType() const override
static DataTypePtr createResultType(const DataTypePtr & inner_type_)
{
if constexpr (UseNull)
{
// -OrNull
if (inner_nullable)
return inner_type;
if (inner_type_->isNullable())
return inner_type_;
return std::make_shared<DataTypeNullable>(inner_type);
return std::make_shared<DataTypeNullable>(inner_type_);
}
else
{
// -OrDefault
return inner_type;
return inner_type_;
}
}

View File

@ -72,7 +72,7 @@ private:
public:
AggregateFunctionQuantile(const DataTypes & argument_types_, const Array & params)
: IAggregateFunctionDataHelper<Data, AggregateFunctionQuantile<Value, Data, Name, has_second_arg, FloatReturnType, returns_many>>(
argument_types_, params)
argument_types_, params, createResultType(argument_types_))
, levels(params, returns_many)
, level(levels.levels[0])
, argument_type(this->argument_types[0])
@ -83,14 +83,14 @@ public:
String getName() const override { return Name::name; }
DataTypePtr getReturnType() const override
static DataTypePtr createResultType(const DataTypes & argument_types_)
{
DataTypePtr res;
if constexpr (returns_float)
res = std::make_shared<DataTypeNumber<FloatReturnType>>();
else
res = argument_type;
res = argument_types_[0];
if constexpr (returns_many)
return std::make_shared<DataTypeArray>(res);

View File

@ -51,7 +51,7 @@ class AggregateFunctionRankCorrelation :
{
public:
explicit AggregateFunctionRankCorrelation(const DataTypes & arguments)
:IAggregateFunctionDataHelper<RankCorrelationData, AggregateFunctionRankCorrelation> ({arguments}, {})
:IAggregateFunctionDataHelper<RankCorrelationData, AggregateFunctionRankCorrelation> ({arguments}, {}, std::make_shared<DataTypeNumber<Float64>>())
{}
String getName() const override
@ -61,11 +61,6 @@ public:
bool allocatesMemoryInArena() const override { return true; }
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeNumber<Float64>>();
}
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{
Float64 new_x = columns[0]->getFloat64(row_num);

View File

@ -43,7 +43,7 @@ public:
size_t step_,
const DataTypes & arguments,
const Array & params)
: IAggregateFunctionHelper<AggregateFunctionResample<Key>>{arguments, params}
: IAggregateFunctionHelper<AggregateFunctionResample<Key>>{arguments, params, createResultType(nested_function_)}
, nested_function{nested_function_}
, last_col{arguments.size() - 1}
, begin{begin_}
@ -190,9 +190,9 @@ public:
nested_function->deserialize(place + i * size_of_data, buf, version, arena);
}
DataTypePtr getReturnType() const override
static DataTypePtr createResultType(const AggregateFunctionPtr & nested_function_)
{
return std::make_shared<DataTypeArray>(nested_function->getReturnType());
return std::make_shared<DataTypeArray>(nested_function_->getResultType());
}
template <bool merge>

View File

@ -76,7 +76,7 @@ public:
}
explicit AggregateFunctionRetention(const DataTypes & arguments)
: IAggregateFunctionDataHelper<AggregateFunctionRetentionData, AggregateFunctionRetention>(arguments, {})
: IAggregateFunctionDataHelper<AggregateFunctionRetentionData, AggregateFunctionRetention>(arguments, {}, std::make_shared<DataTypeArray>(std::make_shared<DataTypeUInt8>()))
{
for (const auto i : collections::range(0, arguments.size()))
{
@ -90,12 +90,6 @@ public:
events_size = static_cast<UInt8>(arguments.size());
}
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeArray>(std::make_shared<DataTypeUInt8>());
}
bool allocatesMemoryInArena() const override { return false; }
void add(AggregateDataPtr __restrict place, const IColumn ** columns, const size_t row_num, Arena *) const override

View File

@ -8,6 +8,7 @@
#include <base/range.h>
#include <base/sort.h>
#include <Common/PODArray.h>
#include "DataTypes/Serializations/ISerialization.h"
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
#include <bitset>
@ -126,8 +127,8 @@ template <typename T, typename Data, typename Derived>
class AggregateFunctionSequenceBase : public IAggregateFunctionDataHelper<Data, Derived>
{
public:
AggregateFunctionSequenceBase(const DataTypes & arguments, const Array & params, const String & pattern_)
: IAggregateFunctionDataHelper<Data, Derived>(arguments, params)
AggregateFunctionSequenceBase(const DataTypes & arguments, const Array & params, const String & pattern_, const DataTypePtr & result_type_)
: IAggregateFunctionDataHelper<Data, Derived>(arguments, params, result_type_)
, pattern(pattern_)
{
arg_count = arguments.size();
@ -617,14 +618,12 @@ class AggregateFunctionSequenceMatch final : public AggregateFunctionSequenceBas
{
public:
AggregateFunctionSequenceMatch(const DataTypes & arguments, const Array & params, const String & pattern_)
: AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceMatch<T, Data>>(arguments, params, pattern_) {}
: AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceMatch<T, Data>>(arguments, params, pattern_, std::make_shared<DataTypeUInt8>()) {}
using AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceMatch<T, Data>>::AggregateFunctionSequenceBase;
String getName() const override { return "sequenceMatch"; }
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeUInt8>(); }
bool allocatesMemoryInArena() const override { return false; }
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
@ -655,14 +654,12 @@ class AggregateFunctionSequenceCount final : public AggregateFunctionSequenceBas
{
public:
AggregateFunctionSequenceCount(const DataTypes & arguments, const Array & params, const String & pattern_)
: AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceCount<T, Data>>(arguments, params, pattern_) {}
: AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceCount<T, Data>>(arguments, params, pattern_, std::make_shared<DataTypeUInt64>()) {}
using AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceCount<T, Data>>::AggregateFunctionSequenceBase;
String getName() const override { return "sequenceCount"; }
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeUInt64>(); }
bool allocatesMemoryInArena() const override { return false; }
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override

View File

@ -190,7 +190,7 @@ public:
SequenceDirection seq_direction_,
size_t min_required_args_,
UInt64 max_elems_ = std::numeric_limits<UInt64>::max())
: IAggregateFunctionDataHelper<SequenceNextNodeGeneralData<Node>, Self>({data_type_}, parameters_)
: IAggregateFunctionDataHelper<SequenceNextNodeGeneralData<Node>, Self>({data_type_}, parameters_, data_type_)
, seq_base_kind(seq_base_kind_)
, seq_direction(seq_direction_)
, min_required_args(min_required_args_)
@ -202,8 +202,6 @@ public:
String getName() const override { return "sequenceNextNode"; }
DataTypePtr getReturnType() const override { return data_type; }
bool haveSameStateRepresentationImpl(const IAggregateFunction & rhs) const override
{
return this->getName() == rhs.getName() && this->haveEqualArgumentTypes(rhs);

View File

@ -99,7 +99,7 @@ public:
IAggregateFunctionDataHelper<
AggregateFunctionSimpleLinearRegressionData<Ret>,
AggregateFunctionSimpleLinearRegression<X, Y, Ret>
> {arguments, params}
> {arguments, params, createResultType()}
{
// notice: arguments has been checked before
}
@ -140,7 +140,7 @@ public:
this->data(place).deserialize(buf);
}
DataTypePtr getReturnType() const override
static DataTypePtr createResultType()
{
DataTypes types
{

View File

@ -20,28 +20,28 @@ private:
public:
AggregateFunctionSimpleState(AggregateFunctionPtr nested_, const DataTypes & arguments_, const Array & params_)
: IAggregateFunctionHelper<AggregateFunctionSimpleState>(arguments_, params_)
: IAggregateFunctionHelper<AggregateFunctionSimpleState>(arguments_, params_, createResultType(nested_, params_))
, nested_func(nested_)
{
}
String getName() const override { return nested_func->getName() + "SimpleState"; }
DataTypePtr getReturnType() const override
static DataTypePtr createResultType(const AggregateFunctionPtr & nested_, const Array & params_)
{
DataTypeCustomSimpleAggregateFunction::checkSupportedFunctions(nested_func);
DataTypeCustomSimpleAggregateFunction::checkSupportedFunctions(nested_);
// Need to make a clone to avoid recursive reference.
auto storage_type_out = DataTypeFactory::instance().get(nested_func->getReturnType()->getName());
auto storage_type_out = DataTypeFactory::instance().get(nested_->getResultType()->getName());
// Need to make a new function with promoted argument types because SimpleAggregates requires arg_type = return_type.
AggregateFunctionProperties properties;
auto function
= AggregateFunctionFactory::instance().get(nested_func->getName(), {storage_type_out}, nested_func->getParameters(), properties);
= AggregateFunctionFactory::instance().get(nested_->getName(), {storage_type_out}, nested_->getParameters(), properties);
// Need to make a clone because it'll be customized.
auto storage_type_arg = DataTypeFactory::instance().get(nested_func->getReturnType()->getName());
auto storage_type_arg = DataTypeFactory::instance().get(nested_->getResultType()->getName());
DataTypeCustomNamePtr custom_name
= std::make_unique<DataTypeCustomSimpleAggregateFunction>(function, DataTypes{nested_func->getReturnType()}, parameters);
= std::make_unique<DataTypeCustomSimpleAggregateFunction>(function, DataTypes{nested_->getResultType()}, params_);
storage_type_arg->setCustomization(std::make_unique<DataTypeCustomDesc>(std::move(custom_name), nullptr));
return storage_type_arg;
}

View File

@ -261,7 +261,7 @@ private:
public:
AggregateFunctionSparkbar(const DataTypes & arguments, const Array & params)
: IAggregateFunctionDataHelper<AggregateFunctionSparkbarData<X, Y>, AggregateFunctionSparkbar>(
arguments, params)
arguments, params, std::make_shared<DataTypeString>())
{
width = params.at(0).safeGet<UInt64>();
if (params.size() == 3)
@ -283,11 +283,6 @@ public:
return "sparkbar";
}
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeString>();
}
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * /*arena*/) const override
{
X x = assert_cast<const ColumnVector<X> *>(columns[0])->getData()[row_num];

View File

@ -23,7 +23,7 @@ private:
public:
AggregateFunctionState(AggregateFunctionPtr nested_, const DataTypes & arguments_, const Array & params_)
: IAggregateFunctionHelper<AggregateFunctionState>(arguments_, params_)
: IAggregateFunctionHelper<AggregateFunctionState>(arguments_, params_, createResultType())
, nested_func(nested_)
{}
@ -32,7 +32,7 @@ public:
return nested_func->getName() + "State";
}
DataTypePtr getReturnType() const override
DataTypePtr createResultType() const
{
return getStateType();
}

View File

@ -115,15 +115,11 @@ class AggregateFunctionVariance final
{
public:
explicit AggregateFunctionVariance(const DataTypePtr & arg)
: IAggregateFunctionDataHelper<AggregateFunctionVarianceData<T, Op>, AggregateFunctionVariance<T, Op>>({arg}, {}) {}
: IAggregateFunctionDataHelper<AggregateFunctionVarianceData<T, Op>, AggregateFunctionVariance<T, Op>>({arg}, {}, std::make_shared<DataTypeFloat64>())
{}
String getName() const override { return Op::name; }
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeFloat64>();
}
bool allocatesMemoryInArena() const override { return false; }
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
@ -368,15 +364,11 @@ class AggregateFunctionCovariance final
public:
explicit AggregateFunctionCovariance(const DataTypes & args) : IAggregateFunctionDataHelper<
CovarianceData<T, U, Op, compute_marginal_moments>,
AggregateFunctionCovariance<T, U, Op, compute_marginal_moments>>(args, {}) {}
AggregateFunctionCovariance<T, U, Op, compute_marginal_moments>>(args, {}, std::make_shared<DataTypeFloat64>())
{}
String getName() const override { return Op::name; }
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeFloat64>();
}
bool allocatesMemoryInArena() const override { return false; }
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override

View File

@ -81,12 +81,12 @@ public:
using ColVecResult = ColumnVector<ResultType>;
explicit AggregateFunctionVarianceSimple(const DataTypes & argument_types_)
: IAggregateFunctionDataHelper<typename StatFunc::Data, AggregateFunctionVarianceSimple<StatFunc>>(argument_types_, {})
: IAggregateFunctionDataHelper<typename StatFunc::Data, AggregateFunctionVarianceSimple<StatFunc>>(argument_types_, {}, std::make_shared<DataTypeNumber<ResultType>>())
, src_scale(0)
{}
AggregateFunctionVarianceSimple(const IDataType & data_type, const DataTypes & argument_types_)
: IAggregateFunctionDataHelper<typename StatFunc::Data, AggregateFunctionVarianceSimple<StatFunc>>(argument_types_, {})
: IAggregateFunctionDataHelper<typename StatFunc::Data, AggregateFunctionVarianceSimple<StatFunc>>(argument_types_, {}, std::make_shared<DataTypeNumber<ResultType>>())
, src_scale(getDecimalScale(data_type))
{}
@ -117,11 +117,6 @@ public:
UNREACHABLE();
}
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeNumber<ResultType>>();
}
bool allocatesMemoryInArena() const override { return false; }
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override

View File

@ -411,23 +411,21 @@ public:
}
explicit AggregateFunctionSum(const DataTypes & argument_types_)
: IAggregateFunctionDataHelper<Data, AggregateFunctionSum<T, TResult, Data, Type>>(argument_types_, {})
, scale(0)
: IAggregateFunctionDataHelper<Data, AggregateFunctionSum<T, TResult, Data, Type>>(argument_types_, {}, createResultType(0))
{}
AggregateFunctionSum(const IDataType & data_type, const DataTypes & argument_types_)
: IAggregateFunctionDataHelper<Data, AggregateFunctionSum<T, TResult, Data, Type>>(argument_types_, {})
, scale(getDecimalScale(data_type))
: IAggregateFunctionDataHelper<Data, AggregateFunctionSum<T, TResult, Data, Type>>(argument_types_, {}, createResultType(getDecimalScale(data_type)))
{}
DataTypePtr getReturnType() const override
static DataTypePtr createResultType(UInt32 scale_)
{
if constexpr (!is_decimal<T>)
return std::make_shared<DataTypeNumber<TResult>>();
else
{
using DataType = DataTypeDecimal<TResult>;
return std::make_shared<DataType>(DataType::maxPrecision(), scale);
return std::make_shared<DataType>(DataType::maxPrecision(), scale_);
}
}
@ -548,7 +546,7 @@ public:
for (const auto & argument_type : this->argument_types)
can_be_compiled &= canBeNativeType(*argument_type);
auto return_type = getReturnType();
auto return_type = this->getResultType();
can_be_compiled &= canBeNativeType(*return_type);
return can_be_compiled;
@ -558,7 +556,7 @@ public:
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * return_type = toNativeType(b, getReturnType());
auto * return_type = toNativeType(b, this->getResultType());
auto * aggregate_sum_ptr = aggregate_data_ptr;
b.CreateStore(llvm::Constant::getNullValue(return_type), aggregate_sum_ptr);
@ -568,7 +566,7 @@ public:
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * return_type = toNativeType(b, getReturnType());
auto * return_type = toNativeType(b, this->getResultType());
auto * sum_value_ptr = aggregate_data_ptr;
auto * sum_value = b.CreateLoad(return_type, sum_value_ptr);
@ -586,7 +584,7 @@ public:
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * return_type = toNativeType(b, getReturnType());
auto * return_type = toNativeType(b, this->getResultType());
auto * sum_value_dst_ptr = aggregate_data_dst_ptr;
auto * sum_value_dst = b.CreateLoad(return_type, sum_value_dst_ptr);
@ -602,7 +600,7 @@ public:
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * return_type = toNativeType(b, getReturnType());
auto * return_type = toNativeType(b, this->getResultType());
auto * sum_value_ptr = aggregate_data_ptr;
return b.CreateLoad(return_type, sum_value_ptr);
@ -611,8 +609,6 @@ public:
#endif
private:
UInt32 scale;
static constexpr auto & castColumnToResult(IColumn & to)
{
if constexpr (is_decimal<T>)

View File

@ -14,12 +14,13 @@ public:
using Base = AggregateFunctionAvg<T>;
explicit AggregateFunctionSumCount(const DataTypes & argument_types_, UInt32 num_scale_ = 0)
: Base(argument_types_, num_scale_), scale(num_scale_) {}
: Base(argument_types_, createResultType(num_scale_), num_scale_)
{}
DataTypePtr getReturnType() const override
static DataTypePtr createResultType(UInt32 num_scale_)
{
auto second_elem = std::make_shared<DataTypeUInt64>();
return std::make_shared<DataTypeTuple>(DataTypes{getReturnTypeFirstElement(), std::move(second_elem)});
return std::make_shared<DataTypeTuple>(DataTypes{getReturnTypeFirstElement(num_scale_), std::move(second_elem)});
}
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const final
@ -43,9 +44,7 @@ public:
#endif
private:
UInt32 scale;
auto getReturnTypeFirstElement() const
static auto getReturnTypeFirstElement(UInt32 num_scale_)
{
using FieldType = AvgFieldType<T>;
@ -54,7 +53,7 @@ private:
else
{
using DataType = DataTypeDecimal<FieldType>;
return std::make_shared<DataType>(DataType::maxPrecision(), scale);
return std::make_shared<DataType>(DataType::maxPrecision(), num_scale_);
}
}
};

View File

@ -80,7 +80,7 @@ public:
AggregateFunctionMapBase(const DataTypePtr & keys_type_,
const DataTypes & values_types_, const DataTypes & argument_types_)
: Base(argument_types_, {} /* parameters */)
: Base(argument_types_, {} /* parameters */, createResultType(keys_type_, values_types_, getName()))
, keys_type(keys_type_)
, keys_serialization(keys_type->getDefaultSerialization())
, values_types(values_types_)
@ -117,19 +117,22 @@ public:
return 0;
}
DataTypePtr getReturnType() const override
static DataTypePtr createResultType(
const DataTypePtr & keys_type_,
const DataTypes & values_types_,
const String & name_)
{
DataTypes types;
types.emplace_back(std::make_shared<DataTypeArray>(keys_type));
types.emplace_back(std::make_shared<DataTypeArray>(keys_type_));
for (const auto & value_type : values_types)
for (const auto & value_type : values_types_)
{
if constexpr (std::is_same_v<Visitor, FieldVisitorSum>)
{
if (!value_type->isSummable())
throw Exception{ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Values for {} cannot be summed, passed type {}",
getName(), value_type->getName()};
name_, value_type->getName()};
}
DataTypePtr result_type;
@ -139,7 +142,7 @@ public:
if (value_type->onlyNull())
throw Exception{ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Cannot calculate {} of type {}",
getName(), value_type->getName()};
name_, value_type->getName()};
// Overflow, meaning that the returned type is the same as
// the input type. Nulls are skipped.
@ -153,7 +156,7 @@ public:
if (!value_type_without_nullable->canBePromoted())
throw Exception{ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Values for {} are expected to be Numeric, Float or Decimal, passed type {}",
getName(), value_type->getName()};
name_, value_type->getName()};
WhichDataType value_type_to_check(value_type_without_nullable);

View File

@ -46,7 +46,7 @@ private:
Float64 confidence_level;
public:
AggregateFunctionTTest(const DataTypes & arguments, const Array & params)
: IAggregateFunctionDataHelper<Data, AggregateFunctionTTest<Data>>({arguments}, params)
: IAggregateFunctionDataHelper<Data, AggregateFunctionTTest<Data>>({arguments}, params, createResultType(!params.empty()))
{
if (!params.empty())
{
@ -71,9 +71,9 @@ public:
return Data::name;
}
DataTypePtr getReturnType() const override
static DataTypePtr createResultType(bool need_confidence_interval_)
{
if (need_confidence_interval)
if (need_confidence_interval_)
{
DataTypes types
{

View File

@ -31,15 +31,33 @@ namespace
template <bool is_weighted>
class AggregateFunctionTopKDate : public AggregateFunctionTopK<DataTypeDate::FieldType, is_weighted>
{
public:
using AggregateFunctionTopK<DataTypeDate::FieldType, is_weighted>::AggregateFunctionTopK;
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeArray>(std::make_shared<DataTypeDate>()); }
AggregateFunctionTopKDate(UInt64 threshold_, UInt64 load_factor, const DataTypes & argument_types_, const Array & params)
: AggregateFunctionTopK<DataTypeDate::FieldType, is_weighted>(
threshold_,
load_factor,
argument_types_,
params,
std::make_shared<DataTypeArray>(std::make_shared<DataTypeDate>()))
{}
};
template <bool is_weighted>
class AggregateFunctionTopKDateTime : public AggregateFunctionTopK<DataTypeDateTime::FieldType, is_weighted>
{
public:
using AggregateFunctionTopK<DataTypeDateTime::FieldType, is_weighted>::AggregateFunctionTopK;
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeArray>(std::make_shared<DataTypeDateTime>()); }
AggregateFunctionTopKDateTime(UInt64 threshold_, UInt64 load_factor, const DataTypes & argument_types_, const Array & params)
: AggregateFunctionTopK<DataTypeDateTime::FieldType, is_weighted>(
threshold_,
load_factor,
argument_types_,
params,
std::make_shared<DataTypeArray>(std::make_shared<DataTypeDateTime>()))
{}
};

View File

@ -11,6 +11,7 @@
#include <Common/SpaceSaving.h>
#include <Common/assert_cast.h>
#include "DataTypes/Serializations/ISerialization.h"
#include <AggregateFunctions/IAggregateFunction.h>
@ -40,14 +41,20 @@ protected:
public:
AggregateFunctionTopK(UInt64 threshold_, UInt64 load_factor, const DataTypes & argument_types_, const Array & params)
: IAggregateFunctionDataHelper<AggregateFunctionTopKData<T>, AggregateFunctionTopK<T, is_weighted>>(argument_types_, params)
, threshold(threshold_), reserved(load_factor * threshold) {}
: IAggregateFunctionDataHelper<AggregateFunctionTopKData<T>, AggregateFunctionTopK<T, is_weighted>>(argument_types_, params, createResultType(argument_types_))
, threshold(threshold_), reserved(load_factor * threshold)
{}
AggregateFunctionTopK(UInt64 threshold_, UInt64 load_factor, const DataTypes & argument_types_, const Array & params, const DataTypePtr & result_type_)
: IAggregateFunctionDataHelper<AggregateFunctionTopKData<T>, AggregateFunctionTopK<T, is_weighted>>(argument_types_, params, result_type_)
, threshold(threshold_), reserved(load_factor * threshold)
{}
String getName() const override { return is_weighted ? "topKWeighted" : "topK"; }
DataTypePtr getReturnType() const override
static DataTypePtr createResultType(const DataTypes & argument_types_)
{
return std::make_shared<DataTypeArray>(this->argument_types[0]);
return std::make_shared<DataTypeArray>(argument_types_[0]);
}
bool allocatesMemoryInArena() const override { return false; }
@ -126,21 +133,20 @@ private:
UInt64 threshold;
UInt64 reserved;
DataTypePtr & input_data_type;
static void deserializeAndInsert(StringRef str, IColumn & data_to);
public:
AggregateFunctionTopKGeneric(
UInt64 threshold_, UInt64 load_factor, const DataTypes & argument_types_, const Array & params)
: IAggregateFunctionDataHelper<AggregateFunctionTopKGenericData, AggregateFunctionTopKGeneric<is_plain_column, is_weighted>>(argument_types_, params)
, threshold(threshold_), reserved(load_factor * threshold), input_data_type(this->argument_types[0]) {}
: IAggregateFunctionDataHelper<AggregateFunctionTopKGenericData, AggregateFunctionTopKGeneric<is_plain_column, is_weighted>>(argument_types_, params, createResultType(argument_types_))
, threshold(threshold_), reserved(load_factor * threshold) {}
String getName() const override { return is_weighted ? "topKWeighted" : "topK"; }
DataTypePtr getReturnType() const override
static DataTypePtr createResultType(const DataTypes & argument_types_)
{
return std::make_shared<DataTypeArray>(input_data_type);
return std::make_shared<DataTypeArray>(argument_types_[0]);
}
bool allocatesMemoryInArena() const override

View File

@ -358,17 +358,12 @@ private:
public:
explicit AggregateFunctionUniq(const DataTypes & argument_types_)
: IAggregateFunctionDataHelper<Data, AggregateFunctionUniq<T, Data>>(argument_types_, {})
: IAggregateFunctionDataHelper<Data, AggregateFunctionUniq<T, Data>>(argument_types_, {}, std::make_shared<DataTypeUInt64>())
{
}
String getName() const override { return Data::getName(); }
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeUInt64>();
}
bool allocatesMemoryInArena() const override { return false; }
/// ALWAYS_INLINE is required to have better code layout for uniqHLL12 function
@ -462,7 +457,7 @@ private:
public:
explicit AggregateFunctionUniqVariadic(const DataTypes & arguments)
: IAggregateFunctionDataHelper<Data, AggregateFunctionUniqVariadic<Data>>(arguments, {})
: IAggregateFunctionDataHelper<Data, AggregateFunctionUniqVariadic<Data>>(arguments, {}, std::make_shared<DataTypeUInt64>())
{
if (argument_is_tuple)
num_args = typeid_cast<const DataTypeTuple &>(*arguments[0]).getElements().size();
@ -472,11 +467,6 @@ public:
String getName() const override { return Data::getName(); }
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeUInt64>();
}
bool allocatesMemoryInArena() const override { return false; }
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override

View File

@ -126,7 +126,8 @@ class AggregateFunctionUniqCombined final
{
public:
AggregateFunctionUniqCombined(const DataTypes & argument_types_, const Array & params_)
: IAggregateFunctionDataHelper<AggregateFunctionUniqCombinedData<T, K, HashValueType>, AggregateFunctionUniqCombined<T, K, HashValueType>>(argument_types_, params_) {}
: IAggregateFunctionDataHelper<AggregateFunctionUniqCombinedData<T, K, HashValueType>, AggregateFunctionUniqCombined<T, K, HashValueType>>(argument_types_, params_, std::make_shared<DataTypeUInt64>())
{}
String getName() const override
{
@ -136,11 +137,6 @@ public:
return "uniqCombined";
}
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeUInt64>();
}
bool allocatesMemoryInArena() const override { return false; }
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
@ -192,7 +188,7 @@ private:
public:
explicit AggregateFunctionUniqCombinedVariadic(const DataTypes & arguments, const Array & params)
: IAggregateFunctionDataHelper<AggregateFunctionUniqCombinedData<UInt64, K, HashValueType>,
AggregateFunctionUniqCombinedVariadic<is_exact, argument_is_tuple, K, HashValueType>>(arguments, params)
AggregateFunctionUniqCombinedVariadic<is_exact, argument_is_tuple, K, HashValueType>>(arguments, params, std::make_shared<DataTypeUInt64>())
{
if (argument_is_tuple)
num_args = typeid_cast<const DataTypeTuple &>(*arguments[0]).getElements().size();
@ -208,11 +204,6 @@ public:
return "uniqCombined";
}
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeUInt64>();
}
bool allocatesMemoryInArena() const override { return false; }
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override

View File

@ -174,7 +174,7 @@ private:
public:
AggregateFunctionUniqUpTo(UInt8 threshold_, const DataTypes & argument_types_, const Array & params_)
: IAggregateFunctionDataHelper<AggregateFunctionUniqUpToData<T>, AggregateFunctionUniqUpTo<T>>(argument_types_, params_)
: IAggregateFunctionDataHelper<AggregateFunctionUniqUpToData<T>, AggregateFunctionUniqUpTo<T>>(argument_types_, params_, std::make_shared<DataTypeUInt64>())
, threshold(threshold_)
{
}
@ -186,11 +186,6 @@ public:
String getName() const override { return "uniqUpTo"; }
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeUInt64>();
}
bool allocatesMemoryInArena() const override { return false; }
/// ALWAYS_INLINE is required to have better code layout for uniqUpTo function
@ -235,7 +230,7 @@ private:
public:
AggregateFunctionUniqUpToVariadic(const DataTypes & arguments, const Array & params, UInt8 threshold_)
: IAggregateFunctionDataHelper<AggregateFunctionUniqUpToData<UInt64>, AggregateFunctionUniqUpToVariadic<is_exact, argument_is_tuple>>(arguments, params)
: IAggregateFunctionDataHelper<AggregateFunctionUniqUpToData<UInt64>, AggregateFunctionUniqUpToVariadic<is_exact, argument_is_tuple>>(arguments, params, std::make_shared<DataTypeUInt64>())
, threshold(threshold_)
{
if (argument_is_tuple)
@ -251,11 +246,6 @@ public:
String getName() const override { return "uniqUpTo"; }
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeUInt64>();
}
bool allocatesMemoryInArena() const override { return false; }
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override

View File

@ -221,7 +221,7 @@ public:
}
AggregateFunctionWindowFunnel(const DataTypes & arguments, const Array & params)
: IAggregateFunctionDataHelper<Data, AggregateFunctionWindowFunnel<T, Data>>(arguments, params)
: IAggregateFunctionDataHelper<Data, AggregateFunctionWindowFunnel<T, Data>>(arguments, params, std::make_shared<DataTypeUInt8>())
{
events_size = arguments.size() - 1;
window = params.at(0).safeGet<UInt64>();
@ -245,11 +245,6 @@ public:
}
}
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeUInt8>();
}
bool allocatesMemoryInArena() const override { return false; }
void add(AggregateDataPtr __restrict place, const IColumn ** columns, const size_t row_num, Arena *) const override

View File

@ -118,7 +118,7 @@ class AggregateFunctionCrossTab : public IAggregateFunctionDataHelper<Data, Aggr
{
public:
explicit AggregateFunctionCrossTab(const DataTypes & arguments)
: IAggregateFunctionDataHelper<Data, AggregateFunctionCrossTab<Data>>({arguments}, {})
: IAggregateFunctionDataHelper<Data, AggregateFunctionCrossTab<Data>>({arguments}, {}, createResultType())
{
}
@ -132,7 +132,7 @@ public:
return false;
}
DataTypePtr getReturnType() const override
static DataTypePtr createResultType()
{
return std::make_shared<DataTypeNumber<Float64>>();
}

View File

@ -10,6 +10,7 @@
#include <base/types.h>
#include <Common/Exception.h>
#include <Common/ThreadPool.h>
#include <Core/IResolvedFunction.h>
#include "config.h"
@ -48,7 +49,9 @@ using AggregateDataPtr = char *;
using ConstAggregateDataPtr = const char *;
class IAggregateFunction;
using AggregateFunctionPtr = std::shared_ptr<const IAggregateFunction>;
using AggregateFunctionPtr = std::shared_ptr<IAggregateFunction>;
using ConstAggregateFunctionPtr = std::shared_ptr<const IAggregateFunction>;
struct AggregateFunctionProperties;
/** Aggregate functions interface.
@ -59,18 +62,18 @@ struct AggregateFunctionProperties;
* (which can be created in some memory pool),
* and IAggregateFunction is the external interface for manipulating them.
*/
class IAggregateFunction : public std::enable_shared_from_this<IAggregateFunction>
class IAggregateFunction : public std::enable_shared_from_this<IAggregateFunction>, public IResolvedFunction
{
public:
IAggregateFunction(const DataTypes & argument_types_, const Array & parameters_)
: argument_types(argument_types_), parameters(parameters_) {}
IAggregateFunction(const DataTypes & argument_types_, const Array & parameters_, const DataTypePtr & result_type_)
: result_type(result_type_)
, argument_types(argument_types_)
, parameters(parameters_)
{}
/// Get main function name.
virtual String getName() const = 0;
/// Get the result type.
virtual DataTypePtr getReturnType() const = 0;
/// Get the data type of internal state. By default it is AggregateFunction(name(params), argument_types...).
virtual DataTypePtr getStateType() const;
@ -102,7 +105,7 @@ public:
virtual size_t getDefaultVersion() const { return 0; }
virtual ~IAggregateFunction() = default;
~IAggregateFunction() override = default;
/** Data manipulating functions. */
@ -348,8 +351,9 @@ public:
*/
virtual AggregateFunctionPtr getNestedFunction() const { return {}; }
const DataTypes & getArgumentTypes() const { return argument_types; }
const Array & getParameters() const { return parameters; }
const DataTypePtr & getResultType() const override { return result_type; }
const DataTypes & getArgumentTypes() const override { return argument_types; }
const Array & getParameters() const override { return parameters; }
// Any aggregate function can be calculated over a window, but there are some
// window functions such as rank() that require a different interface, e.g.
@ -398,6 +402,7 @@ public:
#endif
protected:
DataTypePtr result_type;
DataTypes argument_types;
Array parameters;
};
@ -414,8 +419,8 @@ private:
}
public:
IAggregateFunctionHelper(const DataTypes & argument_types_, const Array & parameters_)
: IAggregateFunction(argument_types_, parameters_) {}
IAggregateFunctionHelper(const DataTypes & argument_types_, const Array & parameters_, const DataTypePtr & result_type_)
: IAggregateFunction(argument_types_, parameters_, result_type_) {}
AddFunc getAddressOfAddFunction() const override { return &addFree; }
@ -695,15 +700,15 @@ public:
// Derived class can `override` this to flag that DateTime64 is not supported.
static constexpr bool DateTime64Supported = true;
IAggregateFunctionDataHelper(const DataTypes & argument_types_, const Array & parameters_)
: IAggregateFunctionHelper<Derived>(argument_types_, parameters_)
IAggregateFunctionDataHelper(const DataTypes & argument_types_, const Array & parameters_, const DataTypePtr & result_type_)
: IAggregateFunctionHelper<Derived>(argument_types_, parameters_, result_type_)
{
/// To prevent derived classes changing the destroy() without updating hasTrivialDestructor() to match it
/// Enforce that either both of them are changed or none are
constexpr bool declares_destroy_and_hasTrivialDestructor =
constexpr bool declares_destroy_and_has_trivial_destructor =
std::is_same_v<decltype(&IAggregateFunctionDataHelper::destroy), decltype(&Derived::destroy)> ==
std::is_same_v<decltype(&IAggregateFunctionDataHelper::hasTrivialDestructor), decltype(&Derived::hasTrivialDestructor)>;
static_assert(declares_destroy_and_hasTrivialDestructor,
static_assert(declares_destroy_and_has_trivial_destructor,
"destroy() and hasTrivialDestructor() methods of an aggregate function must be either both overridden or not");
}

View File

@ -2,6 +2,7 @@
#include <Common/SipHash.h>
#include <Common/FieldVisitorToString.h>
#include "Core/ColumnsWithTypeAndName.h"
#include <IO/WriteBufferFromString.h>
#include <IO/Operators.h>
@ -25,25 +26,54 @@ FunctionNode::FunctionNode(String function_name_)
children[arguments_child_index] = std::make_shared<ListNode>();
}
void FunctionNode::resolveAsFunction(FunctionOverloadResolverPtr function_value, DataTypePtr result_type_value)
ColumnsWithTypeAndName FunctionNode::getArgumentTypes() const
{
aggregate_function = nullptr;
ColumnsWithTypeAndName argument_types;
for (const auto & arg : getArguments().getNodes())
{
ColumnWithTypeAndName argument;
argument.type = arg->getResultType();
argument_types.push_back(argument);
}
return argument_types;
}
FunctionBasePtr FunctionNode::getFunction() const
{
return std::dynamic_pointer_cast<IFunctionBase>(function);
}
AggregateFunctionPtr FunctionNode::getAggregateFunction() const
{
return std::dynamic_pointer_cast<IAggregateFunction>(function);
}
bool FunctionNode::isAggregateFunction() const
{
return typeid_cast<AggregateFunctionPtr>(function) != nullptr && !isWindowFunction();
}
bool FunctionNode::isOrdinaryFunction() const
{
return typeid_cast<FunctionBasePtr>(function) != nullptr;
}
void FunctionNode::resolveAsFunction(FunctionBasePtr function_value)
{
function_name = function_value->getName();
function = std::move(function_value);
result_type = std::move(result_type_value);
function_name = function->getName();
}
void FunctionNode::resolveAsAggregateFunction(AggregateFunctionPtr aggregate_function_value, DataTypePtr result_type_value)
void FunctionNode::resolveAsAggregateFunction(AggregateFunctionPtr aggregate_function_value)
{
function = nullptr;
aggregate_function = std::move(aggregate_function_value);
result_type = std::move(result_type_value);
function_name = aggregate_function->getName();
function_name = aggregate_function_value->getName();
function = std::move(aggregate_function_value);
}
void FunctionNode::resolveAsWindowFunction(AggregateFunctionPtr window_function_value, DataTypePtr result_type_value)
void FunctionNode::resolveAsWindowFunction(AggregateFunctionPtr window_function_value)
{
resolveAsAggregateFunction(window_function_value, result_type_value);
resolveAsAggregateFunction(window_function_value);
}
void FunctionNode::dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const
@ -63,8 +93,8 @@ void FunctionNode::dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state
buffer << ", function_type: " << function_type;
if (result_type)
buffer << ", result_type: " + result_type->getName();
if (function)
buffer << ", result_type: " + function->getResultType()->getName();
const auto & parameters = getParameters();
if (!parameters.getNodes().empty())
@ -95,12 +125,14 @@ bool FunctionNode::isEqualImpl(const IQueryTreeNode & rhs) const
isOrdinaryFunction() != rhs_typed.isOrdinaryFunction() ||
isWindowFunction() != rhs_typed.isWindowFunction())
return false;
auto lhs_result_type = getResultType();
auto rhs_result_type = rhs.getResultType();
if (result_type && rhs_typed.result_type && !result_type->equals(*rhs_typed.getResultType()))
if (lhs_result_type && rhs_result_type && !lhs_result_type->equals(*rhs_result_type))
return false;
else if (result_type && !rhs_typed.result_type)
else if (lhs_result_type && !rhs_result_type)
return false;
else if (!result_type && rhs_typed.result_type)
else if (!lhs_result_type && rhs_result_type)
return false;
return true;
@ -114,7 +146,7 @@ void FunctionNode::updateTreeHashImpl(HashState & hash_state) const
hash_state.update(isAggregateFunction());
hash_state.update(isWindowFunction());
if (result_type)
if (auto result_type = getResultType())
{
auto result_type_name = result_type->getName();
hash_state.update(result_type_name.size());
@ -130,8 +162,6 @@ QueryTreeNodePtr FunctionNode::cloneImpl() const
* because ordinary functions or aggregate functions must be stateless.
*/
result_function->function = function;
result_function->aggregate_function = aggregate_function;
result_function->result_type = result_type;
return result_function;
}

View File

@ -1,8 +1,12 @@
#pragma once
#include <memory>
#include <Core/IResolvedFunction.h>
#include <Analyzer/IQueryTreeNode.h>
#include <Analyzer/ListNode.h>
#include <Analyzer/ConstantValue.h>
#include <Common/typeid_cast.h>
#include "Core/ColumnsWithTypeAndName.h"
namespace DB
{
@ -10,8 +14,11 @@ namespace DB
class IFunctionOverloadResolver;
using FunctionOverloadResolverPtr = std::shared_ptr<IFunctionOverloadResolver>;
class IFunctionBase;
using FunctionBasePtr = std::shared_ptr<IFunctionBase>;
class IAggregateFunction;
using AggregateFunctionPtr = std::shared_ptr<const IAggregateFunction>;
using AggregateFunctionPtr = std::shared_ptr<IAggregateFunction>;
/** Function node represents function in query tree.
* Function syntax: function_name(parameter_1, ...)(argument_1, ...).
@ -96,6 +103,8 @@ public:
return children[arguments_child_index];
}
ColumnsWithTypeAndName getArgumentTypes() const;
/// Returns true if function node has window, false otherwise
bool hasWindow() const
{
@ -124,24 +133,18 @@ public:
/** Get non aggregate function.
* If function is not resolved nullptr returned.
*/
const FunctionOverloadResolverPtr & getFunction() const
{
return function;
}
FunctionBasePtr getFunction() const;
/** Get aggregate function.
* If function is not resolved nullptr returned.
* If function is resolved as non aggregate function nullptr returned.
*/
const AggregateFunctionPtr & getAggregateFunction() const
{
return aggregate_function;
}
AggregateFunctionPtr getAggregateFunction() const;
/// Is function node resolved
bool isResolved() const
{
return result_type != nullptr && (function != nullptr || aggregate_function != nullptr);
return function != nullptr;
}
/// Is function node window function
@ -151,16 +154,10 @@ public:
}
/// Is function node aggregate function
bool isAggregateFunction() const
{
return aggregate_function != nullptr && !isWindowFunction();
}
bool isAggregateFunction() const;
/// Is function node ordinary function
bool isOrdinaryFunction() const
{
return function != nullptr;
}
bool isOrdinaryFunction() const;
/** Resolve function node as non aggregate function.
* It is important that function name is updated with resolved function name.
@ -168,19 +165,19 @@ public:
* Assume we have `multiIf` function with single condition, it can be converted to `if` function.
* Function name must be updated accordingly.
*/
void resolveAsFunction(FunctionOverloadResolverPtr function_value, DataTypePtr result_type_value);
void resolveAsFunction(FunctionBasePtr function_value);
/** Resolve function node as aggregate function.
* It is important that function name is updated with resolved function name.
* Main motivation for this is query tree optimizations.
*/
void resolveAsAggregateFunction(AggregateFunctionPtr aggregate_function_value, DataTypePtr result_type_value);
void resolveAsAggregateFunction(AggregateFunctionPtr aggregate_function_value);
/** Resolve function node as window function.
* It is important that function name is updated with resolved function name.
* Main motivation for this is query tree optimizations.
*/
void resolveAsWindowFunction(AggregateFunctionPtr window_function_value, DataTypePtr result_type_value);
void resolveAsWindowFunction(AggregateFunctionPtr window_function_value);
QueryTreeNodeType getNodeType() const override
{
@ -189,7 +186,7 @@ public:
DataTypePtr getResultType() const override
{
return result_type;
return function->getResultType();
}
void dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const override;
@ -205,9 +202,7 @@ protected:
private:
String function_name;
FunctionOverloadResolverPtr function;
AggregateFunctionPtr aggregate_function;
DataTypePtr result_type;
IResolvedFunctionPtr function;
static constexpr size_t parameters_child_index = 0;
static constexpr size_t arguments_child_index = 1;

View File

@ -147,7 +147,6 @@ public:
private:
static inline void resolveAggregateFunctionNode(FunctionNode & function_node, const String & aggregate_function_name)
{
auto function_result_type = function_node.getResultType();
auto function_aggregate_function = function_node.getAggregateFunction();
AggregateFunctionProperties properties;
@ -156,7 +155,7 @@ private:
function_aggregate_function->getParameters(),
properties);
function_node.resolveAsAggregateFunction(std::move(aggregate_function), std::move(function_result_type));
function_node.resolveAsAggregateFunction(std::move(aggregate_function));
}
};

View File

@ -69,7 +69,7 @@ public:
auto result_type = function_node->getResultType();
AggregateFunctionProperties properties;
auto aggregate_function = AggregateFunctionFactory::instance().get("count", {}, {}, properties);
function_node->resolveAsAggregateFunction(std::move(aggregate_function), std::move(result_type));
function_node->resolveAsAggregateFunction(std::move(aggregate_function));
function_node->getArguments().getNodes().clear();
}
};

View File

@ -9,6 +9,7 @@
#include <Analyzer/InDepthQueryTreeVisitor.h>
#include <Analyzer/FunctionNode.h>
#include "Core/ColumnWithTypeAndName.h"
namespace DB
{
@ -138,7 +139,6 @@ public:
static inline void resolveAggregateOrWindowFunctionNode(FunctionNode & function_node, const String & aggregate_function_name)
{
auto function_result_type = function_node.getResultType();
auto function_aggregate_function = function_node.getAggregateFunction();
AggregateFunctionProperties properties;
@ -148,16 +148,15 @@ public:
properties);
if (function_node.isAggregateFunction())
function_node.resolveAsAggregateFunction(std::move(aggregate_function), std::move(function_result_type));
function_node.resolveAsAggregateFunction(std::move(aggregate_function));
else if (function_node.isWindowFunction())
function_node.resolveAsWindowFunction(std::move(aggregate_function), std::move(function_result_type));
function_node.resolveAsWindowFunction(std::move(aggregate_function));
}
inline void resolveOrdinaryFunctionNode(FunctionNode & function_node, const String & function_name) const
{
auto function_result_type = function_node.getResultType();
auto function = FunctionFactory::instance().get(function_name, context);
function_node.resolveAsFunction(function, std::move(function_result_type));
function_node.resolveAsFunction(function->build(function_node.getArgumentTypes()));
}
private:

View File

@ -78,11 +78,11 @@ public:
column.name += ".size0";
column.type = std::make_shared<DataTypeUInt64>();
resolveOrdinaryFunctionNode(*function_node, "equals");
function_arguments_nodes.clear();
function_arguments_nodes.push_back(std::make_shared<ColumnNode>(column, column_source));
function_arguments_nodes.push_back(std::make_shared<ConstantNode>(static_cast<UInt64>(0)));
resolveOrdinaryFunctionNode(*function_node, "equals");
}
else if (function_name == "notEmpty")
{
@ -90,11 +90,11 @@ public:
column.name += ".size0";
column.type = std::make_shared<DataTypeUInt64>();
resolveOrdinaryFunctionNode(*function_node, "notEquals");
function_arguments_nodes.clear();
function_arguments_nodes.push_back(std::make_shared<ColumnNode>(column, column_source));
function_arguments_nodes.push_back(std::make_shared<ConstantNode>(static_cast<UInt64>(0)));
resolveOrdinaryFunctionNode(*function_node, "notEquals");
}
}
else if (column_type.isNullable())
@ -112,9 +112,9 @@ public:
column.name += ".null";
column.type = std::make_shared<DataTypeUInt8>();
resolveOrdinaryFunctionNode(*function_node, "not");
function_arguments_nodes = {std::make_shared<ColumnNode>(column, column_source)};
resolveOrdinaryFunctionNode(*function_node, "not");
}
}
else if (column_type.isMap())
@ -182,9 +182,9 @@ public:
column.type = data_type_map.getKeyType();
auto has_function_argument = std::make_shared<ColumnNode>(column, column_source);
resolveOrdinaryFunctionNode(*function_node, "has");
function_arguments_nodes[0] = std::move(has_function_argument);
resolveOrdinaryFunctionNode(*function_node, "has");
}
}
}
@ -192,9 +192,8 @@ public:
private:
inline void resolveOrdinaryFunctionNode(FunctionNode & function_node, const String & function_name) const
{
auto function_result_type = function_node.getResultType();
auto function = FunctionFactory::instance().get(function_name, context);
function_node.resolveAsFunction(function, std::move(function_result_type));
function_node.resolveAsFunction(function->build(function_node.getArgumentTypes()));
}
ContextPtr & context;

View File

@ -59,14 +59,13 @@ private:
std::unordered_set<String> names_to_collect;
};
QueryTreeNodePtr createResolvedFunction(const ContextPtr & context, const String & name, const DataTypePtr & result_type, QueryTreeNodes arguments)
QueryTreeNodePtr createResolvedFunction(const ContextPtr & context, const String & name, QueryTreeNodes arguments)
{
auto function_node = std::make_shared<FunctionNode>(name);
auto function = FunctionFactory::instance().get(name, context);
function_node->resolveAsFunction(std::move(function), result_type);
function_node->getArguments().getNodes() = std::move(arguments);
function_node->resolveAsFunction(function->build(function_node->getArgumentTypes()));
return function_node;
}
@ -74,11 +73,6 @@ FunctionNodePtr createResolvedAggregateFunction(const String & name, const Query
{
auto function_node = std::make_shared<FunctionNode>(name);
AggregateFunctionProperties properties;
auto aggregate_function = AggregateFunctionFactory::instance().get(name, {argument->getResultType()}, parameters, properties);
function_node->resolveAsAggregateFunction(aggregate_function, aggregate_function->getReturnType());
function_node->getArguments().getNodes() = { argument };
if (!parameters.empty())
{
QueryTreeNodes parameter_nodes;
@ -86,18 +80,27 @@ FunctionNodePtr createResolvedAggregateFunction(const String & name, const Query
parameter_nodes.emplace_back(std::make_shared<ConstantNode>(param));
function_node->getParameters().getNodes() = std::move(parameter_nodes);
}
function_node->getArguments().getNodes() = { argument };
AggregateFunctionProperties properties;
auto aggregate_function = AggregateFunctionFactory::instance().get(
name,
{ argument->getResultType() },
parameters,
properties);
function_node->resolveAsAggregateFunction(aggregate_function);
return function_node;
}
QueryTreeNodePtr createTupleElementFunction(const ContextPtr & context, const DataTypePtr & result_type, QueryTreeNodePtr argument, UInt64 index)
QueryTreeNodePtr createTupleElementFunction(const ContextPtr & context, QueryTreeNodePtr argument, UInt64 index)
{
return createResolvedFunction(context, "tupleElement", result_type, {std::move(argument), std::make_shared<ConstantNode>(index)});
return createResolvedFunction(context, "tupleElement", {argument, std::make_shared<ConstantNode>(index)});
}
QueryTreeNodePtr createArrayElementFunction(const ContextPtr & context, const DataTypePtr & result_type, QueryTreeNodePtr argument, UInt64 index)
QueryTreeNodePtr createArrayElementFunction(const ContextPtr & context, QueryTreeNodePtr argument, UInt64 index)
{
return createResolvedFunction(context, "arrayElement", result_type, {std::move(argument), std::make_shared<ConstantNode>(index)});
return createResolvedFunction(context, "arrayElement", {argument, std::make_shared<ConstantNode>(index)});
}
void replaceWithSumCount(QueryTreeNodePtr & node, const FunctionNodePtr & sum_count_node, ContextPtr context)
@ -115,20 +118,20 @@ void replaceWithSumCount(QueryTreeNodePtr & node, const FunctionNodePtr & sum_co
if (function_name == "sum")
{
assert(node->getResultType()->equals(*sum_count_result_type->getElement(0)));
node = createTupleElementFunction(context, node->getResultType(), sum_count_node, 1);
node = createTupleElementFunction(context, sum_count_node, 1);
}
else if (function_name == "count")
{
assert(node->getResultType()->equals(*sum_count_result_type->getElement(1)));
node = createTupleElementFunction(context, node->getResultType(), sum_count_node, 2);
node = createTupleElementFunction(context, sum_count_node, 2);
}
else if (function_name == "avg")
{
auto sum_result = createTupleElementFunction(context, sum_count_result_type->getElement(0), sum_count_node, 1);
auto count_result = createTupleElementFunction(context, sum_count_result_type->getElement(1), sum_count_node, 2);
auto sum_result = createTupleElementFunction(context, sum_count_node, 1);
auto count_result = createTupleElementFunction(context, sum_count_node, 2);
/// To avoid integer division by zero
auto count_float_result = createResolvedFunction(context, "toFloat64", std::make_shared<DataTypeFloat64>(), {count_result});
node = createResolvedFunction(context, "divide", node->getResultType(), {sum_result, count_float_result});
auto count_float_result = createResolvedFunction(context, "toFloat64", {count_result});
node = createResolvedFunction(context, "divide", {sum_result, count_float_result});
}
else
{
@ -238,7 +241,7 @@ void tryFuseQuantiles(QueryTreeNodePtr query_tree_node, ContextPtr context)
for (size_t i = 0; i < nodes_set.size(); ++i)
{
size_t array_index = i + 1;
*nodes[i] = createArrayElementFunction(context, result_array_type->getNestedType(), quantiles_node, array_index);
*nodes[i] = createArrayElementFunction(context, quantiles_node, array_index);
}
}
}

View File

@ -55,8 +55,8 @@ public:
return;
auto multi_if_function = std::make_shared<FunctionNode>("multiIf");
multi_if_function->resolveAsFunction(multi_if_function_ptr, std::make_shared<DataTypeUInt8>());
multi_if_function->getArguments().getNodes() = std::move(multi_if_arguments);
multi_if_function->resolveAsFunction(multi_if_function_ptr->build(multi_if_function->getArgumentTypes()));
node = std::move(multi_if_function);
}

View File

@ -27,7 +27,7 @@ public:
return;
auto result_type = function_node->getResultType();
function_node->resolveAsFunction(if_function_ptr, std::move(result_type));
function_node->resolveAsFunction(if_function_ptr->build(function_node->getArgumentTypes()));
}
private:

View File

@ -48,12 +48,10 @@ public:
private:
static inline void resolveAsCountAggregateFunction(FunctionNode & function_node)
{
auto function_result_type = function_node.getResultType();
AggregateFunctionProperties properties;
auto aggregate_function = AggregateFunctionFactory::instance().get("count", {}, {}, properties);
function_node.resolveAsAggregateFunction(std::move(aggregate_function), std::move(function_result_type));
function_node.resolveAsAggregateFunction(std::move(aggregate_function));
}
};

View File

@ -4287,7 +4287,7 @@ ProjectionNames QueryAnalyzer::resolveFunction(QueryTreeNodePtr & node, Identifi
bool force_grouping_standard_compatibility = scope.context->getSettingsRef().force_grouping_standard_compatibility;
auto grouping_function = std::make_shared<FunctionGrouping>(force_grouping_standard_compatibility);
auto grouping_function_adaptor = std::make_shared<FunctionToOverloadResolverAdaptor>(std::move(grouping_function));
function_node.resolveAsFunction(std::move(grouping_function_adaptor), std::make_shared<DataTypeUInt64>());
function_node.resolveAsFunction(grouping_function_adaptor->build({}));
return result_projection_names;
}
}
@ -4307,7 +4307,7 @@ ProjectionNames QueryAnalyzer::resolveFunction(QueryTreeNodePtr & node, Identifi
AggregateFunctionProperties properties;
auto aggregate_function = AggregateFunctionFactory::instance().get(function_name, argument_types, parameters, properties);
function_node.resolveAsWindowFunction(aggregate_function, aggregate_function->getReturnType());
function_node.resolveAsWindowFunction(aggregate_function);
bool window_node_is_identifier = function_node.getWindowNode()->getNodeType() == QueryTreeNodeType::IDENTIFIER;
ProjectionName window_projection_name = resolveWindow(function_node.getWindowNode(), scope);
@ -4361,7 +4361,7 @@ ProjectionNames QueryAnalyzer::resolveFunction(QueryTreeNodePtr & node, Identifi
AggregateFunctionProperties properties;
auto aggregate_function = AggregateFunctionFactory::instance().get(function_name, argument_types, parameters, properties);
function_node.resolveAsAggregateFunction(aggregate_function, aggregate_function->getReturnType());
function_node.resolveAsAggregateFunction(aggregate_function);
return result_projection_names;
}
@ -4538,6 +4538,8 @@ ProjectionNames QueryAnalyzer::resolveFunction(QueryTreeNodePtr & node, Identifi
constant_value = std::make_shared<ConstantValue>(std::move(column_constant_value), result_type);
}
}
function_node.resolveAsFunction(std::move(function_base));
}
catch (Exception & e)
{
@ -4545,8 +4547,6 @@ ProjectionNames QueryAnalyzer::resolveFunction(QueryTreeNodePtr & node, Identifi
throw;
}
function_node.resolveAsFunction(std::move(function), std::move(result_type));
if (constant_value)
node = std::make_shared<ConstantNode>(std::move(constant_value), node);

View File

@ -13,6 +13,7 @@
#include <Analyzer/InDepthQueryTreeVisitor.h>
#include <Analyzer/ConstantNode.h>
#include <Analyzer/FunctionNode.h>
#include "Core/ColumnsWithTypeAndName.h"
namespace DB
{
@ -117,11 +118,12 @@ public:
not_function_result_type = makeNullable(not_function_result_type);
auto not_function = std::make_shared<FunctionNode>("not");
not_function->resolveAsFunction(FunctionFactory::instance().get("not", context), std::move(not_function_result_type));
auto & not_function_arguments = not_function->getArguments().getNodes();
not_function_arguments.push_back(std::move(nested_if_function_arguments_nodes[0]));
not_function->resolveAsFunction(FunctionFactory::instance().get("not", context)->build(not_function->getArgumentTypes()));
function_node_arguments_nodes[0] = std::move(not_function);
function_node_arguments_nodes.resize(1);
@ -139,8 +141,7 @@ private:
function_node.getAggregateFunction()->getParameters(),
properties);
auto function_result_type = function_node.getResultType();
function_node.resolveAsAggregateFunction(std::move(aggregate_function), std::move(function_result_type));
function_node.resolveAsAggregateFunction(std::move(aggregate_function));
}
ContextPtr & context;

View File

@ -76,7 +76,7 @@ public:
properties);
auto function_result_type = function_node->getResultType();
function_node->resolveAsAggregateFunction(std::move(aggregate_function), std::move(function_result_type));
function_node->resolveAsAggregateFunction(std::move(aggregate_function));
}
};

View File

@ -28,7 +28,7 @@ namespace ErrorCodes
}
static String getTypeString(const AggregateFunctionPtr & func, std::optional<size_t> version = std::nullopt)
static String getTypeString(const ConstAggregateFunctionPtr & func, std::optional<size_t> version = std::nullopt)
{
WriteBufferFromOwnString stream;
@ -62,18 +62,18 @@ static String getTypeString(const AggregateFunctionPtr & func, std::optional<siz
}
ColumnAggregateFunction::ColumnAggregateFunction(const AggregateFunctionPtr & func_, std::optional<size_t> version_)
ColumnAggregateFunction::ColumnAggregateFunction(const ConstAggregateFunctionPtr & func_, std::optional<size_t> version_)
: func(func_), type_string(getTypeString(func, version_)), version(version_)
{
}
ColumnAggregateFunction::ColumnAggregateFunction(const AggregateFunctionPtr & func_, const ConstArenas & arenas_)
ColumnAggregateFunction::ColumnAggregateFunction(const ConstAggregateFunctionPtr & func_, const ConstArenas & arenas_)
: foreign_arenas(arenas_), func(func_), type_string(getTypeString(func))
{
}
void ColumnAggregateFunction::set(const AggregateFunctionPtr & func_, size_t version_)
void ColumnAggregateFunction::set(const ConstAggregateFunctionPtr & func_, size_t version_)
{
func = func_;
version = version_;
@ -146,7 +146,7 @@ MutableColumnPtr ColumnAggregateFunction::convertToValues(MutableColumnPtr colum
/// insertResultInto may invalidate states, so we must unshare ownership of them
column_aggregate_func.ensureOwnership();
MutableColumnPtr res = func->getReturnType()->createColumn();
MutableColumnPtr res = func->getResultType()->createColumn();
res->reserve(data.size());
/// If there are references to states in final column, we must hold their ownership

View File

@ -70,7 +70,7 @@ private:
ArenaPtr my_arena;
/// Used for destroying states and for finalization of values.
AggregateFunctionPtr func;
ConstAggregateFunctionPtr func;
/// Source column. Used (holds source from destruction),
/// if this column has been constructed from another and uses all or part of its values.
@ -92,9 +92,9 @@ private:
/// Create a new column that has another column as a source.
MutablePtr createView() const;
explicit ColumnAggregateFunction(const AggregateFunctionPtr & func_, std::optional<size_t> version_ = std::nullopt);
explicit ColumnAggregateFunction(const ConstAggregateFunctionPtr & func_, std::optional<size_t> version_ = std::nullopt);
ColumnAggregateFunction(const AggregateFunctionPtr & func_, const ConstArenas & arenas_);
ColumnAggregateFunction(const ConstAggregateFunctionPtr & func_, const ConstArenas & arenas_);
ColumnAggregateFunction(const ColumnAggregateFunction & src_);
@ -103,10 +103,10 @@ private:
public:
~ColumnAggregateFunction() override;
void set(const AggregateFunctionPtr & func_, size_t version_);
void set(const ConstAggregateFunctionPtr & func_, size_t version_);
AggregateFunctionPtr getAggregateFunction() { return func; }
AggregateFunctionPtr getAggregateFunction() const { return func; }
ConstAggregateFunctionPtr getAggregateFunction() { return func; }
ConstAggregateFunctionPtr getAggregateFunction() const { return func; }
/// If we have another column as a source (owner of data), copy all data to ourself and reset source.
/// This is needed before inserting new elements, because we must own these elements (to destroy them in destructor),

View File

@ -0,0 +1,29 @@
#pragma once
#include <memory>
#include <vector>
namespace DB
{
class IDataType;
using DataTypePtr = std::shared_ptr<const IDataType>;
using DataTypes = std::vector<DataTypePtr>;
struct Array;
class IResolvedFunction
{
public:
virtual const DataTypePtr & getResultType() const = 0;
virtual const DataTypes & getArgumentTypes() const = 0;
virtual const Array & getParameters() const = 0;
virtual ~IResolvedFunction() = default;
};
using IResolvedFunctionPtr = std::shared_ptr<IResolvedFunction>;
}

View File

@ -19,7 +19,7 @@ namespace DB
class DataTypeAggregateFunction final : public IDataType
{
private:
AggregateFunctionPtr function;
ConstAggregateFunctionPtr function;
DataTypes argument_types;
Array parameters;
mutable std::optional<size_t> version;
@ -30,9 +30,9 @@ private:
public:
static constexpr bool is_parametric = true;
DataTypeAggregateFunction(const AggregateFunctionPtr & function_, const DataTypes & argument_types_,
DataTypeAggregateFunction(ConstAggregateFunctionPtr function_, const DataTypes & argument_types_,
const Array & parameters_, std::optional<size_t> version_ = std::nullopt)
: function(function_)
: function(std::move(function_))
, argument_types(argument_types_)
, parameters(parameters_)
, version(version_)
@ -40,7 +40,7 @@ public:
}
String getFunctionName() const { return function->getName(); }
AggregateFunctionPtr getFunction() const { return function; }
ConstAggregateFunctionPtr getFunction() const { return function; }
String doGetName() const override;
String getNameWithoutVersion() const;
@ -51,7 +51,7 @@ public:
bool canBeInsideNullable() const override { return false; }
DataTypePtr getReturnType() const { return function->getReturnType(); }
DataTypePtr getReturnType() const { return function->getResultType(); }
DataTypePtr getReturnTypeToPredict() const { return function->getReturnTypeToPredict(); }
DataTypes getArgumentsDataTypes() const { return argument_types; }

View File

@ -131,9 +131,9 @@ static std::pair<DataTypePtr, DataTypeCustomDescPtr> create(const ASTPtr & argum
DataTypePtr storage_type = DataTypeFactory::instance().get(argument_types[0]->getName());
if (!function->getReturnType()->equals(*removeLowCardinality(storage_type)))
if (!function->getResultType()->equals(*removeLowCardinality(storage_type)))
{
throw Exception("Incompatible data types between aggregate function '" + function->getName() + "' which returns " + function->getReturnType()->getName() + " and column storage type " + storage_type->getName(),
throw Exception("Incompatible data types between aggregate function '" + function->getName() + "' which returns " + function->getResultType()->getName() + " and column storage type " + storage_type->getName(),
ErrorCodes::BAD_ARGUMENTS);
}

View File

@ -108,14 +108,14 @@ void SerializationAggregateFunction::deserializeBinaryBulk(IColumn & column, Rea
}
}
static String serializeToString(const AggregateFunctionPtr & function, const IColumn & column, size_t row_num, size_t version)
static String serializeToString(const ConstAggregateFunctionPtr & function, const IColumn & column, size_t row_num, size_t version)
{
WriteBufferFromOwnString buffer;
function->serialize(assert_cast<const ColumnAggregateFunction &>(column).getData()[row_num], buffer, version);
return buffer.str();
}
static void deserializeFromString(const AggregateFunctionPtr & function, IColumn & column, const String & s, size_t version)
static void deserializeFromString(const ConstAggregateFunctionPtr & function, IColumn & column, const String & s, size_t version)
{
ColumnAggregateFunction & column_concrete = assert_cast<ColumnAggregateFunction &>(column);

View File

@ -11,14 +11,14 @@ namespace DB
class SerializationAggregateFunction final : public ISerialization
{
private:
AggregateFunctionPtr function;
ConstAggregateFunctionPtr function;
String type_name;
size_t version;
public:
static constexpr bool is_parametric = true;
SerializationAggregateFunction(const AggregateFunctionPtr & function_, String type_name_, size_t version_)
SerializationAggregateFunction(const ConstAggregateFunctionPtr & function_, String type_name_, size_t version_)
: function(function_), type_name(std::move(type_name_)), version(version_) {}
/// NOTE These two functions for serializing single values are incompatible with the functions below.

View File

@ -1736,7 +1736,7 @@ namespace
}
const std::shared_ptr<const DataTypeAggregateFunction> aggregate_function_data_type;
const AggregateFunctionPtr aggregate_function;
ConstAggregateFunctionPtr aggregate_function;
String text_buffer;
};

View File

@ -895,7 +895,7 @@ class FunctionBinaryArithmetic : public IFunction
const ColumnAggregateFunction & column = typeid_cast<const ColumnAggregateFunction &>(
agg_state_is_const ? assert_cast<const ColumnConst &>(agg_state_column).getDataColumn() : agg_state_column);
AggregateFunctionPtr function = column.getAggregateFunction();
ConstAggregateFunctionPtr function = column.getAggregateFunction();
size_t size = agg_state_is_const ? 1 : input_rows_count;
@ -960,7 +960,7 @@ class FunctionBinaryArithmetic : public IFunction
const ColumnAggregateFunction & rhs = typeid_cast<const ColumnAggregateFunction &>(
rhs_is_const ? assert_cast<const ColumnConst &>(rhs_column).getDataColumn() : rhs_column);
AggregateFunctionPtr function = lhs.getAggregateFunction();
ConstAggregateFunctionPtr function = lhs.getAggregateFunction();
size_t size = (lhs_is_const && rhs_is_const) ? 1 : input_rows_count;

View File

@ -3,6 +3,8 @@
#include <Core/ColumnNumbers.h>
#include <Core/ColumnsWithTypeAndName.h>
#include <Core/Names.h>
#include <Core/IResolvedFunction.h>
#include <Common/Exception.h>
#include <DataTypes/IDataType.h>
#include "config.h"
@ -122,11 +124,11 @@ using Values = std::vector<llvm::Value *>;
/** Function with known arguments and return type (when the specific overload was chosen).
* It is also the point where all function-specific properties are known.
*/
class IFunctionBase
class IFunctionBase : public IResolvedFunction
{
public:
virtual ~IFunctionBase() = default;
~IFunctionBase() override = default;
virtual ColumnPtr execute( /// NOLINT
const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count, bool dry_run = false) const
@ -137,8 +139,10 @@ public:
/// Get the main function name.
virtual String getName() const = 0;
virtual const DataTypes & getArgumentTypes() const = 0;
virtual const DataTypePtr & getResultType() const = 0;
const Array & getParameters() const final
{
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "IFunctionBase doesn't support getParameters method");
}
/// Do preparations and return executable.
/// sample_columns should contain data types of arguments and values of constants, if relevant.

View File

@ -51,6 +51,8 @@ public:
const DataTypes & getArgumentTypes() const override { return arguments; }
const DataTypePtr & getResultType() const override { return result_type; }
const FunctionPtr & getFunction() const { return function; }
#if USE_EMBEDDED_COMPILER
bool isCompilable() const override { return function->isCompilable(getArgumentTypes()); }

View File

@ -104,7 +104,7 @@ DataTypePtr FunctionArrayReduce::getReturnTypeImpl(const ColumnsWithTypeAndName
aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name, argument_types, params_row, properties);
}
return aggregate_function->getReturnType();
return aggregate_function->getResultType();
}

View File

@ -122,7 +122,7 @@ DataTypePtr FunctionArrayReduceInRanges::getReturnTypeImpl(const ColumnsWithType
aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name, argument_types, params_row, properties);
}
return std::make_shared<DataTypeArray>(aggregate_function->getReturnType());
return std::make_shared<DataTypeArray>(aggregate_function->getResultType());
}

View File

@ -87,7 +87,7 @@ DataTypePtr FunctionInitializeAggregation::getReturnTypeImpl(const ColumnsWithTy
aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name, argument_types, params_row, properties);
}
return aggregate_function->getReturnType();
return aggregate_function->getResultType();
}

View File

@ -91,7 +91,7 @@ public:
if (arguments.size() == 2)
column_with_groups = arguments[1].column;
AggregateFunctionPtr aggregate_function_ptr = column_with_states->getAggregateFunction();
ConstAggregateFunctionPtr aggregate_function_ptr = column_with_states->getAggregateFunction();
const IAggregateFunction & agg_func = *aggregate_function_ptr;
AlignedBuffer place(agg_func.sizeOfData(), agg_func.alignOfData());
@ -99,7 +99,7 @@ public:
/// Will pass empty arena if agg_func does not allocate memory in arena
std::unique_ptr<Arena> arena = agg_func.allocatesMemoryInArena() ? std::make_unique<Arena>() : nullptr;
auto result_column_ptr = agg_func.getReturnType()->createColumn();
auto result_column_ptr = agg_func.getResultType()->createColumn();
IColumn & result_column = *result_column_ptr;
result_column.reserve(column_with_states->size());

View File

@ -47,8 +47,6 @@ void ActionsDAG::Node::toTree(JSONBuilder::JSONMap & map) const
if (function_base)
map.add("Function", function_base->getName());
else if (function_builder)
map.add("Function", function_builder->getName());
if (type == ActionType::FUNCTION)
map.add("Compiled", is_function_compiled);
@ -166,7 +164,6 @@ const ActionsDAG::Node & ActionsDAG::addFunction(
Node node;
node.type = ActionType::FUNCTION;
node.function_builder = function;
node.children = std::move(children);
bool all_const = true;
@ -238,6 +235,86 @@ const ActionsDAG::Node & ActionsDAG::addFunction(
return addNode(std::move(node));
}
const ActionsDAG::Node & ActionsDAG::addFunction(
const FunctionBasePtr & function_base,
NodeRawConstPtrs children,
std::string result_name)
{
size_t num_arguments = children.size();
Node node;
node.type = ActionType::FUNCTION;
node.children = std::move(children);
bool all_const = true;
ColumnsWithTypeAndName arguments(num_arguments);
for (size_t i = 0; i < num_arguments; ++i)
{
const auto & child = *node.children[i];
ColumnWithTypeAndName argument;
argument.column = child.column;
argument.type = child.result_type;
argument.name = child.result_name;
if (!argument.column || !isColumnConst(*argument.column))
all_const = false;
arguments[i] = std::move(argument);
}
node.function_base = function_base;
node.result_type = node.function_base->getResultType();
node.function = node.function_base->prepare(arguments);
node.is_deterministic = node.function_base->isDeterministic();
/// If all arguments are constants, and function is suitable to be executed in 'prepare' stage - execute function.
if (node.function_base->isSuitableForConstantFolding())
{
ColumnPtr column;
if (all_const)
{
size_t num_rows = arguments.empty() ? 0 : arguments.front().column->size();
column = node.function->execute(arguments, node.result_type, num_rows, true);
}
else
{
column = node.function_base->getConstantResultForNonConstArguments(arguments, node.result_type);
}
/// If the result is not a constant, just in case, we will consider the result as unknown.
if (column && isColumnConst(*column))
{
/// All constant (literal) columns in block are added with size 1.
/// But if there was no columns in block before executing a function, the result has size 0.
/// Change the size to 1.
if (column->empty())
column = column->cloneResized(1);
node.column = std::move(column);
}
}
if (result_name.empty())
{
result_name = function_base->getName() + "(";
for (size_t i = 0; i < num_arguments; ++i)
{
if (i)
result_name += ", ";
result_name += node.children[i]->result_name;
}
result_name += ")";
}
node.result_name = std::move(result_name);
return addNode(std::move(node));
}
const ActionsDAG::Node & ActionsDAG::findInOutputs(const std::string & name) const
{
if (const auto * node = tryFindInOutputs(name))
@ -1927,8 +2004,7 @@ ActionsDAGPtr ActionsDAG::cloneActionsForFilterPushDown(
FunctionOverloadResolverPtr func_builder_cast = CastInternalOverloadResolver<CastType::nonAccurate>::createImpl();
predicate->function_builder = func_builder_cast;
predicate->function_base = predicate->function_builder->build(arguments);
predicate->function_base = func_builder_cast->build(arguments);
predicate->function = predicate->function_base->prepare(arguments);
}
}
@ -1939,7 +2015,9 @@ ActionsDAGPtr ActionsDAG::cloneActionsForFilterPushDown(
predicate->children.swap(new_children);
auto arguments = prepareFunctionArguments(predicate->children);
predicate->function_base = predicate->function_builder->build(arguments);
FunctionOverloadResolverPtr func_builder_and = std::make_unique<FunctionToOverloadResolverAdaptor>(std::make_shared<FunctionAnd>());
predicate->function_base = func_builder_and->build(arguments);
predicate->function = predicate->function_base->prepare(arguments);
}
}

View File

@ -74,7 +74,6 @@ public:
std::string result_name;
DataTypePtr result_type;
FunctionOverloadResolverPtr function_builder;
/// Can be used to get function signature or properties like monotonicity.
FunctionBasePtr function_base;
/// Prepared function which is used in function execution.
@ -139,6 +138,10 @@ public:
const FunctionOverloadResolverPtr & function,
NodeRawConstPtrs children,
std::string result_name);
const Node & addFunction(
const FunctionBasePtr & function_base,
NodeRawConstPtrs children,
std::string result_name);
/// Find first column by name in output nodes. This search is linear.
const Node & findInOutputs(const std::string & name) const;

View File

@ -53,7 +53,7 @@ void AggregateDescription::explain(WriteBuffer & out, size_t indent) const
out << type->getName();
}
out << ") → " << function->getReturnType()->getName() << "\n";
out << ") → " << function->getResultType()->getName() << "\n";
}
else
out << prefix << " Function: nullptr\n";
@ -109,7 +109,7 @@ void AggregateDescription::explain(JSONBuilder::JSONMap & map) const
args_array->add(type->getName());
function_map->add("Argument Types", std::move(args_array));
function_map->add("Result Type", function->getReturnType()->getName());
function_map->add("Result Type", function->getResultType()->getName());
map.add("Function", std::move(function_map));
}

View File

@ -45,7 +45,7 @@ OutputBlockColumns prepareOutputBlockColumns(
}
else
{
final_aggregate_columns[i] = aggregate_functions[i]->getReturnType()->createColumn();
final_aggregate_columns[i] = aggregate_functions[i]->getResultType()->createColumn();
final_aggregate_columns[i]->reserve(rows);
if (aggregate_functions[i]->isState())

View File

@ -433,7 +433,7 @@ Block Aggregator::Params::getHeader(
{
auto & elem = res.getByName(aggregate.column_name);
elem.type = aggregate.function->getReturnType();
elem.type = aggregate.function->getResultType();
elem.column = elem.type->createColumn();
}
}
@ -452,7 +452,7 @@ Block Aggregator::Params::getHeader(
DataTypePtr type;
if (final)
type = aggregate.function->getReturnType();
type = aggregate.function->getResultType();
else
type = std::make_shared<DataTypeAggregateFunction>(aggregate.function, argument_types, aggregate.parameters);

View File

@ -423,7 +423,7 @@ void ExpressionAnalyzer::analyzeAggregation(ActionsDAGPtr & temp_actions)
aggregated_columns = temp_actions->getNamesAndTypesList();
for (const auto & desc : aggregate_descriptions)
aggregated_columns.emplace_back(desc.column_name, desc.function->getReturnType());
aggregated_columns.emplace_back(desc.column_name, desc.function->getResultType());
}
@ -2021,7 +2021,7 @@ ExpressionAnalysisResult::ExpressionAnalysisResult(
for (const auto & f : w.window_functions)
{
query_analyzer.columns_after_window.push_back(
{f.column_name, f.aggregate_function->getReturnType()});
{f.column_name, f.aggregate_function->getResultType()});
}
}

View File

@ -403,7 +403,7 @@ static void compileInsertAggregatesIntoResultColumns(llvm::Module & module, cons
std::vector<ColumnDataPlaceholder> columns(functions.size());
for (size_t i = 0; i < functions.size(); ++i)
{
auto return_type = functions[i].function->getReturnType();
auto return_type = functions[i].function->getResultType();
auto * data = b.CreateLoad(column_type, b.CreateConstInBoundsGEP1_64(column_type, columns_arg, i));
auto * column_data_type = toNativeType(b, removeNullable(return_type));

View File

@ -365,8 +365,8 @@ void Planner::buildQueryPlanIfNeeded()
{
auto function_node = std::make_shared<FunctionNode>("and");
auto and_function = FunctionFactory::instance().get("and", query_context);
function_node->resolveAsFunction(std::move(and_function), std::make_shared<DataTypeUInt8>());
function_node->getArguments().getNodes() = {query_node.getPrewhere(), query_node.getWhere()};
function_node->resolveAsFunction(and_function->build(function_node->getArgumentTypes()));
query_node.getWhere() = std::move(function_node);
query_node.getPrewhere() = {};
}

View File

@ -121,7 +121,7 @@ public:
return node;
}
const ActionsDAG::Node * addFunctionIfNecessary(const std::string & node_name, ActionsDAG::NodeRawConstPtrs children, FunctionOverloadResolverPtr function)
const ActionsDAG::Node * addFunctionIfNecessary(const std::string & node_name, ActionsDAG::NodeRawConstPtrs children, FunctionBasePtr function)
{
auto it = node_name_to_node.find(node_name);
if (it != node_name_to_node.end())
@ -325,7 +325,7 @@ PlannerActionsVisitorImpl::NodeNameAndNodeMinLevel PlannerActionsVisitorImpl::vi
lambda_actions, captured_column_names, lambda_arguments_names_and_types, result_type, lambda_expression_node_name);
actions_stack.pop_back();
actions_stack[level].addFunctionIfNecessary(lambda_node_name, std::move(lambda_children), std::move(function_capture));
actions_stack[level].addFunctionIfNecessary(lambda_node_name, std::move(lambda_children), function_capture->build({}));
size_t actions_stack_size = actions_stack.size();
for (size_t i = level + 1; i < actions_stack_size; ++i)

View File

@ -101,14 +101,14 @@ public:
{
auto grouping_ordinary_function = std::make_shared<FunctionGroupingOrdinary>(arguments_indexes, force_grouping_standard_compatibility);
auto grouping_ordinary_function_adaptor = std::make_shared<FunctionToOverloadResolverAdaptor>(std::move(grouping_ordinary_function));
function_node->resolveAsFunction(std::move(grouping_ordinary_function_adaptor), std::make_shared<DataTypeUInt64>());
function_node->resolveAsFunction(grouping_ordinary_function_adaptor->build({}));
break;
}
case GroupByKind::ROLLUP:
{
auto grouping_rollup_function = std::make_shared<FunctionGroupingForRollup>(arguments_indexes, aggregation_keys_size, force_grouping_standard_compatibility);
auto grouping_rollup_function_adaptor = std::make_shared<FunctionToOverloadResolverAdaptor>(std::move(grouping_rollup_function));
function_node->resolveAsFunction(std::move(grouping_rollup_function_adaptor), std::make_shared<DataTypeUInt64>());
function_node->resolveAsFunction(grouping_rollup_function_adaptor->build({}));
function_node->getArguments().getNodes().push_back(std::move(grouping_set_argument_column));
break;
}
@ -116,7 +116,7 @@ public:
{
auto grouping_cube_function = std::make_shared<FunctionGroupingForCube>(arguments_indexes, aggregation_keys_size, force_grouping_standard_compatibility);
auto grouping_cube_function_adaptor = std::make_shared<FunctionToOverloadResolverAdaptor>(std::move(grouping_cube_function));
function_node->resolveAsFunction(std::move(grouping_cube_function_adaptor), std::make_shared<DataTypeUInt64>());
function_node->resolveAsFunction(grouping_cube_function_adaptor->build({}));
function_node->getArguments().getNodes().push_back(std::move(grouping_set_argument_column));
break;
}
@ -124,7 +124,7 @@ public:
{
auto grouping_grouping_sets_function = std::make_shared<FunctionGroupingForGroupingSets>(arguments_indexes, grouping_sets_keys_indices, force_grouping_standard_compatibility);
auto grouping_grouping_sets_function_adaptor = std::make_shared<FunctionToOverloadResolverAdaptor>(std::move(grouping_grouping_sets_function));
function_node->resolveAsFunction(std::move(grouping_grouping_sets_function_adaptor), std::make_shared<DataTypeUInt64>());
function_node->resolveAsFunction(grouping_grouping_sets_function_adaptor->build({}));
function_node->getArguments().getNodes().push_back(std::move(grouping_set_argument_column));
break;
}

View File

@ -65,7 +65,7 @@ std::optional<AggregationAnalysisResult> analyzeAggregation(QueryTreeNodePtr & q
ColumnsWithTypeAndName aggregates_columns;
aggregates_columns.reserve(aggregates_descriptions.size());
for (auto & aggregate_description : aggregates_descriptions)
aggregates_columns.emplace_back(nullptr, aggregate_description.function->getReturnType(), aggregate_description.column_name);
aggregates_columns.emplace_back(nullptr, aggregate_description.function->getResultType(), aggregate_description.column_name);
Names aggregation_keys;
@ -284,7 +284,7 @@ std::optional<WindowAnalysisResult> analyzeWindow(QueryTreeNodePtr & query_tree,
for (auto & window_description : window_descriptions)
for (auto & window_function : window_description.window_functions)
window_functions_additional_columns.emplace_back(nullptr, window_function.aggregate_function->getReturnType(), window_function.column_name);
window_functions_additional_columns.emplace_back(nullptr, window_function.aggregate_function->getResultType(), window_function.column_name);
auto before_window_step = std::make_unique<ActionsChainStep>(before_window_actions,
ActionsChainStep::AvailableOutputColumnsStrategy::ALL_NODES,

Some files were not shown because too many files have changed in this diff Show More