mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-22 15:42:02 +00:00
Refactor FunctionNode
This commit is contained in:
parent
b6eddbac0d
commit
2c70dbc76a
@ -49,14 +49,16 @@ private:
|
||||
|
||||
public:
|
||||
AggregateFunctionThrow(const DataTypes & argument_types_, const Array & parameters_, Float64 throw_probability_)
|
||||
: IAggregateFunctionDataHelper(argument_types_, parameters_), throw_probability(throw_probability_) {}
|
||||
: IAggregateFunctionDataHelper(argument_types_, parameters_, createResultType())
|
||||
, throw_probability(throw_probability_)
|
||||
{}
|
||||
|
||||
String getName() const override
|
||||
{
|
||||
return "aggThrow";
|
||||
}
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
static DataTypePtr createResultType()
|
||||
{
|
||||
return std::make_shared<DataTypeUInt8>();
|
||||
}
|
||||
|
@ -37,10 +37,10 @@ class AggregateFunctionAnalysisOfVariance final : public IAggregateFunctionDataH
|
||||
{
|
||||
public:
|
||||
explicit AggregateFunctionAnalysisOfVariance(const DataTypes & arguments, const Array & params)
|
||||
: IAggregateFunctionDataHelper(arguments, params)
|
||||
: IAggregateFunctionDataHelper(arguments, params, createResultType())
|
||||
{}
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
DataTypePtr createResultType() const
|
||||
{
|
||||
DataTypes types {std::make_shared<DataTypeNumber<Float64>>(), std::make_shared<DataTypeNumber<Float64>>() };
|
||||
Strings names {"f_statistic", "p_value"};
|
||||
|
@ -38,7 +38,6 @@ template <typename Data>
|
||||
class AggregateFunctionArgMinMax final : public IAggregateFunctionDataHelper<Data, AggregateFunctionArgMinMax<Data>>
|
||||
{
|
||||
private:
|
||||
const DataTypePtr & type_res;
|
||||
const DataTypePtr & type_val;
|
||||
const SerializationPtr serialization_res;
|
||||
const SerializationPtr serialization_val;
|
||||
@ -47,10 +46,9 @@ private:
|
||||
|
||||
public:
|
||||
AggregateFunctionArgMinMax(const DataTypePtr & type_res_, const DataTypePtr & type_val_)
|
||||
: Base({type_res_, type_val_}, {})
|
||||
, type_res(this->argument_types[0])
|
||||
: Base({type_res_, type_val_}, {}, type_res_)
|
||||
, type_val(this->argument_types[1])
|
||||
, serialization_res(type_res->getDefaultSerialization())
|
||||
, serialization_res(type_res_->getDefaultSerialization())
|
||||
, serialization_val(type_val->getDefaultSerialization())
|
||||
{
|
||||
if (!type_val->isComparable())
|
||||
@ -63,11 +61,6 @@ public:
|
||||
return StringRef(Data::ValueData_t::name()) == StringRef("min") ? "argMin" : "argMax";
|
||||
}
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
{
|
||||
return type_res;
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
|
||||
{
|
||||
if (this->data(place).value.changeIfBetter(*columns[1], row_num, arena))
|
||||
|
@ -30,7 +30,7 @@ private:
|
||||
|
||||
public:
|
||||
AggregateFunctionArray(AggregateFunctionPtr nested_, const DataTypes & arguments, const Array & params_)
|
||||
: IAggregateFunctionHelper<AggregateFunctionArray>(arguments, params_)
|
||||
: IAggregateFunctionHelper<AggregateFunctionArray>(arguments, params_, createResultType())
|
||||
, nested_func(nested_), num_arguments(arguments.size())
|
||||
{
|
||||
assert(parameters == nested_func->getParameters());
|
||||
@ -44,9 +44,9 @@ public:
|
||||
return nested_func->getName() + "Array";
|
||||
}
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
DataTypePtr createResultType() const
|
||||
{
|
||||
return nested_func->getReturnType();
|
||||
return nested_func->getResultType();
|
||||
}
|
||||
|
||||
const IAggregateFunction & getBaseAggregateFunctionWithSameStateRepresentation() const override
|
||||
|
@ -11,6 +11,9 @@
|
||||
#include <AggregateFunctions/AggregateFunctionSum.h>
|
||||
#include <Core/DecimalFunctions.h>
|
||||
|
||||
#include "Core/IResolvedFunction.h"
|
||||
#include "DataTypes/IDataType.h"
|
||||
#include "DataTypes/Serializations/ISerialization.h"
|
||||
#include "config.h"
|
||||
|
||||
#if USE_EMBEDDED_COMPILER
|
||||
@ -83,10 +86,20 @@ public:
|
||||
using Fraction = AvgFraction<Numerator, Denominator>;
|
||||
|
||||
explicit AggregateFunctionAvgBase(const DataTypes & argument_types_,
|
||||
UInt32 num_scale_ = 0, UInt32 denom_scale_ = 0)
|
||||
: Base(argument_types_, {}), num_scale(num_scale_), denom_scale(denom_scale_) {}
|
||||
UInt32 num_scale_ = 0, UInt32 denom_scale_ = 0)
|
||||
: Base(argument_types_, {}, createResultType())
|
||||
, num_scale(num_scale_)
|
||||
, denom_scale(denom_scale_)
|
||||
{}
|
||||
|
||||
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeNumber<Float64>>(); }
|
||||
AggregateFunctionAvgBase(const DataTypes & argument_types_, const DataTypePtr & result_type_,
|
||||
UInt32 num_scale_ = 0, UInt32 denom_scale_ = 0)
|
||||
: Base(argument_types_, {}, result_type_)
|
||||
, num_scale(num_scale_)
|
||||
, denom_scale(denom_scale_)
|
||||
{}
|
||||
|
||||
DataTypePtr createResultType() const { return std::make_shared<DataTypeNumber<Float64>>(); }
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
@ -135,7 +148,7 @@ public:
|
||||
for (const auto & argument : this->argument_types)
|
||||
can_be_compiled &= canBeNativeType(*argument);
|
||||
|
||||
auto return_type = getReturnType();
|
||||
auto return_type = this->getResultType();
|
||||
can_be_compiled &= canBeNativeType(*return_type);
|
||||
|
||||
return can_be_compiled;
|
||||
|
@ -97,11 +97,12 @@ class AggregateFunctionBitwise final : public IAggregateFunctionDataHelper<Data,
|
||||
{
|
||||
public:
|
||||
explicit AggregateFunctionBitwise(const DataTypePtr & type)
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionBitwise<T, Data>>({type}, {}) {}
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionBitwise<T, Data>>({type}, {}, createResultType())
|
||||
{}
|
||||
|
||||
String getName() const override { return Data::name(); }
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
static DataTypePtr createResultType()
|
||||
{
|
||||
return std::make_shared<DataTypeNumber<T>>();
|
||||
}
|
||||
@ -137,7 +138,7 @@ public:
|
||||
|
||||
bool isCompilable() const override
|
||||
{
|
||||
auto return_type = getReturnType();
|
||||
auto return_type = this->getResultType();
|
||||
return canBeNativeType(*return_type);
|
||||
}
|
||||
|
||||
@ -151,7 +152,7 @@ public:
|
||||
{
|
||||
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
|
||||
|
||||
auto * return_type = toNativeType(b, getReturnType());
|
||||
auto * return_type = toNativeType(b, this->getResultType());
|
||||
|
||||
auto * value_ptr = aggregate_data_ptr;
|
||||
auto * value = b.CreateLoad(return_type, value_ptr);
|
||||
@ -166,7 +167,7 @@ public:
|
||||
{
|
||||
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
|
||||
|
||||
auto * return_type = toNativeType(b, getReturnType());
|
||||
auto * return_type = toNativeType(b, this->getResultType());
|
||||
|
||||
auto * value_dst_ptr = aggregate_data_dst_ptr;
|
||||
auto * value_dst = b.CreateLoad(return_type, value_dst_ptr);
|
||||
@ -183,7 +184,7 @@ public:
|
||||
{
|
||||
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
|
||||
|
||||
auto * return_type = toNativeType(b, getReturnType());
|
||||
auto * return_type = toNativeType(b, this->getResultType());
|
||||
auto * value_ptr = aggregate_data_ptr;
|
||||
|
||||
return b.CreateLoad(return_type, value_ptr);
|
||||
|
@ -112,7 +112,7 @@ public:
|
||||
}
|
||||
|
||||
explicit AggregateFunctionBoundingRatio(const DataTypes & arguments)
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionBoundingRatioData, AggregateFunctionBoundingRatio>(arguments, {})
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionBoundingRatioData, AggregateFunctionBoundingRatio>(arguments, {}, std::make_shared<DataTypeFloat64>())
|
||||
{
|
||||
const auto * x_arg = arguments.at(0).get();
|
||||
const auto * y_arg = arguments.at(1).get();
|
||||
@ -122,11 +122,6 @@ public:
|
||||
ErrorCodes::BAD_ARGUMENTS);
|
||||
}
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
{
|
||||
return std::make_shared<DataTypeFloat64>();
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, const size_t row_num, Arena *) const override
|
||||
|
@ -46,9 +46,9 @@ private:
|
||||
}
|
||||
|
||||
public:
|
||||
AggregateFunctionCategoricalIV(const DataTypes & arguments_, const Array & params_) :
|
||||
IAggregateFunctionHelper<AggregateFunctionCategoricalIV>{arguments_, params_},
|
||||
category_count{arguments_.size() - 1}
|
||||
AggregateFunctionCategoricalIV(const DataTypes & arguments_, const Array & params_)
|
||||
: IAggregateFunctionHelper<AggregateFunctionCategoricalIV>{arguments_, params_, createResultType()}
|
||||
, category_count{arguments_.size() - 1}
|
||||
{
|
||||
// notice: argument types has been checked before
|
||||
}
|
||||
@ -121,7 +121,7 @@ public:
|
||||
buf.readStrict(place, sizeOfData());
|
||||
}
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
static DataTypePtr createResultType()
|
||||
{
|
||||
return std::make_shared<DataTypeArray>(
|
||||
std::make_shared<DataTypeNumber<Float64>>());
|
||||
|
@ -39,11 +39,13 @@ namespace ErrorCodes
|
||||
class AggregateFunctionCount final : public IAggregateFunctionDataHelper<AggregateFunctionCountData, AggregateFunctionCount>
|
||||
{
|
||||
public:
|
||||
explicit AggregateFunctionCount(const DataTypes & argument_types_) : IAggregateFunctionDataHelper(argument_types_, {}) {}
|
||||
explicit AggregateFunctionCount(const DataTypes & argument_types_)
|
||||
: IAggregateFunctionDataHelper(argument_types_, {}, createResultType())
|
||||
{}
|
||||
|
||||
String getName() const override { return "count"; }
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
static DataTypePtr createResultType()
|
||||
{
|
||||
return std::make_shared<DataTypeUInt64>();
|
||||
}
|
||||
@ -167,7 +169,7 @@ public:
|
||||
{
|
||||
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
|
||||
|
||||
auto * return_type = toNativeType(b, getReturnType());
|
||||
auto * return_type = toNativeType(b, this->getResultType());
|
||||
|
||||
auto * count_value_ptr = aggregate_data_ptr;
|
||||
auto * count_value = b.CreateLoad(return_type, count_value_ptr);
|
||||
@ -180,7 +182,7 @@ public:
|
||||
{
|
||||
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
|
||||
|
||||
auto * return_type = toNativeType(b, getReturnType());
|
||||
auto * return_type = toNativeType(b, this->getResultType());
|
||||
|
||||
auto * count_value_dst_ptr = aggregate_data_dst_ptr;
|
||||
auto * count_value_dst = b.CreateLoad(return_type, count_value_dst_ptr);
|
||||
@ -197,7 +199,7 @@ public:
|
||||
{
|
||||
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
|
||||
|
||||
auto * return_type = toNativeType(b, getReturnType());
|
||||
auto * return_type = toNativeType(b, this->getResultType());
|
||||
auto * count_value_ptr = aggregate_data_ptr;
|
||||
|
||||
return b.CreateLoad(return_type, count_value_ptr);
|
||||
@ -214,7 +216,7 @@ class AggregateFunctionCountNotNullUnary final
|
||||
{
|
||||
public:
|
||||
AggregateFunctionCountNotNullUnary(const DataTypePtr & argument, const Array & params)
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionCountData, AggregateFunctionCountNotNullUnary>({argument}, params)
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionCountData, AggregateFunctionCountNotNullUnary>({argument}, params, createResultType())
|
||||
{
|
||||
if (!argument->isNullable())
|
||||
throw Exception("Logical error: not Nullable data type passed to AggregateFunctionCountNotNullUnary", ErrorCodes::LOGICAL_ERROR);
|
||||
@ -222,7 +224,7 @@ public:
|
||||
|
||||
String getName() const override { return "count"; }
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
static DataTypePtr createResultType()
|
||||
{
|
||||
return std::make_shared<DataTypeUInt64>();
|
||||
}
|
||||
@ -311,7 +313,7 @@ public:
|
||||
{
|
||||
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
|
||||
|
||||
auto * return_type = toNativeType(b, getReturnType());
|
||||
auto * return_type = toNativeType(b, this->getResultType());
|
||||
|
||||
auto * is_null_value = b.CreateExtractValue(values[0], {1});
|
||||
auto * increment_value = b.CreateSelect(is_null_value, llvm::ConstantInt::get(return_type, 0), llvm::ConstantInt::get(return_type, 1));
|
||||
@ -327,7 +329,7 @@ public:
|
||||
{
|
||||
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
|
||||
|
||||
auto * return_type = toNativeType(b, getReturnType());
|
||||
auto * return_type = toNativeType(b, this->getResultType());
|
||||
|
||||
auto * count_value_dst_ptr = aggregate_data_dst_ptr;
|
||||
auto * count_value_dst = b.CreateLoad(return_type, count_value_dst_ptr);
|
||||
@ -344,7 +346,7 @@ public:
|
||||
{
|
||||
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
|
||||
|
||||
auto * return_type = toNativeType(b, getReturnType());
|
||||
auto * return_type = toNativeType(b, this->getResultType());
|
||||
auto * count_value_ptr = aggregate_data_ptr;
|
||||
|
||||
return b.CreateLoad(return_type, count_value_ptr);
|
||||
|
@ -31,7 +31,7 @@ class AggregationFunctionDeltaSum final
|
||||
{
|
||||
public:
|
||||
AggregationFunctionDeltaSum(const DataTypes & arguments, const Array & params)
|
||||
: IAggregateFunctionDataHelper<AggregationFunctionDeltaSumData<T>, AggregationFunctionDeltaSum<T>>{arguments, params}
|
||||
: IAggregateFunctionDataHelper<AggregationFunctionDeltaSumData<T>, AggregationFunctionDeltaSum<T>>{arguments, params, createResultType()}
|
||||
{}
|
||||
|
||||
AggregationFunctionDeltaSum()
|
||||
@ -40,7 +40,7 @@ public:
|
||||
|
||||
String getName() const override { return "deltaSum"; }
|
||||
|
||||
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeNumber<T>>(); }
|
||||
static DataTypePtr createResultType() { return std::make_shared<DataTypeNumber<T>>(); }
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
|
@ -38,7 +38,7 @@ public:
|
||||
: IAggregateFunctionDataHelper<
|
||||
AggregationFunctionDeltaSumTimestampData<ValueType, TimestampType>,
|
||||
AggregationFunctionDeltaSumTimestamp<ValueType, TimestampType>
|
||||
>{arguments, params}
|
||||
>{arguments, params, createResultType()}
|
||||
{}
|
||||
|
||||
AggregationFunctionDeltaSumTimestamp()
|
||||
@ -52,7 +52,7 @@ public:
|
||||
|
||||
String getName() const override { return "deltaSumTimestamp"; }
|
||||
|
||||
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeNumber<ValueType>>(); }
|
||||
static DataTypePtr createResultType() { return std::make_shared<DataTypeNumber<ValueType>>(); }
|
||||
|
||||
void NO_SANITIZE_UNDEFINED ALWAYS_INLINE add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
{
|
||||
|
@ -168,7 +168,7 @@ private:
|
||||
|
||||
public:
|
||||
AggregateFunctionDistinct(AggregateFunctionPtr nested_func_, const DataTypes & arguments, const Array & params_)
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionDistinct>(arguments, params_)
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionDistinct>(arguments, params_, nested_func_->getResultType())
|
||||
, nested_func(nested_func_)
|
||||
, arguments_num(arguments.size())
|
||||
{
|
||||
@ -255,11 +255,6 @@ public:
|
||||
return nested_func->getName() + "Distinct";
|
||||
}
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
{
|
||||
return nested_func->getReturnType();
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override
|
||||
{
|
||||
return true;
|
||||
|
@ -92,14 +92,14 @@ private:
|
||||
|
||||
public:
|
||||
explicit AggregateFunctionEntropy(const DataTypes & argument_types_)
|
||||
: IAggregateFunctionDataHelper<EntropyData<Value>, AggregateFunctionEntropy<Value>>(argument_types_, {})
|
||||
: IAggregateFunctionDataHelper<EntropyData<Value>, AggregateFunctionEntropy<Value>>(argument_types_, {}, createResultType())
|
||||
, num_args(argument_types_.size())
|
||||
{
|
||||
}
|
||||
|
||||
String getName() const override { return "entropy"; }
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
static DataTypePtr createResultType()
|
||||
{
|
||||
return std::make_shared<DataTypeNumber<Float64>>();
|
||||
}
|
||||
|
@ -29,7 +29,7 @@ private:
|
||||
|
||||
public:
|
||||
AggregateFunctionExponentialMovingAverage(const DataTypes & argument_types_, const Array & params)
|
||||
: IAggregateFunctionDataHelper<ExponentiallySmoothedAverage, AggregateFunctionExponentialMovingAverage>(argument_types_, params)
|
||||
: IAggregateFunctionDataHelper<ExponentiallySmoothedAverage, AggregateFunctionExponentialMovingAverage>(argument_types_, params, createResultType())
|
||||
{
|
||||
if (params.size() != 1)
|
||||
throw Exception{"Aggregate function " + getName() + " requires exactly one parameter: half decay time.",
|
||||
@ -43,7 +43,7 @@ public:
|
||||
return "exponentialMovingAverage";
|
||||
}
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
static DataTypePtr createResultType()
|
||||
{
|
||||
return std::make_shared<DataTypeNumber<Float64>>();
|
||||
}
|
||||
|
@ -107,7 +107,7 @@ private:
|
||||
|
||||
public:
|
||||
AggregateFunctionForEach(AggregateFunctionPtr nested_, const DataTypes & arguments, const Array & params_)
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionForEachData, AggregateFunctionForEach>(arguments, params_)
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionForEachData, AggregateFunctionForEach>(arguments, params_, nested_->getResultType())
|
||||
, nested_func(nested_), num_arguments(arguments.size())
|
||||
{
|
||||
nested_size_of_data = nested_func->sizeOfData();
|
||||
@ -125,11 +125,6 @@ public:
|
||||
return nested_func->getName() + "ForEach";
|
||||
}
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
{
|
||||
return std::make_shared<DataTypeArray>(nested_func->getReturnType());
|
||||
}
|
||||
|
||||
bool isVersioned() const override
|
||||
{
|
||||
return nested_func->isVersioned();
|
||||
|
@ -121,7 +121,7 @@ public:
|
||||
explicit GroupArrayNumericImpl(
|
||||
const DataTypePtr & data_type_, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max(), UInt64 seed_ = 123456)
|
||||
: IAggregateFunctionDataHelper<GroupArrayNumericData<T, Trait::sampler != Sampler::NONE>, GroupArrayNumericImpl<T, Trait>>(
|
||||
{data_type_}, parameters_)
|
||||
{data_type_}, parameters_, std::make_shared<DataTypeArray>(data_type_))
|
||||
, max_elems(max_elems_)
|
||||
, seed(seed_)
|
||||
{
|
||||
@ -129,8 +129,6 @@ public:
|
||||
|
||||
String getName() const override { return getNameByTrait<Trait>(); }
|
||||
|
||||
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeArray>(this->argument_types[0]); }
|
||||
|
||||
void insert(Data & a, const T & v, Arena * arena) const
|
||||
{
|
||||
++a.total_values;
|
||||
@ -423,7 +421,7 @@ class GroupArrayGeneralImpl final
|
||||
public:
|
||||
GroupArrayGeneralImpl(const DataTypePtr & data_type_, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max(), UInt64 seed_ = 123456)
|
||||
: IAggregateFunctionDataHelper<GroupArrayGeneralData<Node, Trait::sampler != Sampler::NONE>, GroupArrayGeneralImpl<Node, Trait>>(
|
||||
{data_type_}, parameters_)
|
||||
{data_type_}, parameters_, std::make_shared<DataTypeArray>(data_type_))
|
||||
, data_type(this->argument_types[0])
|
||||
, max_elems(max_elems_)
|
||||
, seed(seed_)
|
||||
@ -432,8 +430,6 @@ public:
|
||||
|
||||
String getName() const override { return getNameByTrait<Trait>(); }
|
||||
|
||||
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeArray>(data_type); }
|
||||
|
||||
void insert(Data & a, const Node * v, Arena * arena) const
|
||||
{
|
||||
++a.total_values;
|
||||
@ -697,7 +693,7 @@ class GroupArrayGeneralListImpl final
|
||||
|
||||
public:
|
||||
GroupArrayGeneralListImpl(const DataTypePtr & data_type_, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max())
|
||||
: IAggregateFunctionDataHelper<GroupArrayGeneralListData<Node>, GroupArrayGeneralListImpl<Node, Trait>>({data_type_}, parameters_)
|
||||
: IAggregateFunctionDataHelper<GroupArrayGeneralListData<Node>, GroupArrayGeneralListImpl<Node, Trait>>({data_type_}, parameters_, std::make_shared<DataTypeArray>(data_type_))
|
||||
, data_type(this->argument_types[0])
|
||||
, max_elems(max_elems_)
|
||||
{
|
||||
@ -705,8 +701,6 @@ public:
|
||||
|
||||
String getName() const override { return getNameByTrait<Trait>(); }
|
||||
|
||||
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeArray>(data_type); }
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
|
||||
{
|
||||
if (limit_num_elems && data(place).elems >= max_elems)
|
||||
|
@ -64,7 +64,7 @@ private:
|
||||
|
||||
public:
|
||||
AggregateFunctionGroupArrayInsertAtGeneric(const DataTypes & arguments, const Array & params)
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionGroupArrayInsertAtDataGeneric, AggregateFunctionGroupArrayInsertAtGeneric>(arguments, params)
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionGroupArrayInsertAtDataGeneric, AggregateFunctionGroupArrayInsertAtGeneric>(arguments, params, std::make_shared<DataTypeArray>(arguments[0]))
|
||||
, type(argument_types[0])
|
||||
, serialization(type->getDefaultSerialization())
|
||||
{
|
||||
@ -101,11 +101,6 @@ public:
|
||||
|
||||
String getName() const override { return "groupArrayInsertAt"; }
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
{
|
||||
return std::make_shared<DataTypeArray>(type);
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
|
@ -93,12 +93,15 @@ public:
|
||||
using ColumnResult = ColumnVectorOrDecimal<ResultT>;
|
||||
|
||||
explicit MovingImpl(const DataTypePtr & data_type_, UInt64 window_size_ = std::numeric_limits<UInt64>::max())
|
||||
: IAggregateFunctionDataHelper<Data, MovingImpl<T, LimitNumElements, Data>>({data_type_}, {})
|
||||
: IAggregateFunctionDataHelper<Data, MovingImpl<T, LimitNumElements, Data>>({data_type_}, {}, createResultType(data_type_))
|
||||
, window_size(window_size_) {}
|
||||
|
||||
String getName() const override { return Data::name; }
|
||||
|
||||
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeArray>(getReturnTypeElement()); }
|
||||
static DataTypePtr createResultType(const DataTypePtr & argument)
|
||||
{
|
||||
return std::make_shared<DataTypeArray>(getReturnTypeElement(argument));
|
||||
}
|
||||
|
||||
void NO_SANITIZE_UNDEFINED add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
|
||||
{
|
||||
@ -183,14 +186,14 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
auto getReturnTypeElement() const
|
||||
static auto getReturnTypeElement(const DataTypePtr & argument)
|
||||
{
|
||||
if constexpr (!is_decimal<ResultT>)
|
||||
return std::make_shared<DataTypeNumber<ResultT>>();
|
||||
else
|
||||
{
|
||||
using Res = DataTypeDecimal<ResultT>;
|
||||
return std::make_shared<Res>(Res::maxPrecision(), getDecimalScale(*this->argument_types.at(0)));
|
||||
return std::make_shared<Res>(Res::maxPrecision(), getDecimalScale(*argument));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -74,7 +74,7 @@ namespace
|
||||
/// groupBitmap needs to know about the data type that was used to create bitmaps.
|
||||
/// We need to look inside the type of its argument to obtain it.
|
||||
const DataTypeAggregateFunction & datatype_aggfunc = dynamic_cast<const DataTypeAggregateFunction &>(*argument_type_ptr);
|
||||
AggregateFunctionPtr aggfunc = datatype_aggfunc.getFunction();
|
||||
ConstAggregateFunctionPtr aggfunc = datatype_aggfunc.getFunction();
|
||||
|
||||
if (aggfunc->getName() != AggregateFunctionGroupBitmapData<UInt8>::name())
|
||||
throw Exception(
|
||||
|
@ -19,13 +19,13 @@ class AggregateFunctionBitmap final : public IAggregateFunctionDataHelper<Data,
|
||||
{
|
||||
public:
|
||||
explicit AggregateFunctionBitmap(const DataTypePtr & type)
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionBitmap<T, Data>>({type}, {})
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionBitmap<T, Data>>({type}, {}, createResultType())
|
||||
{
|
||||
}
|
||||
|
||||
String getName() const override { return Data::name(); }
|
||||
|
||||
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeNumber<T>>(); }
|
||||
static DataTypePtr createResultType() { return std::make_shared<DataTypeNumber<T>>(); }
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
@ -59,13 +59,13 @@ private:
|
||||
static constexpr size_t STATE_VERSION_1_MIN_REVISION = 54455;
|
||||
public:
|
||||
explicit AggregateFunctionBitmapL2(const DataTypePtr & type)
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionBitmapL2<T, Data, Policy>>({type}, {})
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionBitmapL2<T, Data, Policy>>({type}, {}, createResultType())
|
||||
{
|
||||
}
|
||||
|
||||
String getName() const override { return Policy::name; }
|
||||
|
||||
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeNumber<T>>(); }
|
||||
static DataTypePtr createResultType() { return std::make_shared<DataTypeNumber<T>>(); }
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
|
@ -26,8 +26,8 @@ class AggregateFunctionGroupUniqArrayDate : public AggregateFunctionGroupUniqArr
|
||||
{
|
||||
public:
|
||||
explicit AggregateFunctionGroupUniqArrayDate(const DataTypePtr & argument_type, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max())
|
||||
: AggregateFunctionGroupUniqArray<DataTypeDate::FieldType, HasLimit>(argument_type, parameters_, max_elems_) {}
|
||||
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeArray>(std::make_shared<DataTypeDate>()); }
|
||||
: AggregateFunctionGroupUniqArray<DataTypeDate::FieldType, HasLimit>(argument_type, parameters_, createResultType(), max_elems_) {}
|
||||
static DataTypePtr createResultType() { return std::make_shared<DataTypeArray>(std::make_shared<DataTypeDate>()); }
|
||||
};
|
||||
|
||||
template <typename HasLimit>
|
||||
@ -35,8 +35,8 @@ class AggregateFunctionGroupUniqArrayDateTime : public AggregateFunctionGroupUni
|
||||
{
|
||||
public:
|
||||
explicit AggregateFunctionGroupUniqArrayDateTime(const DataTypePtr & argument_type, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max())
|
||||
: AggregateFunctionGroupUniqArray<DataTypeDateTime::FieldType, HasLimit>(argument_type, parameters_, max_elems_) {}
|
||||
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeArray>(std::make_shared<DataTypeDateTime>()); }
|
||||
: AggregateFunctionGroupUniqArray<DataTypeDateTime::FieldType, HasLimit>(argument_type, parameters_, createResultType(), max_elems_) {}
|
||||
static DataTypePtr createResultType() { return std::make_shared<DataTypeArray>(std::make_shared<DataTypeDateTime>()); }
|
||||
};
|
||||
|
||||
template <typename HasLimit, typename ... TArgs>
|
||||
|
@ -50,15 +50,16 @@ private:
|
||||
public:
|
||||
AggregateFunctionGroupUniqArray(const DataTypePtr & argument_type, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max())
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionGroupUniqArrayData<T>,
|
||||
AggregateFunctionGroupUniqArray<T, LimitNumElems>>({argument_type}, parameters_),
|
||||
AggregateFunctionGroupUniqArray<T, LimitNumElems>>({argument_type}, parameters_, std::make_shared<DataTypeArray>(argument_type)),
|
||||
max_elems(max_elems_) {}
|
||||
|
||||
String getName() const override { return "groupUniqArray"; }
|
||||
AggregateFunctionGroupUniqArray(const DataTypePtr & argument_type, const Array & parameters_, const DataTypePtr & result_type_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max())
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionGroupUniqArrayData<T>,
|
||||
AggregateFunctionGroupUniqArray<T, LimitNumElems>>({argument_type}, parameters_, result_type_),
|
||||
max_elems(max_elems_) {}
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
{
|
||||
return std::make_shared<DataTypeArray>(this->argument_types[0]);
|
||||
}
|
||||
|
||||
String getName() const override { return "groupUniqArray"; }
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
@ -153,17 +154,12 @@ class AggregateFunctionGroupUniqArrayGeneric
|
||||
|
||||
public:
|
||||
AggregateFunctionGroupUniqArrayGeneric(const DataTypePtr & input_data_type_, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max())
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionGroupUniqArrayGenericData, AggregateFunctionGroupUniqArrayGeneric<is_plain_column, LimitNumElems>>({input_data_type_}, parameters_)
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionGroupUniqArrayGenericData, AggregateFunctionGroupUniqArrayGeneric<is_plain_column, LimitNumElems>>({input_data_type_}, parameters_, std::make_shared<DataTypeArray>(input_data_type_))
|
||||
, input_data_type(this->argument_types[0])
|
||||
, max_elems(max_elems_) {}
|
||||
|
||||
String getName() const override { return "groupUniqArray"; }
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
{
|
||||
return std::make_shared<DataTypeArray>(input_data_type);
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override
|
||||
{
|
||||
return true;
|
||||
|
@ -307,7 +307,7 @@ private:
|
||||
|
||||
public:
|
||||
AggregateFunctionHistogram(UInt32 max_bins_, const DataTypes & arguments, const Array & params)
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionHistogramData, AggregateFunctionHistogram<T>>(arguments, params)
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionHistogramData, AggregateFunctionHistogram<T>>(arguments, params, createResultType())
|
||||
, max_bins(max_bins_)
|
||||
{
|
||||
}
|
||||
@ -316,7 +316,7 @@ public:
|
||||
{
|
||||
return Data::structSize(max_bins);
|
||||
}
|
||||
DataTypePtr getReturnType() const override
|
||||
static DataTypePtr createResultType()
|
||||
{
|
||||
DataTypes types;
|
||||
auto mean = std::make_shared<DataTypeNumber<Data::Mean>>();
|
||||
|
@ -448,7 +448,7 @@ AggregateFunctionPtr AggregateFunctionIf::getOwnNullAdapter(
|
||||
|
||||
/// Nullability of the last argument (condition) does not affect the nullability of the result (NULL is processed as false).
|
||||
/// For other arguments it is as usual (at least one is NULL then the result is NULL if possible).
|
||||
bool return_type_is_nullable = !properties.returns_default_when_only_null && getReturnType()->canBeInsideNullable()
|
||||
bool return_type_is_nullable = !properties.returns_default_when_only_null && getResultType()->canBeInsideNullable()
|
||||
&& std::any_of(arguments.begin(), arguments.end() - 1, [](const auto & element) { return element->isNullable(); });
|
||||
|
||||
bool need_to_serialize_flag = return_type_is_nullable || properties.returns_default_when_only_null;
|
||||
|
@ -36,7 +36,7 @@ private:
|
||||
|
||||
public:
|
||||
AggregateFunctionIf(AggregateFunctionPtr nested, const DataTypes & types, const Array & params_)
|
||||
: IAggregateFunctionHelper<AggregateFunctionIf>(types, params_)
|
||||
: IAggregateFunctionHelper<AggregateFunctionIf>(types, params_, nested->getResultType())
|
||||
, nested_func(nested), num_arguments(types.size())
|
||||
{
|
||||
if (num_arguments == 0)
|
||||
@ -51,11 +51,6 @@ public:
|
||||
return nested_func->getName() + "If";
|
||||
}
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
{
|
||||
return nested_func->getReturnType();
|
||||
}
|
||||
|
||||
const IAggregateFunction & getBaseAggregateFunctionWithSameStateRepresentation() const override
|
||||
{
|
||||
return nested_func->getBaseAggregateFunctionWithSameStateRepresentation();
|
||||
|
@ -177,11 +177,11 @@ public:
|
||||
String getName() const override { return "intervalLengthSum"; }
|
||||
|
||||
explicit AggregateFunctionIntervalLengthSum(const DataTypes & arguments)
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionIntervalLengthSum<T, Data>>(arguments, {})
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionIntervalLengthSum<T, Data>>(arguments, {}, createResultType())
|
||||
{
|
||||
}
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
static DataTypePtr createResultType()
|
||||
{
|
||||
if constexpr (std::is_floating_point_v<T>)
|
||||
return std::make_shared<DataTypeFloat64>();
|
||||
|
@ -309,7 +309,7 @@ public:
|
||||
UInt64 batch_size_,
|
||||
const DataTypes & arguments_types,
|
||||
const Array & params)
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionMLMethod<Data, Name>>(arguments_types, params)
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionMLMethod<Data, Name>>(arguments_types, params, createResultType())
|
||||
, param_num(param_num_)
|
||||
, learning_rate(learning_rate_)
|
||||
, l2_reg_coef(l2_reg_coef_)
|
||||
@ -319,8 +319,7 @@ public:
|
||||
{
|
||||
}
|
||||
|
||||
/// This function is called when SELECT linearRegression(...) is called
|
||||
DataTypePtr getReturnType() const override
|
||||
static DataTypePtr createResultType()
|
||||
{
|
||||
return std::make_shared<DataTypeArray>(std::make_shared<DataTypeFloat64>());
|
||||
}
|
||||
|
@ -133,7 +133,7 @@ private:
|
||||
|
||||
public:
|
||||
explicit AggregateFunctionMannWhitney(const DataTypes & arguments, const Array & params)
|
||||
:IAggregateFunctionDataHelper<MannWhitneyData, AggregateFunctionMannWhitney> ({arguments}, {})
|
||||
:IAggregateFunctionDataHelper<MannWhitneyData, AggregateFunctionMannWhitney> ({arguments}, {}, createResultType())
|
||||
{
|
||||
if (params.size() > 2)
|
||||
throw Exception("Aggregate function " + getName() + " require two parameter or less", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
@ -174,7 +174,7 @@ public:
|
||||
|
||||
bool allocatesMemoryInArena() const override { return true; }
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
static DataTypePtr createResultType()
|
||||
{
|
||||
DataTypes types
|
||||
{
|
||||
|
@ -18,6 +18,7 @@
|
||||
#include <Functions/FunctionHelpers.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include "DataTypes/Serializations/ISerialization.h"
|
||||
#include "base/types.h"
|
||||
#include <Common/Arena.h>
|
||||
#include "AggregateFunctions/AggregateFunctionFactory.h"
|
||||
@ -104,26 +105,32 @@ public:
|
||||
return nested_func->getDefaultVersion();
|
||||
}
|
||||
|
||||
AggregateFunctionMap(AggregateFunctionPtr nested, const DataTypes & types) : Base(types, nested->getParameters()), nested_func(nested)
|
||||
AggregateFunctionMap(AggregateFunctionPtr nested, const DataTypes & types)
|
||||
: Base(types, nested->getParameters(), std::make_shared<DataTypeMap>(DataTypes{getKeyType(types, nested), nested->getResultType()}))
|
||||
, nested_func(nested)
|
||||
{
|
||||
if (types.empty())
|
||||
throw Exception(
|
||||
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function " + getName() + " requires at least one argument");
|
||||
|
||||
if (types.size() > 1)
|
||||
throw Exception(
|
||||
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function " + getName() + " requires only one map argument");
|
||||
|
||||
const auto * map_type = checkAndGetDataType<DataTypeMap>(types[0].get());
|
||||
if (!map_type)
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Aggregate function " + getName() + " requires map as argument");
|
||||
|
||||
key_type = map_type->getKeyType();
|
||||
key_type = getKeyType(types, nested_func);
|
||||
}
|
||||
|
||||
String getName() const override { return nested_func->getName() + "Map"; }
|
||||
|
||||
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeMap>(DataTypes{key_type, nested_func->getReturnType()}); }
|
||||
static DataTypePtr getKeyType(const DataTypes & types, const AggregateFunctionPtr & nested)
|
||||
{
|
||||
if (types.empty())
|
||||
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
|
||||
"Aggregate function {}Map requires at least one argument", nested->getName());
|
||||
|
||||
if (types.size() > 1)
|
||||
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
|
||||
"Aggregate function {}Map requires only one map argument", nested->getName());
|
||||
|
||||
const auto * map_type = checkAndGetDataType<DataTypeMap>(types[0].get());
|
||||
if (!map_type)
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
|
||||
"Aggregate function {}Map requires map as argument", nested->getName());
|
||||
|
||||
return map_type->getKeyType();
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
|
||||
{
|
||||
|
@ -62,7 +62,8 @@ private:
|
||||
|
||||
public:
|
||||
AggregateFunctionIntersectionsMax(AggregateFunctionIntersectionsKind kind_, const DataTypes & arguments)
|
||||
: IAggregateFunctionDataHelper<MaxIntersectionsData<PointType>, AggregateFunctionIntersectionsMax<PointType>>(arguments, {}), kind(kind_)
|
||||
: IAggregateFunctionDataHelper<MaxIntersectionsData<PointType>, AggregateFunctionIntersectionsMax<PointType>>(arguments, {}, createResultType(kind_))
|
||||
, kind(kind_)
|
||||
{
|
||||
if (!isNativeNumber(arguments[0]))
|
||||
throw Exception{getName() + ": first argument must be represented by integer", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
|
||||
@ -81,9 +82,9 @@ public:
|
||||
: "maxIntersectionsPosition";
|
||||
}
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
static DataTypePtr createResultType(AggregateFunctionIntersectionsKind kind_)
|
||||
{
|
||||
if (kind == AggregateFunctionIntersectionsKind::Count)
|
||||
if (kind_ == AggregateFunctionIntersectionsKind::Count)
|
||||
return std::make_shared<DataTypeUInt64>();
|
||||
else
|
||||
return std::make_shared<DataTypeNumber<PointType>>();
|
||||
|
@ -36,7 +36,7 @@ private:
|
||||
|
||||
public:
|
||||
AggregateFunctionMeanZTest(const DataTypes & arguments, const Array & params)
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionMeanZTest<Data>>({arguments}, params)
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionMeanZTest<Data>>({arguments}, params, createResultType())
|
||||
{
|
||||
pop_var_x = params.at(0).safeGet<Float64>();
|
||||
pop_var_y = params.at(1).safeGet<Float64>();
|
||||
@ -63,7 +63,7 @@ public:
|
||||
return Data::name;
|
||||
}
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
static DataTypePtr createResultType()
|
||||
{
|
||||
DataTypes types
|
||||
{
|
||||
|
@ -30,7 +30,7 @@ private:
|
||||
|
||||
public:
|
||||
AggregateFunctionMerge(const AggregateFunctionPtr & nested_, const DataTypePtr & argument, const Array & params_)
|
||||
: IAggregateFunctionHelper<AggregateFunctionMerge>({argument}, params_)
|
||||
: IAggregateFunctionHelper<AggregateFunctionMerge>({argument}, params_, createResultType())
|
||||
, nested_func(nested_)
|
||||
{
|
||||
const DataTypeAggregateFunction * data_type = typeid_cast<const DataTypeAggregateFunction *>(argument.get());
|
||||
@ -45,9 +45,9 @@ public:
|
||||
return nested_func->getName() + "Merge";
|
||||
}
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
DataTypePtr createResultType() const
|
||||
{
|
||||
return nested_func->getReturnType();
|
||||
return nested_func->getResultType();
|
||||
}
|
||||
|
||||
const IAggregateFunction & getBaseAggregateFunctionWithSameStateRepresentation() const override
|
||||
|
@ -1219,7 +1219,7 @@ private:
|
||||
|
||||
public:
|
||||
explicit AggregateFunctionsSingleValue(const DataTypePtr & type)
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionsSingleValue<Data>>({type}, {})
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionsSingleValue<Data>>({type}, {}, createResultType())
|
||||
, serialization(type->getDefaultSerialization())
|
||||
{
|
||||
if (StringRef(Data::name()) == StringRef("min")
|
||||
@ -1233,7 +1233,7 @@ public:
|
||||
|
||||
String getName() const override { return Data::name(); }
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
DataTypePtr createResultType() const
|
||||
{
|
||||
auto result_type = this->argument_types.at(0);
|
||||
if constexpr (Data::is_nullable)
|
||||
|
@ -6,6 +6,7 @@
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include "DataTypes/IDataType.h"
|
||||
|
||||
|
||||
namespace DB
|
||||
@ -19,16 +20,16 @@ class AggregateFunctionNothing final : public IAggregateFunctionHelper<Aggregate
|
||||
{
|
||||
public:
|
||||
AggregateFunctionNothing(const DataTypes & arguments, const Array & params)
|
||||
: IAggregateFunctionHelper<AggregateFunctionNothing>(arguments, params) {}
|
||||
: IAggregateFunctionHelper<AggregateFunctionNothing>(arguments, params, createResultType(arguments)) {}
|
||||
|
||||
String getName() const override
|
||||
{
|
||||
return "nothing";
|
||||
}
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
static DataTypePtr createResultType(const DataTypes & arguments)
|
||||
{
|
||||
return argument_types.empty() ? std::make_shared<DataTypeNullable>(std::make_shared<DataTypeNothing>()) : argument_types.front();
|
||||
return arguments.empty() ? std::make_shared<DataTypeNullable>(std::make_shared<DataTypeNothing>()) : arguments.front();
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
@ -87,7 +87,7 @@ public:
|
||||
transformed_nested_function->getParameters());
|
||||
}
|
||||
|
||||
bool return_type_is_nullable = !properties.returns_default_when_only_null && nested_function->getReturnType()->canBeInsideNullable();
|
||||
bool return_type_is_nullable = !properties.returns_default_when_only_null && nested_function->getResultType()->canBeInsideNullable();
|
||||
bool serialize_flag = return_type_is_nullable || properties.returns_default_when_only_null;
|
||||
|
||||
if (arguments.size() == 1)
|
||||
|
@ -82,7 +82,8 @@ protected:
|
||||
|
||||
public:
|
||||
AggregateFunctionNullBase(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params)
|
||||
: IAggregateFunctionHelper<Derived>(arguments, params), nested_function{nested_function_}
|
||||
: IAggregateFunctionHelper<Derived>(arguments, params, createResultType(nested_function_))
|
||||
, nested_function{nested_function_}
|
||||
{
|
||||
if constexpr (result_is_nullable)
|
||||
prefix_size = nested_function->alignOfData();
|
||||
@ -96,11 +97,11 @@ public:
|
||||
return nested_function->getName();
|
||||
}
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
static DataTypePtr createResultType(const AggregateFunctionPtr & nested_function_)
|
||||
{
|
||||
return result_is_nullable
|
||||
? makeNullable(nested_function->getReturnType())
|
||||
: nested_function->getReturnType();
|
||||
? makeNullable(nested_function_->getResultType())
|
||||
: nested_function_->getResultType();
|
||||
}
|
||||
|
||||
void create(AggregateDataPtr __restrict place) const override
|
||||
@ -270,7 +271,7 @@ public:
|
||||
{
|
||||
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
|
||||
|
||||
auto * return_type = toNativeType(b, this->getReturnType());
|
||||
auto * return_type = toNativeType(b, this->getResultType());
|
||||
|
||||
llvm::Value * result = nullptr;
|
||||
|
||||
|
@ -4,6 +4,7 @@
|
||||
#include <Columns/ColumnNullable.h>
|
||||
#include <Columns/ColumnsCommon.h>
|
||||
#include <Common/typeid_cast.h>
|
||||
#include "DataTypes/Serializations/ISerialization.h"
|
||||
#include <DataTypes/DataTypeNullable.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
@ -30,16 +31,14 @@ private:
|
||||
AggregateFunctionPtr nested_function;
|
||||
|
||||
size_t size_of_data;
|
||||
DataTypePtr inner_type;
|
||||
bool inner_nullable;
|
||||
|
||||
public:
|
||||
AggregateFunctionOrFill(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params)
|
||||
: IAggregateFunctionHelper<AggregateFunctionOrFill>{arguments, params}
|
||||
: IAggregateFunctionHelper<AggregateFunctionOrFill>{arguments, params, createResultType(nested_function_->getResultType())}
|
||||
, nested_function{nested_function_}
|
||||
, size_of_data {nested_function->sizeOfData()}
|
||||
, inner_type {nested_function->getReturnType()}
|
||||
, inner_nullable {inner_type->isNullable()}
|
||||
, inner_nullable {nested_function->getResultType()->isNullable()}
|
||||
{
|
||||
// nothing
|
||||
}
|
||||
@ -246,22 +245,22 @@ public:
|
||||
readChar(place[size_of_data], buf);
|
||||
}
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
static DataTypePtr createResultType(const DataTypePtr & inner_type_)
|
||||
{
|
||||
if constexpr (UseNull)
|
||||
{
|
||||
// -OrNull
|
||||
|
||||
if (inner_nullable)
|
||||
return inner_type;
|
||||
if (inner_type_->isNullable())
|
||||
return inner_type_;
|
||||
|
||||
return std::make_shared<DataTypeNullable>(inner_type);
|
||||
return std::make_shared<DataTypeNullable>(inner_type_);
|
||||
}
|
||||
else
|
||||
{
|
||||
// -OrDefault
|
||||
|
||||
return inner_type;
|
||||
return inner_type_;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -72,7 +72,7 @@ private:
|
||||
public:
|
||||
AggregateFunctionQuantile(const DataTypes & argument_types_, const Array & params)
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionQuantile<Value, Data, Name, has_second_arg, FloatReturnType, returns_many>>(
|
||||
argument_types_, params)
|
||||
argument_types_, params, createResultType(argument_types_))
|
||||
, levels(params, returns_many)
|
||||
, level(levels.levels[0])
|
||||
, argument_type(this->argument_types[0])
|
||||
@ -83,14 +83,14 @@ public:
|
||||
|
||||
String getName() const override { return Name::name; }
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
static DataTypePtr createResultType(const DataTypes & argument_types_)
|
||||
{
|
||||
DataTypePtr res;
|
||||
|
||||
if constexpr (returns_float)
|
||||
res = std::make_shared<DataTypeNumber<FloatReturnType>>();
|
||||
else
|
||||
res = argument_type;
|
||||
res = argument_types_[0];
|
||||
|
||||
if constexpr (returns_many)
|
||||
return std::make_shared<DataTypeArray>(res);
|
||||
|
@ -51,7 +51,7 @@ class AggregateFunctionRankCorrelation :
|
||||
{
|
||||
public:
|
||||
explicit AggregateFunctionRankCorrelation(const DataTypes & arguments)
|
||||
:IAggregateFunctionDataHelper<RankCorrelationData, AggregateFunctionRankCorrelation> ({arguments}, {})
|
||||
:IAggregateFunctionDataHelper<RankCorrelationData, AggregateFunctionRankCorrelation> ({arguments}, {}, std::make_shared<DataTypeNumber<Float64>>())
|
||||
{}
|
||||
|
||||
String getName() const override
|
||||
@ -61,11 +61,6 @@ public:
|
||||
|
||||
bool allocatesMemoryInArena() const override { return true; }
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
{
|
||||
return std::make_shared<DataTypeNumber<Float64>>();
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
|
||||
{
|
||||
Float64 new_x = columns[0]->getFloat64(row_num);
|
||||
|
@ -43,7 +43,7 @@ public:
|
||||
size_t step_,
|
||||
const DataTypes & arguments,
|
||||
const Array & params)
|
||||
: IAggregateFunctionHelper<AggregateFunctionResample<Key>>{arguments, params}
|
||||
: IAggregateFunctionHelper<AggregateFunctionResample<Key>>{arguments, params, createResultType(nested_function_)}
|
||||
, nested_function{nested_function_}
|
||||
, last_col{arguments.size() - 1}
|
||||
, begin{begin_}
|
||||
@ -190,9 +190,9 @@ public:
|
||||
nested_function->deserialize(place + i * size_of_data, buf, version, arena);
|
||||
}
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
static DataTypePtr createResultType(const AggregateFunctionPtr & nested_function_)
|
||||
{
|
||||
return std::make_shared<DataTypeArray>(nested_function->getReturnType());
|
||||
return std::make_shared<DataTypeArray>(nested_function_->getResultType());
|
||||
}
|
||||
|
||||
template <bool merge>
|
||||
|
@ -76,7 +76,7 @@ public:
|
||||
}
|
||||
|
||||
explicit AggregateFunctionRetention(const DataTypes & arguments)
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionRetentionData, AggregateFunctionRetention>(arguments, {})
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionRetentionData, AggregateFunctionRetention>(arguments, {}, std::make_shared<DataTypeArray>(std::make_shared<DataTypeUInt8>()))
|
||||
{
|
||||
for (const auto i : collections::range(0, arguments.size()))
|
||||
{
|
||||
@ -90,12 +90,6 @@ public:
|
||||
events_size = static_cast<UInt8>(arguments.size());
|
||||
}
|
||||
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
{
|
||||
return std::make_shared<DataTypeArray>(std::make_shared<DataTypeUInt8>());
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, const size_t row_num, Arena *) const override
|
||||
|
@ -8,6 +8,7 @@
|
||||
#include <base/range.h>
|
||||
#include <base/sort.h>
|
||||
#include <Common/PODArray.h>
|
||||
#include "DataTypes/Serializations/ISerialization.h"
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <bitset>
|
||||
@ -126,8 +127,8 @@ template <typename T, typename Data, typename Derived>
|
||||
class AggregateFunctionSequenceBase : public IAggregateFunctionDataHelper<Data, Derived>
|
||||
{
|
||||
public:
|
||||
AggregateFunctionSequenceBase(const DataTypes & arguments, const Array & params, const String & pattern_)
|
||||
: IAggregateFunctionDataHelper<Data, Derived>(arguments, params)
|
||||
AggregateFunctionSequenceBase(const DataTypes & arguments, const Array & params, const String & pattern_, const DataTypePtr & result_type_)
|
||||
: IAggregateFunctionDataHelper<Data, Derived>(arguments, params, result_type_)
|
||||
, pattern(pattern_)
|
||||
{
|
||||
arg_count = arguments.size();
|
||||
@ -617,14 +618,12 @@ class AggregateFunctionSequenceMatch final : public AggregateFunctionSequenceBas
|
||||
{
|
||||
public:
|
||||
AggregateFunctionSequenceMatch(const DataTypes & arguments, const Array & params, const String & pattern_)
|
||||
: AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceMatch<T, Data>>(arguments, params, pattern_) {}
|
||||
: AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceMatch<T, Data>>(arguments, params, pattern_, std::make_shared<DataTypeUInt8>()) {}
|
||||
|
||||
using AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceMatch<T, Data>>::AggregateFunctionSequenceBase;
|
||||
|
||||
String getName() const override { return "sequenceMatch"; }
|
||||
|
||||
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeUInt8>(); }
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
@ -655,14 +654,12 @@ class AggregateFunctionSequenceCount final : public AggregateFunctionSequenceBas
|
||||
{
|
||||
public:
|
||||
AggregateFunctionSequenceCount(const DataTypes & arguments, const Array & params, const String & pattern_)
|
||||
: AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceCount<T, Data>>(arguments, params, pattern_) {}
|
||||
: AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceCount<T, Data>>(arguments, params, pattern_, std::make_shared<DataTypeUInt64>()) {}
|
||||
|
||||
using AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceCount<T, Data>>::AggregateFunctionSequenceBase;
|
||||
|
||||
String getName() const override { return "sequenceCount"; }
|
||||
|
||||
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeUInt64>(); }
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
|
@ -190,7 +190,7 @@ public:
|
||||
SequenceDirection seq_direction_,
|
||||
size_t min_required_args_,
|
||||
UInt64 max_elems_ = std::numeric_limits<UInt64>::max())
|
||||
: IAggregateFunctionDataHelper<SequenceNextNodeGeneralData<Node>, Self>({data_type_}, parameters_)
|
||||
: IAggregateFunctionDataHelper<SequenceNextNodeGeneralData<Node>, Self>({data_type_}, parameters_, data_type_)
|
||||
, seq_base_kind(seq_base_kind_)
|
||||
, seq_direction(seq_direction_)
|
||||
, min_required_args(min_required_args_)
|
||||
@ -202,8 +202,6 @@ public:
|
||||
|
||||
String getName() const override { return "sequenceNextNode"; }
|
||||
|
||||
DataTypePtr getReturnType() const override { return data_type; }
|
||||
|
||||
bool haveSameStateRepresentationImpl(const IAggregateFunction & rhs) const override
|
||||
{
|
||||
return this->getName() == rhs.getName() && this->haveEqualArgumentTypes(rhs);
|
||||
|
@ -99,7 +99,7 @@ public:
|
||||
IAggregateFunctionDataHelper<
|
||||
AggregateFunctionSimpleLinearRegressionData<Ret>,
|
||||
AggregateFunctionSimpleLinearRegression<X, Y, Ret>
|
||||
> {arguments, params}
|
||||
> {arguments, params, createResultType()}
|
||||
{
|
||||
// notice: arguments has been checked before
|
||||
}
|
||||
@ -140,7 +140,7 @@ public:
|
||||
this->data(place).deserialize(buf);
|
||||
}
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
static DataTypePtr createResultType()
|
||||
{
|
||||
DataTypes types
|
||||
{
|
||||
|
@ -20,28 +20,28 @@ private:
|
||||
|
||||
public:
|
||||
AggregateFunctionSimpleState(AggregateFunctionPtr nested_, const DataTypes & arguments_, const Array & params_)
|
||||
: IAggregateFunctionHelper<AggregateFunctionSimpleState>(arguments_, params_)
|
||||
: IAggregateFunctionHelper<AggregateFunctionSimpleState>(arguments_, params_, createResultType(nested_, params_))
|
||||
, nested_func(nested_)
|
||||
{
|
||||
}
|
||||
|
||||
String getName() const override { return nested_func->getName() + "SimpleState"; }
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
static DataTypePtr createResultType(const AggregateFunctionPtr & nested_, const Array & params_)
|
||||
{
|
||||
DataTypeCustomSimpleAggregateFunction::checkSupportedFunctions(nested_func);
|
||||
DataTypeCustomSimpleAggregateFunction::checkSupportedFunctions(nested_);
|
||||
|
||||
// Need to make a clone to avoid recursive reference.
|
||||
auto storage_type_out = DataTypeFactory::instance().get(nested_func->getReturnType()->getName());
|
||||
auto storage_type_out = DataTypeFactory::instance().get(nested_->getResultType()->getName());
|
||||
// Need to make a new function with promoted argument types because SimpleAggregates requires arg_type = return_type.
|
||||
AggregateFunctionProperties properties;
|
||||
auto function
|
||||
= AggregateFunctionFactory::instance().get(nested_func->getName(), {storage_type_out}, nested_func->getParameters(), properties);
|
||||
= AggregateFunctionFactory::instance().get(nested_->getName(), {storage_type_out}, nested_->getParameters(), properties);
|
||||
|
||||
// Need to make a clone because it'll be customized.
|
||||
auto storage_type_arg = DataTypeFactory::instance().get(nested_func->getReturnType()->getName());
|
||||
auto storage_type_arg = DataTypeFactory::instance().get(nested_->getResultType()->getName());
|
||||
DataTypeCustomNamePtr custom_name
|
||||
= std::make_unique<DataTypeCustomSimpleAggregateFunction>(function, DataTypes{nested_func->getReturnType()}, parameters);
|
||||
= std::make_unique<DataTypeCustomSimpleAggregateFunction>(function, DataTypes{nested_->getResultType()}, params_);
|
||||
storage_type_arg->setCustomization(std::make_unique<DataTypeCustomDesc>(std::move(custom_name), nullptr));
|
||||
return storage_type_arg;
|
||||
}
|
||||
|
@ -261,7 +261,7 @@ private:
|
||||
public:
|
||||
AggregateFunctionSparkbar(const DataTypes & arguments, const Array & params)
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionSparkbarData<X, Y>, AggregateFunctionSparkbar>(
|
||||
arguments, params)
|
||||
arguments, params, std::make_shared<DataTypeString>())
|
||||
{
|
||||
width = params.at(0).safeGet<UInt64>();
|
||||
if (params.size() == 3)
|
||||
@ -283,11 +283,6 @@ public:
|
||||
return "sparkbar";
|
||||
}
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
{
|
||||
return std::make_shared<DataTypeString>();
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * /*arena*/) const override
|
||||
{
|
||||
X x = assert_cast<const ColumnVector<X> *>(columns[0])->getData()[row_num];
|
||||
|
@ -23,7 +23,7 @@ private:
|
||||
|
||||
public:
|
||||
AggregateFunctionState(AggregateFunctionPtr nested_, const DataTypes & arguments_, const Array & params_)
|
||||
: IAggregateFunctionHelper<AggregateFunctionState>(arguments_, params_)
|
||||
: IAggregateFunctionHelper<AggregateFunctionState>(arguments_, params_, createResultType())
|
||||
, nested_func(nested_)
|
||||
{}
|
||||
|
||||
@ -32,7 +32,7 @@ public:
|
||||
return nested_func->getName() + "State";
|
||||
}
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
DataTypePtr createResultType() const
|
||||
{
|
||||
return getStateType();
|
||||
}
|
||||
|
@ -115,15 +115,11 @@ class AggregateFunctionVariance final
|
||||
{
|
||||
public:
|
||||
explicit AggregateFunctionVariance(const DataTypePtr & arg)
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionVarianceData<T, Op>, AggregateFunctionVariance<T, Op>>({arg}, {}) {}
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionVarianceData<T, Op>, AggregateFunctionVariance<T, Op>>({arg}, {}, std::make_shared<DataTypeFloat64>())
|
||||
{}
|
||||
|
||||
String getName() const override { return Op::name; }
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
{
|
||||
return std::make_shared<DataTypeFloat64>();
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
@ -368,15 +364,11 @@ class AggregateFunctionCovariance final
|
||||
public:
|
||||
explicit AggregateFunctionCovariance(const DataTypes & args) : IAggregateFunctionDataHelper<
|
||||
CovarianceData<T, U, Op, compute_marginal_moments>,
|
||||
AggregateFunctionCovariance<T, U, Op, compute_marginal_moments>>(args, {}) {}
|
||||
AggregateFunctionCovariance<T, U, Op, compute_marginal_moments>>(args, {}, std::make_shared<DataTypeFloat64>())
|
||||
{}
|
||||
|
||||
String getName() const override { return Op::name; }
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
{
|
||||
return std::make_shared<DataTypeFloat64>();
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
|
@ -81,12 +81,12 @@ public:
|
||||
using ColVecResult = ColumnVector<ResultType>;
|
||||
|
||||
explicit AggregateFunctionVarianceSimple(const DataTypes & argument_types_)
|
||||
: IAggregateFunctionDataHelper<typename StatFunc::Data, AggregateFunctionVarianceSimple<StatFunc>>(argument_types_, {})
|
||||
: IAggregateFunctionDataHelper<typename StatFunc::Data, AggregateFunctionVarianceSimple<StatFunc>>(argument_types_, {}, std::make_shared<DataTypeNumber<ResultType>>())
|
||||
, src_scale(0)
|
||||
{}
|
||||
|
||||
AggregateFunctionVarianceSimple(const IDataType & data_type, const DataTypes & argument_types_)
|
||||
: IAggregateFunctionDataHelper<typename StatFunc::Data, AggregateFunctionVarianceSimple<StatFunc>>(argument_types_, {})
|
||||
: IAggregateFunctionDataHelper<typename StatFunc::Data, AggregateFunctionVarianceSimple<StatFunc>>(argument_types_, {}, std::make_shared<DataTypeNumber<ResultType>>())
|
||||
, src_scale(getDecimalScale(data_type))
|
||||
{}
|
||||
|
||||
@ -117,11 +117,6 @@ public:
|
||||
UNREACHABLE();
|
||||
}
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
{
|
||||
return std::make_shared<DataTypeNumber<ResultType>>();
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
|
@ -411,23 +411,21 @@ public:
|
||||
}
|
||||
|
||||
explicit AggregateFunctionSum(const DataTypes & argument_types_)
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionSum<T, TResult, Data, Type>>(argument_types_, {})
|
||||
, scale(0)
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionSum<T, TResult, Data, Type>>(argument_types_, {}, createResultType(0))
|
||||
{}
|
||||
|
||||
AggregateFunctionSum(const IDataType & data_type, const DataTypes & argument_types_)
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionSum<T, TResult, Data, Type>>(argument_types_, {})
|
||||
, scale(getDecimalScale(data_type))
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionSum<T, TResult, Data, Type>>(argument_types_, {}, createResultType(getDecimalScale(data_type)))
|
||||
{}
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
static DataTypePtr createResultType(UInt32 scale_)
|
||||
{
|
||||
if constexpr (!is_decimal<T>)
|
||||
return std::make_shared<DataTypeNumber<TResult>>();
|
||||
else
|
||||
{
|
||||
using DataType = DataTypeDecimal<TResult>;
|
||||
return std::make_shared<DataType>(DataType::maxPrecision(), scale);
|
||||
return std::make_shared<DataType>(DataType::maxPrecision(), scale_);
|
||||
}
|
||||
}
|
||||
|
||||
@ -548,7 +546,7 @@ public:
|
||||
for (const auto & argument_type : this->argument_types)
|
||||
can_be_compiled &= canBeNativeType(*argument_type);
|
||||
|
||||
auto return_type = getReturnType();
|
||||
auto return_type = this->getResultType();
|
||||
can_be_compiled &= canBeNativeType(*return_type);
|
||||
|
||||
return can_be_compiled;
|
||||
@ -558,7 +556,7 @@ public:
|
||||
{
|
||||
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
|
||||
|
||||
auto * return_type = toNativeType(b, getReturnType());
|
||||
auto * return_type = toNativeType(b, this->getResultType());
|
||||
auto * aggregate_sum_ptr = aggregate_data_ptr;
|
||||
|
||||
b.CreateStore(llvm::Constant::getNullValue(return_type), aggregate_sum_ptr);
|
||||
@ -568,7 +566,7 @@ public:
|
||||
{
|
||||
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
|
||||
|
||||
auto * return_type = toNativeType(b, getReturnType());
|
||||
auto * return_type = toNativeType(b, this->getResultType());
|
||||
|
||||
auto * sum_value_ptr = aggregate_data_ptr;
|
||||
auto * sum_value = b.CreateLoad(return_type, sum_value_ptr);
|
||||
@ -586,7 +584,7 @@ public:
|
||||
{
|
||||
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
|
||||
|
||||
auto * return_type = toNativeType(b, getReturnType());
|
||||
auto * return_type = toNativeType(b, this->getResultType());
|
||||
|
||||
auto * sum_value_dst_ptr = aggregate_data_dst_ptr;
|
||||
auto * sum_value_dst = b.CreateLoad(return_type, sum_value_dst_ptr);
|
||||
@ -602,7 +600,7 @@ public:
|
||||
{
|
||||
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
|
||||
|
||||
auto * return_type = toNativeType(b, getReturnType());
|
||||
auto * return_type = toNativeType(b, this->getResultType());
|
||||
auto * sum_value_ptr = aggregate_data_ptr;
|
||||
|
||||
return b.CreateLoad(return_type, sum_value_ptr);
|
||||
@ -611,8 +609,6 @@ public:
|
||||
#endif
|
||||
|
||||
private:
|
||||
UInt32 scale;
|
||||
|
||||
static constexpr auto & castColumnToResult(IColumn & to)
|
||||
{
|
||||
if constexpr (is_decimal<T>)
|
||||
|
@ -14,12 +14,13 @@ public:
|
||||
using Base = AggregateFunctionAvg<T>;
|
||||
|
||||
explicit AggregateFunctionSumCount(const DataTypes & argument_types_, UInt32 num_scale_ = 0)
|
||||
: Base(argument_types_, num_scale_), scale(num_scale_) {}
|
||||
: Base(argument_types_, createResultType(num_scale_), num_scale_)
|
||||
{}
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
static DataTypePtr createResultType(UInt32 num_scale_)
|
||||
{
|
||||
auto second_elem = std::make_shared<DataTypeUInt64>();
|
||||
return std::make_shared<DataTypeTuple>(DataTypes{getReturnTypeFirstElement(), std::move(second_elem)});
|
||||
return std::make_shared<DataTypeTuple>(DataTypes{getReturnTypeFirstElement(num_scale_), std::move(second_elem)});
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const final
|
||||
@ -43,9 +44,7 @@ public:
|
||||
#endif
|
||||
|
||||
private:
|
||||
UInt32 scale;
|
||||
|
||||
auto getReturnTypeFirstElement() const
|
||||
static auto getReturnTypeFirstElement(UInt32 num_scale_)
|
||||
{
|
||||
using FieldType = AvgFieldType<T>;
|
||||
|
||||
@ -54,7 +53,7 @@ private:
|
||||
else
|
||||
{
|
||||
using DataType = DataTypeDecimal<FieldType>;
|
||||
return std::make_shared<DataType>(DataType::maxPrecision(), scale);
|
||||
return std::make_shared<DataType>(DataType::maxPrecision(), num_scale_);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -80,7 +80,7 @@ public:
|
||||
|
||||
AggregateFunctionMapBase(const DataTypePtr & keys_type_,
|
||||
const DataTypes & values_types_, const DataTypes & argument_types_)
|
||||
: Base(argument_types_, {} /* parameters */)
|
||||
: Base(argument_types_, {} /* parameters */, createResultType(keys_type_, values_types_, getName()))
|
||||
, keys_type(keys_type_)
|
||||
, keys_serialization(keys_type->getDefaultSerialization())
|
||||
, values_types(values_types_)
|
||||
@ -117,19 +117,22 @@ public:
|
||||
return 0;
|
||||
}
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
static DataTypePtr createResultType(
|
||||
const DataTypePtr & keys_type_,
|
||||
const DataTypes & values_types_,
|
||||
const String & name_)
|
||||
{
|
||||
DataTypes types;
|
||||
types.emplace_back(std::make_shared<DataTypeArray>(keys_type));
|
||||
types.emplace_back(std::make_shared<DataTypeArray>(keys_type_));
|
||||
|
||||
for (const auto & value_type : values_types)
|
||||
for (const auto & value_type : values_types_)
|
||||
{
|
||||
if constexpr (std::is_same_v<Visitor, FieldVisitorSum>)
|
||||
{
|
||||
if (!value_type->isSummable())
|
||||
throw Exception{ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
|
||||
"Values for {} cannot be summed, passed type {}",
|
||||
getName(), value_type->getName()};
|
||||
name_, value_type->getName()};
|
||||
}
|
||||
|
||||
DataTypePtr result_type;
|
||||
@ -139,7 +142,7 @@ public:
|
||||
if (value_type->onlyNull())
|
||||
throw Exception{ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
|
||||
"Cannot calculate {} of type {}",
|
||||
getName(), value_type->getName()};
|
||||
name_, value_type->getName()};
|
||||
|
||||
// Overflow, meaning that the returned type is the same as
|
||||
// the input type. Nulls are skipped.
|
||||
@ -153,7 +156,7 @@ public:
|
||||
if (!value_type_without_nullable->canBePromoted())
|
||||
throw Exception{ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
|
||||
"Values for {} are expected to be Numeric, Float or Decimal, passed type {}",
|
||||
getName(), value_type->getName()};
|
||||
name_, value_type->getName()};
|
||||
|
||||
WhichDataType value_type_to_check(value_type_without_nullable);
|
||||
|
||||
|
@ -46,7 +46,7 @@ private:
|
||||
Float64 confidence_level;
|
||||
public:
|
||||
AggregateFunctionTTest(const DataTypes & arguments, const Array & params)
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionTTest<Data>>({arguments}, params)
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionTTest<Data>>({arguments}, params, createResultType(!params.empty()))
|
||||
{
|
||||
if (!params.empty())
|
||||
{
|
||||
@ -71,9 +71,9 @@ public:
|
||||
return Data::name;
|
||||
}
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
static DataTypePtr createResultType(bool need_confidence_interval_)
|
||||
{
|
||||
if (need_confidence_interval)
|
||||
if (need_confidence_interval_)
|
||||
{
|
||||
DataTypes types
|
||||
{
|
||||
|
@ -31,15 +31,33 @@ namespace
|
||||
template <bool is_weighted>
|
||||
class AggregateFunctionTopKDate : public AggregateFunctionTopK<DataTypeDate::FieldType, is_weighted>
|
||||
{
|
||||
public:
|
||||
using AggregateFunctionTopK<DataTypeDate::FieldType, is_weighted>::AggregateFunctionTopK;
|
||||
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeArray>(std::make_shared<DataTypeDate>()); }
|
||||
|
||||
AggregateFunctionTopKDate(UInt64 threshold_, UInt64 load_factor, const DataTypes & argument_types_, const Array & params)
|
||||
: AggregateFunctionTopK<DataTypeDate::FieldType, is_weighted>(
|
||||
threshold_,
|
||||
load_factor,
|
||||
argument_types_,
|
||||
params,
|
||||
std::make_shared<DataTypeArray>(std::make_shared<DataTypeDate>()))
|
||||
{}
|
||||
};
|
||||
|
||||
template <bool is_weighted>
|
||||
class AggregateFunctionTopKDateTime : public AggregateFunctionTopK<DataTypeDateTime::FieldType, is_weighted>
|
||||
{
|
||||
public:
|
||||
using AggregateFunctionTopK<DataTypeDateTime::FieldType, is_weighted>::AggregateFunctionTopK;
|
||||
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeArray>(std::make_shared<DataTypeDateTime>()); }
|
||||
|
||||
AggregateFunctionTopKDateTime(UInt64 threshold_, UInt64 load_factor, const DataTypes & argument_types_, const Array & params)
|
||||
: AggregateFunctionTopK<DataTypeDateTime::FieldType, is_weighted>(
|
||||
threshold_,
|
||||
load_factor,
|
||||
argument_types_,
|
||||
params,
|
||||
std::make_shared<DataTypeArray>(std::make_shared<DataTypeDateTime>()))
|
||||
{}
|
||||
};
|
||||
|
||||
|
||||
|
@ -11,6 +11,7 @@
|
||||
|
||||
#include <Common/SpaceSaving.h>
|
||||
#include <Common/assert_cast.h>
|
||||
#include "DataTypes/Serializations/ISerialization.h"
|
||||
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
|
||||
@ -40,14 +41,20 @@ protected:
|
||||
|
||||
public:
|
||||
AggregateFunctionTopK(UInt64 threshold_, UInt64 load_factor, const DataTypes & argument_types_, const Array & params)
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionTopKData<T>, AggregateFunctionTopK<T, is_weighted>>(argument_types_, params)
|
||||
, threshold(threshold_), reserved(load_factor * threshold) {}
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionTopKData<T>, AggregateFunctionTopK<T, is_weighted>>(argument_types_, params, createResultType(argument_types_))
|
||||
, threshold(threshold_), reserved(load_factor * threshold)
|
||||
{}
|
||||
|
||||
AggregateFunctionTopK(UInt64 threshold_, UInt64 load_factor, const DataTypes & argument_types_, const Array & params, const DataTypePtr & result_type_)
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionTopKData<T>, AggregateFunctionTopK<T, is_weighted>>(argument_types_, params, result_type_)
|
||||
, threshold(threshold_), reserved(load_factor * threshold)
|
||||
{}
|
||||
|
||||
String getName() const override { return is_weighted ? "topKWeighted" : "topK"; }
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
static DataTypePtr createResultType(const DataTypes & argument_types_)
|
||||
{
|
||||
return std::make_shared<DataTypeArray>(this->argument_types[0]);
|
||||
return std::make_shared<DataTypeArray>(argument_types_[0]);
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
@ -126,21 +133,20 @@ private:
|
||||
|
||||
UInt64 threshold;
|
||||
UInt64 reserved;
|
||||
DataTypePtr & input_data_type;
|
||||
|
||||
static void deserializeAndInsert(StringRef str, IColumn & data_to);
|
||||
|
||||
public:
|
||||
AggregateFunctionTopKGeneric(
|
||||
UInt64 threshold_, UInt64 load_factor, const DataTypes & argument_types_, const Array & params)
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionTopKGenericData, AggregateFunctionTopKGeneric<is_plain_column, is_weighted>>(argument_types_, params)
|
||||
, threshold(threshold_), reserved(load_factor * threshold), input_data_type(this->argument_types[0]) {}
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionTopKGenericData, AggregateFunctionTopKGeneric<is_plain_column, is_weighted>>(argument_types_, params, createResultType(argument_types_))
|
||||
, threshold(threshold_), reserved(load_factor * threshold) {}
|
||||
|
||||
String getName() const override { return is_weighted ? "topKWeighted" : "topK"; }
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
static DataTypePtr createResultType(const DataTypes & argument_types_)
|
||||
{
|
||||
return std::make_shared<DataTypeArray>(input_data_type);
|
||||
return std::make_shared<DataTypeArray>(argument_types_[0]);
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override
|
||||
|
@ -358,17 +358,12 @@ private:
|
||||
|
||||
public:
|
||||
explicit AggregateFunctionUniq(const DataTypes & argument_types_)
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionUniq<T, Data>>(argument_types_, {})
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionUniq<T, Data>>(argument_types_, {}, std::make_shared<DataTypeUInt64>())
|
||||
{
|
||||
}
|
||||
|
||||
String getName() const override { return Data::getName(); }
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
{
|
||||
return std::make_shared<DataTypeUInt64>();
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
/// ALWAYS_INLINE is required to have better code layout for uniqHLL12 function
|
||||
@ -462,7 +457,7 @@ private:
|
||||
|
||||
public:
|
||||
explicit AggregateFunctionUniqVariadic(const DataTypes & arguments)
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionUniqVariadic<Data>>(arguments, {})
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionUniqVariadic<Data>>(arguments, {}, std::make_shared<DataTypeUInt64>())
|
||||
{
|
||||
if (argument_is_tuple)
|
||||
num_args = typeid_cast<const DataTypeTuple &>(*arguments[0]).getElements().size();
|
||||
@ -472,11 +467,6 @@ public:
|
||||
|
||||
String getName() const override { return Data::getName(); }
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
{
|
||||
return std::make_shared<DataTypeUInt64>();
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
|
@ -126,7 +126,8 @@ class AggregateFunctionUniqCombined final
|
||||
{
|
||||
public:
|
||||
AggregateFunctionUniqCombined(const DataTypes & argument_types_, const Array & params_)
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionUniqCombinedData<T, K, HashValueType>, AggregateFunctionUniqCombined<T, K, HashValueType>>(argument_types_, params_) {}
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionUniqCombinedData<T, K, HashValueType>, AggregateFunctionUniqCombined<T, K, HashValueType>>(argument_types_, params_, std::make_shared<DataTypeUInt64>())
|
||||
{}
|
||||
|
||||
String getName() const override
|
||||
{
|
||||
@ -136,11 +137,6 @@ public:
|
||||
return "uniqCombined";
|
||||
}
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
{
|
||||
return std::make_shared<DataTypeUInt64>();
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
@ -192,7 +188,7 @@ private:
|
||||
public:
|
||||
explicit AggregateFunctionUniqCombinedVariadic(const DataTypes & arguments, const Array & params)
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionUniqCombinedData<UInt64, K, HashValueType>,
|
||||
AggregateFunctionUniqCombinedVariadic<is_exact, argument_is_tuple, K, HashValueType>>(arguments, params)
|
||||
AggregateFunctionUniqCombinedVariadic<is_exact, argument_is_tuple, K, HashValueType>>(arguments, params, std::make_shared<DataTypeUInt64>())
|
||||
{
|
||||
if (argument_is_tuple)
|
||||
num_args = typeid_cast<const DataTypeTuple &>(*arguments[0]).getElements().size();
|
||||
@ -208,11 +204,6 @@ public:
|
||||
return "uniqCombined";
|
||||
}
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
{
|
||||
return std::make_shared<DataTypeUInt64>();
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
|
@ -174,7 +174,7 @@ private:
|
||||
|
||||
public:
|
||||
AggregateFunctionUniqUpTo(UInt8 threshold_, const DataTypes & argument_types_, const Array & params_)
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionUniqUpToData<T>, AggregateFunctionUniqUpTo<T>>(argument_types_, params_)
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionUniqUpToData<T>, AggregateFunctionUniqUpTo<T>>(argument_types_, params_, std::make_shared<DataTypeUInt64>())
|
||||
, threshold(threshold_)
|
||||
{
|
||||
}
|
||||
@ -186,11 +186,6 @@ public:
|
||||
|
||||
String getName() const override { return "uniqUpTo"; }
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
{
|
||||
return std::make_shared<DataTypeUInt64>();
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
/// ALWAYS_INLINE is required to have better code layout for uniqUpTo function
|
||||
@ -235,7 +230,7 @@ private:
|
||||
|
||||
public:
|
||||
AggregateFunctionUniqUpToVariadic(const DataTypes & arguments, const Array & params, UInt8 threshold_)
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionUniqUpToData<UInt64>, AggregateFunctionUniqUpToVariadic<is_exact, argument_is_tuple>>(arguments, params)
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionUniqUpToData<UInt64>, AggregateFunctionUniqUpToVariadic<is_exact, argument_is_tuple>>(arguments, params, std::make_shared<DataTypeUInt64>())
|
||||
, threshold(threshold_)
|
||||
{
|
||||
if (argument_is_tuple)
|
||||
@ -251,11 +246,6 @@ public:
|
||||
|
||||
String getName() const override { return "uniqUpTo"; }
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
{
|
||||
return std::make_shared<DataTypeUInt64>();
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
|
@ -221,7 +221,7 @@ public:
|
||||
}
|
||||
|
||||
AggregateFunctionWindowFunnel(const DataTypes & arguments, const Array & params)
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionWindowFunnel<T, Data>>(arguments, params)
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionWindowFunnel<T, Data>>(arguments, params, std::make_shared<DataTypeUInt8>())
|
||||
{
|
||||
events_size = arguments.size() - 1;
|
||||
window = params.at(0).safeGet<UInt64>();
|
||||
@ -245,11 +245,6 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
{
|
||||
return std::make_shared<DataTypeUInt8>();
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, const size_t row_num, Arena *) const override
|
||||
|
@ -118,7 +118,7 @@ class AggregateFunctionCrossTab : public IAggregateFunctionDataHelper<Data, Aggr
|
||||
{
|
||||
public:
|
||||
explicit AggregateFunctionCrossTab(const DataTypes & arguments)
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionCrossTab<Data>>({arguments}, {})
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionCrossTab<Data>>({arguments}, {}, createResultType())
|
||||
{
|
||||
}
|
||||
|
||||
@ -132,7 +132,7 @@ public:
|
||||
return false;
|
||||
}
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
static DataTypePtr createResultType()
|
||||
{
|
||||
return std::make_shared<DataTypeNumber<Float64>>();
|
||||
}
|
||||
|
@ -10,6 +10,7 @@
|
||||
#include <base/types.h>
|
||||
#include <Common/Exception.h>
|
||||
#include <Common/ThreadPool.h>
|
||||
#include <Core/IResolvedFunction.h>
|
||||
|
||||
#include "config.h"
|
||||
|
||||
@ -48,7 +49,9 @@ using AggregateDataPtr = char *;
|
||||
using ConstAggregateDataPtr = const char *;
|
||||
|
||||
class IAggregateFunction;
|
||||
using AggregateFunctionPtr = std::shared_ptr<const IAggregateFunction>;
|
||||
using AggregateFunctionPtr = std::shared_ptr<IAggregateFunction>;
|
||||
using ConstAggregateFunctionPtr = std::shared_ptr<const IAggregateFunction>;
|
||||
|
||||
struct AggregateFunctionProperties;
|
||||
|
||||
/** Aggregate functions interface.
|
||||
@ -59,18 +62,18 @@ struct AggregateFunctionProperties;
|
||||
* (which can be created in some memory pool),
|
||||
* and IAggregateFunction is the external interface for manipulating them.
|
||||
*/
|
||||
class IAggregateFunction : public std::enable_shared_from_this<IAggregateFunction>
|
||||
class IAggregateFunction : public std::enable_shared_from_this<IAggregateFunction>, public IResolvedFunction
|
||||
{
|
||||
public:
|
||||
IAggregateFunction(const DataTypes & argument_types_, const Array & parameters_)
|
||||
: argument_types(argument_types_), parameters(parameters_) {}
|
||||
IAggregateFunction(const DataTypes & argument_types_, const Array & parameters_, const DataTypePtr & result_type_)
|
||||
: result_type(result_type_)
|
||||
, argument_types(argument_types_)
|
||||
, parameters(parameters_)
|
||||
{}
|
||||
|
||||
/// Get main function name.
|
||||
virtual String getName() const = 0;
|
||||
|
||||
/// Get the result type.
|
||||
virtual DataTypePtr getReturnType() const = 0;
|
||||
|
||||
/// Get the data type of internal state. By default it is AggregateFunction(name(params), argument_types...).
|
||||
virtual DataTypePtr getStateType() const;
|
||||
|
||||
@ -102,7 +105,7 @@ public:
|
||||
|
||||
virtual size_t getDefaultVersion() const { return 0; }
|
||||
|
||||
virtual ~IAggregateFunction() = default;
|
||||
~IAggregateFunction() override = default;
|
||||
|
||||
/** Data manipulating functions. */
|
||||
|
||||
@ -348,8 +351,9 @@ public:
|
||||
*/
|
||||
virtual AggregateFunctionPtr getNestedFunction() const { return {}; }
|
||||
|
||||
const DataTypes & getArgumentTypes() const { return argument_types; }
|
||||
const Array & getParameters() const { return parameters; }
|
||||
const DataTypePtr & getResultType() const override { return result_type; }
|
||||
const DataTypes & getArgumentTypes() const override { return argument_types; }
|
||||
const Array & getParameters() const override { return parameters; }
|
||||
|
||||
// Any aggregate function can be calculated over a window, but there are some
|
||||
// window functions such as rank() that require a different interface, e.g.
|
||||
@ -398,6 +402,7 @@ public:
|
||||
#endif
|
||||
|
||||
protected:
|
||||
DataTypePtr result_type;
|
||||
DataTypes argument_types;
|
||||
Array parameters;
|
||||
};
|
||||
@ -414,8 +419,8 @@ private:
|
||||
}
|
||||
|
||||
public:
|
||||
IAggregateFunctionHelper(const DataTypes & argument_types_, const Array & parameters_)
|
||||
: IAggregateFunction(argument_types_, parameters_) {}
|
||||
IAggregateFunctionHelper(const DataTypes & argument_types_, const Array & parameters_, const DataTypePtr & result_type_)
|
||||
: IAggregateFunction(argument_types_, parameters_, result_type_) {}
|
||||
|
||||
AddFunc getAddressOfAddFunction() const override { return &addFree; }
|
||||
|
||||
@ -695,15 +700,15 @@ public:
|
||||
// Derived class can `override` this to flag that DateTime64 is not supported.
|
||||
static constexpr bool DateTime64Supported = true;
|
||||
|
||||
IAggregateFunctionDataHelper(const DataTypes & argument_types_, const Array & parameters_)
|
||||
: IAggregateFunctionHelper<Derived>(argument_types_, parameters_)
|
||||
IAggregateFunctionDataHelper(const DataTypes & argument_types_, const Array & parameters_, const DataTypePtr & result_type_)
|
||||
: IAggregateFunctionHelper<Derived>(argument_types_, parameters_, result_type_)
|
||||
{
|
||||
/// To prevent derived classes changing the destroy() without updating hasTrivialDestructor() to match it
|
||||
/// Enforce that either both of them are changed or none are
|
||||
constexpr bool declares_destroy_and_hasTrivialDestructor =
|
||||
constexpr bool declares_destroy_and_has_trivial_destructor =
|
||||
std::is_same_v<decltype(&IAggregateFunctionDataHelper::destroy), decltype(&Derived::destroy)> ==
|
||||
std::is_same_v<decltype(&IAggregateFunctionDataHelper::hasTrivialDestructor), decltype(&Derived::hasTrivialDestructor)>;
|
||||
static_assert(declares_destroy_and_hasTrivialDestructor,
|
||||
static_assert(declares_destroy_and_has_trivial_destructor,
|
||||
"destroy() and hasTrivialDestructor() methods of an aggregate function must be either both overridden or not");
|
||||
}
|
||||
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include <Common/SipHash.h>
|
||||
#include <Common/FieldVisitorToString.h>
|
||||
#include "Core/ColumnsWithTypeAndName.h"
|
||||
|
||||
#include <IO/WriteBufferFromString.h>
|
||||
#include <IO/Operators.h>
|
||||
@ -25,25 +26,54 @@ FunctionNode::FunctionNode(String function_name_)
|
||||
children[arguments_child_index] = std::make_shared<ListNode>();
|
||||
}
|
||||
|
||||
void FunctionNode::resolveAsFunction(FunctionOverloadResolverPtr function_value, DataTypePtr result_type_value)
|
||||
ColumnsWithTypeAndName FunctionNode::getArgumentTypes() const
|
||||
{
|
||||
aggregate_function = nullptr;
|
||||
ColumnsWithTypeAndName argument_types;
|
||||
for (const auto & arg : getArguments().getNodes())
|
||||
{
|
||||
ColumnWithTypeAndName argument;
|
||||
argument.type = arg->getResultType();
|
||||
argument_types.push_back(argument);
|
||||
}
|
||||
return argument_types;
|
||||
}
|
||||
|
||||
|
||||
FunctionBasePtr FunctionNode::getFunction() const
|
||||
{
|
||||
return std::dynamic_pointer_cast<IFunctionBase>(function);
|
||||
}
|
||||
|
||||
AggregateFunctionPtr FunctionNode::getAggregateFunction() const
|
||||
{
|
||||
return std::dynamic_pointer_cast<IAggregateFunction>(function);
|
||||
}
|
||||
|
||||
bool FunctionNode::isAggregateFunction() const
|
||||
{
|
||||
return typeid_cast<AggregateFunctionPtr>(function) != nullptr && !isWindowFunction();
|
||||
}
|
||||
|
||||
bool FunctionNode::isOrdinaryFunction() const
|
||||
{
|
||||
return typeid_cast<FunctionBasePtr>(function) != nullptr;
|
||||
}
|
||||
|
||||
void FunctionNode::resolveAsFunction(FunctionBasePtr function_value)
|
||||
{
|
||||
function_name = function_value->getName();
|
||||
function = std::move(function_value);
|
||||
result_type = std::move(result_type_value);
|
||||
function_name = function->getName();
|
||||
}
|
||||
|
||||
void FunctionNode::resolveAsAggregateFunction(AggregateFunctionPtr aggregate_function_value, DataTypePtr result_type_value)
|
||||
void FunctionNode::resolveAsAggregateFunction(AggregateFunctionPtr aggregate_function_value)
|
||||
{
|
||||
function = nullptr;
|
||||
aggregate_function = std::move(aggregate_function_value);
|
||||
result_type = std::move(result_type_value);
|
||||
function_name = aggregate_function->getName();
|
||||
function_name = aggregate_function_value->getName();
|
||||
function = std::move(aggregate_function_value);
|
||||
}
|
||||
|
||||
void FunctionNode::resolveAsWindowFunction(AggregateFunctionPtr window_function_value, DataTypePtr result_type_value)
|
||||
void FunctionNode::resolveAsWindowFunction(AggregateFunctionPtr window_function_value)
|
||||
{
|
||||
resolveAsAggregateFunction(window_function_value, result_type_value);
|
||||
resolveAsAggregateFunction(window_function_value);
|
||||
}
|
||||
|
||||
void FunctionNode::dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const
|
||||
@ -63,8 +93,8 @@ void FunctionNode::dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state
|
||||
|
||||
buffer << ", function_type: " << function_type;
|
||||
|
||||
if (result_type)
|
||||
buffer << ", result_type: " + result_type->getName();
|
||||
if (function)
|
||||
buffer << ", result_type: " + function->getResultType()->getName();
|
||||
|
||||
const auto & parameters = getParameters();
|
||||
if (!parameters.getNodes().empty())
|
||||
@ -95,12 +125,14 @@ bool FunctionNode::isEqualImpl(const IQueryTreeNode & rhs) const
|
||||
isOrdinaryFunction() != rhs_typed.isOrdinaryFunction() ||
|
||||
isWindowFunction() != rhs_typed.isWindowFunction())
|
||||
return false;
|
||||
auto lhs_result_type = getResultType();
|
||||
auto rhs_result_type = rhs.getResultType();
|
||||
|
||||
if (result_type && rhs_typed.result_type && !result_type->equals(*rhs_typed.getResultType()))
|
||||
if (lhs_result_type && rhs_result_type && !lhs_result_type->equals(*rhs_result_type))
|
||||
return false;
|
||||
else if (result_type && !rhs_typed.result_type)
|
||||
else if (lhs_result_type && !rhs_result_type)
|
||||
return false;
|
||||
else if (!result_type && rhs_typed.result_type)
|
||||
else if (!lhs_result_type && rhs_result_type)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
@ -114,7 +146,7 @@ void FunctionNode::updateTreeHashImpl(HashState & hash_state) const
|
||||
hash_state.update(isAggregateFunction());
|
||||
hash_state.update(isWindowFunction());
|
||||
|
||||
if (result_type)
|
||||
if (auto result_type = getResultType())
|
||||
{
|
||||
auto result_type_name = result_type->getName();
|
||||
hash_state.update(result_type_name.size());
|
||||
@ -130,8 +162,6 @@ QueryTreeNodePtr FunctionNode::cloneImpl() const
|
||||
* because ordinary functions or aggregate functions must be stateless.
|
||||
*/
|
||||
result_function->function = function;
|
||||
result_function->aggregate_function = aggregate_function;
|
||||
result_function->result_type = result_type;
|
||||
|
||||
return result_function;
|
||||
}
|
||||
|
@ -1,8 +1,12 @@
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <Core/IResolvedFunction.h>
|
||||
#include <Analyzer/IQueryTreeNode.h>
|
||||
#include <Analyzer/ListNode.h>
|
||||
#include <Analyzer/ConstantValue.h>
|
||||
#include <Common/typeid_cast.h>
|
||||
#include "Core/ColumnsWithTypeAndName.h"
|
||||
|
||||
namespace DB
|
||||
{
|
||||
@ -10,8 +14,11 @@ namespace DB
|
||||
class IFunctionOverloadResolver;
|
||||
using FunctionOverloadResolverPtr = std::shared_ptr<IFunctionOverloadResolver>;
|
||||
|
||||
class IFunctionBase;
|
||||
using FunctionBasePtr = std::shared_ptr<IFunctionBase>;
|
||||
|
||||
class IAggregateFunction;
|
||||
using AggregateFunctionPtr = std::shared_ptr<const IAggregateFunction>;
|
||||
using AggregateFunctionPtr = std::shared_ptr<IAggregateFunction>;
|
||||
|
||||
/** Function node represents function in query tree.
|
||||
* Function syntax: function_name(parameter_1, ...)(argument_1, ...).
|
||||
@ -96,6 +103,8 @@ public:
|
||||
return children[arguments_child_index];
|
||||
}
|
||||
|
||||
ColumnsWithTypeAndName getArgumentTypes() const;
|
||||
|
||||
/// Returns true if function node has window, false otherwise
|
||||
bool hasWindow() const
|
||||
{
|
||||
@ -124,24 +133,18 @@ public:
|
||||
/** Get non aggregate function.
|
||||
* If function is not resolved nullptr returned.
|
||||
*/
|
||||
const FunctionOverloadResolverPtr & getFunction() const
|
||||
{
|
||||
return function;
|
||||
}
|
||||
FunctionBasePtr getFunction() const;
|
||||
|
||||
/** Get aggregate function.
|
||||
* If function is not resolved nullptr returned.
|
||||
* If function is resolved as non aggregate function nullptr returned.
|
||||
*/
|
||||
const AggregateFunctionPtr & getAggregateFunction() const
|
||||
{
|
||||
return aggregate_function;
|
||||
}
|
||||
AggregateFunctionPtr getAggregateFunction() const;
|
||||
|
||||
/// Is function node resolved
|
||||
bool isResolved() const
|
||||
{
|
||||
return result_type != nullptr && (function != nullptr || aggregate_function != nullptr);
|
||||
return function != nullptr;
|
||||
}
|
||||
|
||||
/// Is function node window function
|
||||
@ -151,16 +154,10 @@ public:
|
||||
}
|
||||
|
||||
/// Is function node aggregate function
|
||||
bool isAggregateFunction() const
|
||||
{
|
||||
return aggregate_function != nullptr && !isWindowFunction();
|
||||
}
|
||||
bool isAggregateFunction() const;
|
||||
|
||||
/// Is function node ordinary function
|
||||
bool isOrdinaryFunction() const
|
||||
{
|
||||
return function != nullptr;
|
||||
}
|
||||
bool isOrdinaryFunction() const;
|
||||
|
||||
/** Resolve function node as non aggregate function.
|
||||
* It is important that function name is updated with resolved function name.
|
||||
@ -168,19 +165,19 @@ public:
|
||||
* Assume we have `multiIf` function with single condition, it can be converted to `if` function.
|
||||
* Function name must be updated accordingly.
|
||||
*/
|
||||
void resolveAsFunction(FunctionOverloadResolverPtr function_value, DataTypePtr result_type_value);
|
||||
void resolveAsFunction(FunctionBasePtr function_value);
|
||||
|
||||
/** Resolve function node as aggregate function.
|
||||
* It is important that function name is updated with resolved function name.
|
||||
* Main motivation for this is query tree optimizations.
|
||||
*/
|
||||
void resolveAsAggregateFunction(AggregateFunctionPtr aggregate_function_value, DataTypePtr result_type_value);
|
||||
void resolveAsAggregateFunction(AggregateFunctionPtr aggregate_function_value);
|
||||
|
||||
/** Resolve function node as window function.
|
||||
* It is important that function name is updated with resolved function name.
|
||||
* Main motivation for this is query tree optimizations.
|
||||
*/
|
||||
void resolveAsWindowFunction(AggregateFunctionPtr window_function_value, DataTypePtr result_type_value);
|
||||
void resolveAsWindowFunction(AggregateFunctionPtr window_function_value);
|
||||
|
||||
QueryTreeNodeType getNodeType() const override
|
||||
{
|
||||
@ -189,7 +186,7 @@ public:
|
||||
|
||||
DataTypePtr getResultType() const override
|
||||
{
|
||||
return result_type;
|
||||
return function->getResultType();
|
||||
}
|
||||
|
||||
void dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const override;
|
||||
@ -205,9 +202,7 @@ protected:
|
||||
|
||||
private:
|
||||
String function_name;
|
||||
FunctionOverloadResolverPtr function;
|
||||
AggregateFunctionPtr aggregate_function;
|
||||
DataTypePtr result_type;
|
||||
IResolvedFunctionPtr function;
|
||||
|
||||
static constexpr size_t parameters_child_index = 0;
|
||||
static constexpr size_t arguments_child_index = 1;
|
||||
|
@ -147,7 +147,6 @@ public:
|
||||
private:
|
||||
static inline void resolveAggregateFunctionNode(FunctionNode & function_node, const String & aggregate_function_name)
|
||||
{
|
||||
auto function_result_type = function_node.getResultType();
|
||||
auto function_aggregate_function = function_node.getAggregateFunction();
|
||||
|
||||
AggregateFunctionProperties properties;
|
||||
@ -156,7 +155,7 @@ private:
|
||||
function_aggregate_function->getParameters(),
|
||||
properties);
|
||||
|
||||
function_node.resolveAsAggregateFunction(std::move(aggregate_function), std::move(function_result_type));
|
||||
function_node.resolveAsAggregateFunction(std::move(aggregate_function));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -69,7 +69,7 @@ public:
|
||||
auto result_type = function_node->getResultType();
|
||||
AggregateFunctionProperties properties;
|
||||
auto aggregate_function = AggregateFunctionFactory::instance().get("count", {}, {}, properties);
|
||||
function_node->resolveAsAggregateFunction(std::move(aggregate_function), std::move(result_type));
|
||||
function_node->resolveAsAggregateFunction(std::move(aggregate_function));
|
||||
function_node->getArguments().getNodes().clear();
|
||||
}
|
||||
};
|
||||
|
@ -9,6 +9,7 @@
|
||||
|
||||
#include <Analyzer/InDepthQueryTreeVisitor.h>
|
||||
#include <Analyzer/FunctionNode.h>
|
||||
#include "Core/ColumnWithTypeAndName.h"
|
||||
|
||||
namespace DB
|
||||
{
|
||||
@ -138,7 +139,6 @@ public:
|
||||
|
||||
static inline void resolveAggregateOrWindowFunctionNode(FunctionNode & function_node, const String & aggregate_function_name)
|
||||
{
|
||||
auto function_result_type = function_node.getResultType();
|
||||
auto function_aggregate_function = function_node.getAggregateFunction();
|
||||
|
||||
AggregateFunctionProperties properties;
|
||||
@ -148,16 +148,15 @@ public:
|
||||
properties);
|
||||
|
||||
if (function_node.isAggregateFunction())
|
||||
function_node.resolveAsAggregateFunction(std::move(aggregate_function), std::move(function_result_type));
|
||||
function_node.resolveAsAggregateFunction(std::move(aggregate_function));
|
||||
else if (function_node.isWindowFunction())
|
||||
function_node.resolveAsWindowFunction(std::move(aggregate_function), std::move(function_result_type));
|
||||
function_node.resolveAsWindowFunction(std::move(aggregate_function));
|
||||
}
|
||||
|
||||
inline void resolveOrdinaryFunctionNode(FunctionNode & function_node, const String & function_name) const
|
||||
{
|
||||
auto function_result_type = function_node.getResultType();
|
||||
auto function = FunctionFactory::instance().get(function_name, context);
|
||||
function_node.resolveAsFunction(function, std::move(function_result_type));
|
||||
function_node.resolveAsFunction(function->build(function_node.getArgumentTypes()));
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -78,11 +78,11 @@ public:
|
||||
column.name += ".size0";
|
||||
column.type = std::make_shared<DataTypeUInt64>();
|
||||
|
||||
resolveOrdinaryFunctionNode(*function_node, "equals");
|
||||
|
||||
function_arguments_nodes.clear();
|
||||
function_arguments_nodes.push_back(std::make_shared<ColumnNode>(column, column_source));
|
||||
function_arguments_nodes.push_back(std::make_shared<ConstantNode>(static_cast<UInt64>(0)));
|
||||
|
||||
resolveOrdinaryFunctionNode(*function_node, "equals");
|
||||
}
|
||||
else if (function_name == "notEmpty")
|
||||
{
|
||||
@ -90,11 +90,11 @@ public:
|
||||
column.name += ".size0";
|
||||
column.type = std::make_shared<DataTypeUInt64>();
|
||||
|
||||
resolveOrdinaryFunctionNode(*function_node, "notEquals");
|
||||
|
||||
function_arguments_nodes.clear();
|
||||
function_arguments_nodes.push_back(std::make_shared<ColumnNode>(column, column_source));
|
||||
function_arguments_nodes.push_back(std::make_shared<ConstantNode>(static_cast<UInt64>(0)));
|
||||
|
||||
resolveOrdinaryFunctionNode(*function_node, "notEquals");
|
||||
}
|
||||
}
|
||||
else if (column_type.isNullable())
|
||||
@ -112,9 +112,9 @@ public:
|
||||
column.name += ".null";
|
||||
column.type = std::make_shared<DataTypeUInt8>();
|
||||
|
||||
resolveOrdinaryFunctionNode(*function_node, "not");
|
||||
|
||||
function_arguments_nodes = {std::make_shared<ColumnNode>(column, column_source)};
|
||||
|
||||
resolveOrdinaryFunctionNode(*function_node, "not");
|
||||
}
|
||||
}
|
||||
else if (column_type.isMap())
|
||||
@ -182,9 +182,9 @@ public:
|
||||
column.type = data_type_map.getKeyType();
|
||||
|
||||
auto has_function_argument = std::make_shared<ColumnNode>(column, column_source);
|
||||
resolveOrdinaryFunctionNode(*function_node, "has");
|
||||
|
||||
function_arguments_nodes[0] = std::move(has_function_argument);
|
||||
|
||||
resolveOrdinaryFunctionNode(*function_node, "has");
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -192,9 +192,8 @@ public:
|
||||
private:
|
||||
inline void resolveOrdinaryFunctionNode(FunctionNode & function_node, const String & function_name) const
|
||||
{
|
||||
auto function_result_type = function_node.getResultType();
|
||||
auto function = FunctionFactory::instance().get(function_name, context);
|
||||
function_node.resolveAsFunction(function, std::move(function_result_type));
|
||||
function_node.resolveAsFunction(function->build(function_node.getArgumentTypes()));
|
||||
}
|
||||
|
||||
ContextPtr & context;
|
||||
|
@ -59,14 +59,13 @@ private:
|
||||
std::unordered_set<String> names_to_collect;
|
||||
};
|
||||
|
||||
QueryTreeNodePtr createResolvedFunction(const ContextPtr & context, const String & name, const DataTypePtr & result_type, QueryTreeNodes arguments)
|
||||
QueryTreeNodePtr createResolvedFunction(const ContextPtr & context, const String & name, QueryTreeNodes arguments)
|
||||
{
|
||||
auto function_node = std::make_shared<FunctionNode>(name);
|
||||
|
||||
auto function = FunctionFactory::instance().get(name, context);
|
||||
function_node->resolveAsFunction(std::move(function), result_type);
|
||||
function_node->getArguments().getNodes() = std::move(arguments);
|
||||
|
||||
function_node->resolveAsFunction(function->build(function_node->getArgumentTypes()));
|
||||
return function_node;
|
||||
}
|
||||
|
||||
@ -74,11 +73,6 @@ FunctionNodePtr createResolvedAggregateFunction(const String & name, const Query
|
||||
{
|
||||
auto function_node = std::make_shared<FunctionNode>(name);
|
||||
|
||||
AggregateFunctionProperties properties;
|
||||
auto aggregate_function = AggregateFunctionFactory::instance().get(name, {argument->getResultType()}, parameters, properties);
|
||||
function_node->resolveAsAggregateFunction(aggregate_function, aggregate_function->getReturnType());
|
||||
function_node->getArguments().getNodes() = { argument };
|
||||
|
||||
if (!parameters.empty())
|
||||
{
|
||||
QueryTreeNodes parameter_nodes;
|
||||
@ -86,18 +80,27 @@ FunctionNodePtr createResolvedAggregateFunction(const String & name, const Query
|
||||
parameter_nodes.emplace_back(std::make_shared<ConstantNode>(param));
|
||||
function_node->getParameters().getNodes() = std::move(parameter_nodes);
|
||||
}
|
||||
function_node->getArguments().getNodes() = { argument };
|
||||
|
||||
AggregateFunctionProperties properties;
|
||||
auto aggregate_function = AggregateFunctionFactory::instance().get(
|
||||
name,
|
||||
{ argument->getResultType() },
|
||||
parameters,
|
||||
properties);
|
||||
function_node->resolveAsAggregateFunction(aggregate_function);
|
||||
|
||||
return function_node;
|
||||
}
|
||||
|
||||
QueryTreeNodePtr createTupleElementFunction(const ContextPtr & context, const DataTypePtr & result_type, QueryTreeNodePtr argument, UInt64 index)
|
||||
QueryTreeNodePtr createTupleElementFunction(const ContextPtr & context, QueryTreeNodePtr argument, UInt64 index)
|
||||
{
|
||||
return createResolvedFunction(context, "tupleElement", result_type, {std::move(argument), std::make_shared<ConstantNode>(index)});
|
||||
return createResolvedFunction(context, "tupleElement", {argument, std::make_shared<ConstantNode>(index)});
|
||||
}
|
||||
|
||||
QueryTreeNodePtr createArrayElementFunction(const ContextPtr & context, const DataTypePtr & result_type, QueryTreeNodePtr argument, UInt64 index)
|
||||
QueryTreeNodePtr createArrayElementFunction(const ContextPtr & context, QueryTreeNodePtr argument, UInt64 index)
|
||||
{
|
||||
return createResolvedFunction(context, "arrayElement", result_type, {std::move(argument), std::make_shared<ConstantNode>(index)});
|
||||
return createResolvedFunction(context, "arrayElement", {argument, std::make_shared<ConstantNode>(index)});
|
||||
}
|
||||
|
||||
void replaceWithSumCount(QueryTreeNodePtr & node, const FunctionNodePtr & sum_count_node, ContextPtr context)
|
||||
@ -115,20 +118,20 @@ void replaceWithSumCount(QueryTreeNodePtr & node, const FunctionNodePtr & sum_co
|
||||
if (function_name == "sum")
|
||||
{
|
||||
assert(node->getResultType()->equals(*sum_count_result_type->getElement(0)));
|
||||
node = createTupleElementFunction(context, node->getResultType(), sum_count_node, 1);
|
||||
node = createTupleElementFunction(context, sum_count_node, 1);
|
||||
}
|
||||
else if (function_name == "count")
|
||||
{
|
||||
assert(node->getResultType()->equals(*sum_count_result_type->getElement(1)));
|
||||
node = createTupleElementFunction(context, node->getResultType(), sum_count_node, 2);
|
||||
node = createTupleElementFunction(context, sum_count_node, 2);
|
||||
}
|
||||
else if (function_name == "avg")
|
||||
{
|
||||
auto sum_result = createTupleElementFunction(context, sum_count_result_type->getElement(0), sum_count_node, 1);
|
||||
auto count_result = createTupleElementFunction(context, sum_count_result_type->getElement(1), sum_count_node, 2);
|
||||
auto sum_result = createTupleElementFunction(context, sum_count_node, 1);
|
||||
auto count_result = createTupleElementFunction(context, sum_count_node, 2);
|
||||
/// To avoid integer division by zero
|
||||
auto count_float_result = createResolvedFunction(context, "toFloat64", std::make_shared<DataTypeFloat64>(), {count_result});
|
||||
node = createResolvedFunction(context, "divide", node->getResultType(), {sum_result, count_float_result});
|
||||
auto count_float_result = createResolvedFunction(context, "toFloat64", {count_result});
|
||||
node = createResolvedFunction(context, "divide", {sum_result, count_float_result});
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -238,7 +241,7 @@ void tryFuseQuantiles(QueryTreeNodePtr query_tree_node, ContextPtr context)
|
||||
for (size_t i = 0; i < nodes_set.size(); ++i)
|
||||
{
|
||||
size_t array_index = i + 1;
|
||||
*nodes[i] = createArrayElementFunction(context, result_array_type->getNestedType(), quantiles_node, array_index);
|
||||
*nodes[i] = createArrayElementFunction(context, quantiles_node, array_index);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -55,8 +55,8 @@ public:
|
||||
return;
|
||||
|
||||
auto multi_if_function = std::make_shared<FunctionNode>("multiIf");
|
||||
multi_if_function->resolveAsFunction(multi_if_function_ptr, std::make_shared<DataTypeUInt8>());
|
||||
multi_if_function->getArguments().getNodes() = std::move(multi_if_arguments);
|
||||
multi_if_function->resolveAsFunction(multi_if_function_ptr->build(multi_if_function->getArgumentTypes()));
|
||||
node = std::move(multi_if_function);
|
||||
}
|
||||
|
||||
|
@ -27,7 +27,7 @@ public:
|
||||
return;
|
||||
|
||||
auto result_type = function_node->getResultType();
|
||||
function_node->resolveAsFunction(if_function_ptr, std::move(result_type));
|
||||
function_node->resolveAsFunction(if_function_ptr->build(function_node->getArgumentTypes()));
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -48,12 +48,10 @@ public:
|
||||
private:
|
||||
static inline void resolveAsCountAggregateFunction(FunctionNode & function_node)
|
||||
{
|
||||
auto function_result_type = function_node.getResultType();
|
||||
|
||||
AggregateFunctionProperties properties;
|
||||
auto aggregate_function = AggregateFunctionFactory::instance().get("count", {}, {}, properties);
|
||||
|
||||
function_node.resolveAsAggregateFunction(std::move(aggregate_function), std::move(function_result_type));
|
||||
function_node.resolveAsAggregateFunction(std::move(aggregate_function));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -4287,7 +4287,7 @@ ProjectionNames QueryAnalyzer::resolveFunction(QueryTreeNodePtr & node, Identifi
|
||||
bool force_grouping_standard_compatibility = scope.context->getSettingsRef().force_grouping_standard_compatibility;
|
||||
auto grouping_function = std::make_shared<FunctionGrouping>(force_grouping_standard_compatibility);
|
||||
auto grouping_function_adaptor = std::make_shared<FunctionToOverloadResolverAdaptor>(std::move(grouping_function));
|
||||
function_node.resolveAsFunction(std::move(grouping_function_adaptor), std::make_shared<DataTypeUInt64>());
|
||||
function_node.resolveAsFunction(grouping_function_adaptor->build({}));
|
||||
return result_projection_names;
|
||||
}
|
||||
}
|
||||
@ -4307,7 +4307,7 @@ ProjectionNames QueryAnalyzer::resolveFunction(QueryTreeNodePtr & node, Identifi
|
||||
AggregateFunctionProperties properties;
|
||||
auto aggregate_function = AggregateFunctionFactory::instance().get(function_name, argument_types, parameters, properties);
|
||||
|
||||
function_node.resolveAsWindowFunction(aggregate_function, aggregate_function->getReturnType());
|
||||
function_node.resolveAsWindowFunction(aggregate_function);
|
||||
|
||||
bool window_node_is_identifier = function_node.getWindowNode()->getNodeType() == QueryTreeNodeType::IDENTIFIER;
|
||||
ProjectionName window_projection_name = resolveWindow(function_node.getWindowNode(), scope);
|
||||
@ -4361,7 +4361,7 @@ ProjectionNames QueryAnalyzer::resolveFunction(QueryTreeNodePtr & node, Identifi
|
||||
|
||||
AggregateFunctionProperties properties;
|
||||
auto aggregate_function = AggregateFunctionFactory::instance().get(function_name, argument_types, parameters, properties);
|
||||
function_node.resolveAsAggregateFunction(aggregate_function, aggregate_function->getReturnType());
|
||||
function_node.resolveAsAggregateFunction(aggregate_function);
|
||||
return result_projection_names;
|
||||
}
|
||||
|
||||
@ -4538,6 +4538,8 @@ ProjectionNames QueryAnalyzer::resolveFunction(QueryTreeNodePtr & node, Identifi
|
||||
constant_value = std::make_shared<ConstantValue>(std::move(column_constant_value), result_type);
|
||||
}
|
||||
}
|
||||
|
||||
function_node.resolveAsFunction(std::move(function_base));
|
||||
}
|
||||
catch (Exception & e)
|
||||
{
|
||||
@ -4545,8 +4547,6 @@ ProjectionNames QueryAnalyzer::resolveFunction(QueryTreeNodePtr & node, Identifi
|
||||
throw;
|
||||
}
|
||||
|
||||
function_node.resolveAsFunction(std::move(function), std::move(result_type));
|
||||
|
||||
if (constant_value)
|
||||
node = std::make_shared<ConstantNode>(std::move(constant_value), node);
|
||||
|
||||
|
@ -13,6 +13,7 @@
|
||||
#include <Analyzer/InDepthQueryTreeVisitor.h>
|
||||
#include <Analyzer/ConstantNode.h>
|
||||
#include <Analyzer/FunctionNode.h>
|
||||
#include "Core/ColumnsWithTypeAndName.h"
|
||||
|
||||
namespace DB
|
||||
{
|
||||
@ -117,11 +118,12 @@ public:
|
||||
not_function_result_type = makeNullable(not_function_result_type);
|
||||
|
||||
auto not_function = std::make_shared<FunctionNode>("not");
|
||||
not_function->resolveAsFunction(FunctionFactory::instance().get("not", context), std::move(not_function_result_type));
|
||||
|
||||
auto & not_function_arguments = not_function->getArguments().getNodes();
|
||||
not_function_arguments.push_back(std::move(nested_if_function_arguments_nodes[0]));
|
||||
|
||||
not_function->resolveAsFunction(FunctionFactory::instance().get("not", context)->build(not_function->getArgumentTypes()));
|
||||
|
||||
function_node_arguments_nodes[0] = std::move(not_function);
|
||||
function_node_arguments_nodes.resize(1);
|
||||
|
||||
@ -139,8 +141,7 @@ private:
|
||||
function_node.getAggregateFunction()->getParameters(),
|
||||
properties);
|
||||
|
||||
auto function_result_type = function_node.getResultType();
|
||||
function_node.resolveAsAggregateFunction(std::move(aggregate_function), std::move(function_result_type));
|
||||
function_node.resolveAsAggregateFunction(std::move(aggregate_function));
|
||||
}
|
||||
|
||||
ContextPtr & context;
|
||||
|
@ -76,7 +76,7 @@ public:
|
||||
properties);
|
||||
|
||||
auto function_result_type = function_node->getResultType();
|
||||
function_node->resolveAsAggregateFunction(std::move(aggregate_function), std::move(function_result_type));
|
||||
function_node->resolveAsAggregateFunction(std::move(aggregate_function));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -28,7 +28,7 @@ namespace ErrorCodes
|
||||
}
|
||||
|
||||
|
||||
static String getTypeString(const AggregateFunctionPtr & func, std::optional<size_t> version = std::nullopt)
|
||||
static String getTypeString(const ConstAggregateFunctionPtr & func, std::optional<size_t> version = std::nullopt)
|
||||
{
|
||||
WriteBufferFromOwnString stream;
|
||||
|
||||
@ -62,18 +62,18 @@ static String getTypeString(const AggregateFunctionPtr & func, std::optional<siz
|
||||
}
|
||||
|
||||
|
||||
ColumnAggregateFunction::ColumnAggregateFunction(const AggregateFunctionPtr & func_, std::optional<size_t> version_)
|
||||
ColumnAggregateFunction::ColumnAggregateFunction(const ConstAggregateFunctionPtr & func_, std::optional<size_t> version_)
|
||||
: func(func_), type_string(getTypeString(func, version_)), version(version_)
|
||||
{
|
||||
}
|
||||
|
||||
ColumnAggregateFunction::ColumnAggregateFunction(const AggregateFunctionPtr & func_, const ConstArenas & arenas_)
|
||||
ColumnAggregateFunction::ColumnAggregateFunction(const ConstAggregateFunctionPtr & func_, const ConstArenas & arenas_)
|
||||
: foreign_arenas(arenas_), func(func_), type_string(getTypeString(func))
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
void ColumnAggregateFunction::set(const AggregateFunctionPtr & func_, size_t version_)
|
||||
void ColumnAggregateFunction::set(const ConstAggregateFunctionPtr & func_, size_t version_)
|
||||
{
|
||||
func = func_;
|
||||
version = version_;
|
||||
@ -146,7 +146,7 @@ MutableColumnPtr ColumnAggregateFunction::convertToValues(MutableColumnPtr colum
|
||||
/// insertResultInto may invalidate states, so we must unshare ownership of them
|
||||
column_aggregate_func.ensureOwnership();
|
||||
|
||||
MutableColumnPtr res = func->getReturnType()->createColumn();
|
||||
MutableColumnPtr res = func->getResultType()->createColumn();
|
||||
res->reserve(data.size());
|
||||
|
||||
/// If there are references to states in final column, we must hold their ownership
|
||||
|
@ -70,7 +70,7 @@ private:
|
||||
ArenaPtr my_arena;
|
||||
|
||||
/// Used for destroying states and for finalization of values.
|
||||
AggregateFunctionPtr func;
|
||||
ConstAggregateFunctionPtr func;
|
||||
|
||||
/// Source column. Used (holds source from destruction),
|
||||
/// if this column has been constructed from another and uses all or part of its values.
|
||||
@ -92,9 +92,9 @@ private:
|
||||
/// Create a new column that has another column as a source.
|
||||
MutablePtr createView() const;
|
||||
|
||||
explicit ColumnAggregateFunction(const AggregateFunctionPtr & func_, std::optional<size_t> version_ = std::nullopt);
|
||||
explicit ColumnAggregateFunction(const ConstAggregateFunctionPtr & func_, std::optional<size_t> version_ = std::nullopt);
|
||||
|
||||
ColumnAggregateFunction(const AggregateFunctionPtr & func_, const ConstArenas & arenas_);
|
||||
ColumnAggregateFunction(const ConstAggregateFunctionPtr & func_, const ConstArenas & arenas_);
|
||||
|
||||
ColumnAggregateFunction(const ColumnAggregateFunction & src_);
|
||||
|
||||
@ -103,10 +103,10 @@ private:
|
||||
public:
|
||||
~ColumnAggregateFunction() override;
|
||||
|
||||
void set(const AggregateFunctionPtr & func_, size_t version_);
|
||||
void set(const ConstAggregateFunctionPtr & func_, size_t version_);
|
||||
|
||||
AggregateFunctionPtr getAggregateFunction() { return func; }
|
||||
AggregateFunctionPtr getAggregateFunction() const { return func; }
|
||||
ConstAggregateFunctionPtr getAggregateFunction() { return func; }
|
||||
ConstAggregateFunctionPtr getAggregateFunction() const { return func; }
|
||||
|
||||
/// If we have another column as a source (owner of data), copy all data to ourself and reset source.
|
||||
/// This is needed before inserting new elements, because we must own these elements (to destroy them in destructor),
|
||||
|
29
src/Core/IResolvedFunction.h
Normal file
29
src/Core/IResolvedFunction.h
Normal file
@ -0,0 +1,29 @@
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
class IDataType;
|
||||
|
||||
using DataTypePtr = std::shared_ptr<const IDataType>;
|
||||
using DataTypes = std::vector<DataTypePtr>;
|
||||
|
||||
struct Array;
|
||||
|
||||
class IResolvedFunction
|
||||
{
|
||||
public:
|
||||
virtual const DataTypePtr & getResultType() const = 0;
|
||||
|
||||
virtual const DataTypes & getArgumentTypes() const = 0;
|
||||
|
||||
virtual const Array & getParameters() const = 0;
|
||||
|
||||
virtual ~IResolvedFunction() = default;
|
||||
};
|
||||
|
||||
using IResolvedFunctionPtr = std::shared_ptr<IResolvedFunction>;
|
||||
|
||||
}
|
@ -19,7 +19,7 @@ namespace DB
|
||||
class DataTypeAggregateFunction final : public IDataType
|
||||
{
|
||||
private:
|
||||
AggregateFunctionPtr function;
|
||||
ConstAggregateFunctionPtr function;
|
||||
DataTypes argument_types;
|
||||
Array parameters;
|
||||
mutable std::optional<size_t> version;
|
||||
@ -30,9 +30,9 @@ private:
|
||||
public:
|
||||
static constexpr bool is_parametric = true;
|
||||
|
||||
DataTypeAggregateFunction(const AggregateFunctionPtr & function_, const DataTypes & argument_types_,
|
||||
DataTypeAggregateFunction(ConstAggregateFunctionPtr function_, const DataTypes & argument_types_,
|
||||
const Array & parameters_, std::optional<size_t> version_ = std::nullopt)
|
||||
: function(function_)
|
||||
: function(std::move(function_))
|
||||
, argument_types(argument_types_)
|
||||
, parameters(parameters_)
|
||||
, version(version_)
|
||||
@ -40,7 +40,7 @@ public:
|
||||
}
|
||||
|
||||
String getFunctionName() const { return function->getName(); }
|
||||
AggregateFunctionPtr getFunction() const { return function; }
|
||||
ConstAggregateFunctionPtr getFunction() const { return function; }
|
||||
|
||||
String doGetName() const override;
|
||||
String getNameWithoutVersion() const;
|
||||
@ -51,7 +51,7 @@ public:
|
||||
|
||||
bool canBeInsideNullable() const override { return false; }
|
||||
|
||||
DataTypePtr getReturnType() const { return function->getReturnType(); }
|
||||
DataTypePtr getReturnType() const { return function->getResultType(); }
|
||||
DataTypePtr getReturnTypeToPredict() const { return function->getReturnTypeToPredict(); }
|
||||
DataTypes getArgumentsDataTypes() const { return argument_types; }
|
||||
|
||||
|
@ -131,9 +131,9 @@ static std::pair<DataTypePtr, DataTypeCustomDescPtr> create(const ASTPtr & argum
|
||||
|
||||
DataTypePtr storage_type = DataTypeFactory::instance().get(argument_types[0]->getName());
|
||||
|
||||
if (!function->getReturnType()->equals(*removeLowCardinality(storage_type)))
|
||||
if (!function->getResultType()->equals(*removeLowCardinality(storage_type)))
|
||||
{
|
||||
throw Exception("Incompatible data types between aggregate function '" + function->getName() + "' which returns " + function->getReturnType()->getName() + " and column storage type " + storage_type->getName(),
|
||||
throw Exception("Incompatible data types between aggregate function '" + function->getName() + "' which returns " + function->getResultType()->getName() + " and column storage type " + storage_type->getName(),
|
||||
ErrorCodes::BAD_ARGUMENTS);
|
||||
}
|
||||
|
||||
|
@ -108,14 +108,14 @@ void SerializationAggregateFunction::deserializeBinaryBulk(IColumn & column, Rea
|
||||
}
|
||||
}
|
||||
|
||||
static String serializeToString(const AggregateFunctionPtr & function, const IColumn & column, size_t row_num, size_t version)
|
||||
static String serializeToString(const ConstAggregateFunctionPtr & function, const IColumn & column, size_t row_num, size_t version)
|
||||
{
|
||||
WriteBufferFromOwnString buffer;
|
||||
function->serialize(assert_cast<const ColumnAggregateFunction &>(column).getData()[row_num], buffer, version);
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
static void deserializeFromString(const AggregateFunctionPtr & function, IColumn & column, const String & s, size_t version)
|
||||
static void deserializeFromString(const ConstAggregateFunctionPtr & function, IColumn & column, const String & s, size_t version)
|
||||
{
|
||||
ColumnAggregateFunction & column_concrete = assert_cast<ColumnAggregateFunction &>(column);
|
||||
|
||||
|
@ -11,14 +11,14 @@ namespace DB
|
||||
class SerializationAggregateFunction final : public ISerialization
|
||||
{
|
||||
private:
|
||||
AggregateFunctionPtr function;
|
||||
ConstAggregateFunctionPtr function;
|
||||
String type_name;
|
||||
size_t version;
|
||||
|
||||
public:
|
||||
static constexpr bool is_parametric = true;
|
||||
|
||||
SerializationAggregateFunction(const AggregateFunctionPtr & function_, String type_name_, size_t version_)
|
||||
SerializationAggregateFunction(const ConstAggregateFunctionPtr & function_, String type_name_, size_t version_)
|
||||
: function(function_), type_name(std::move(type_name_)), version(version_) {}
|
||||
|
||||
/// NOTE These two functions for serializing single values are incompatible with the functions below.
|
||||
|
@ -1736,7 +1736,7 @@ namespace
|
||||
}
|
||||
|
||||
const std::shared_ptr<const DataTypeAggregateFunction> aggregate_function_data_type;
|
||||
const AggregateFunctionPtr aggregate_function;
|
||||
ConstAggregateFunctionPtr aggregate_function;
|
||||
String text_buffer;
|
||||
};
|
||||
|
||||
|
@ -895,7 +895,7 @@ class FunctionBinaryArithmetic : public IFunction
|
||||
const ColumnAggregateFunction & column = typeid_cast<const ColumnAggregateFunction &>(
|
||||
agg_state_is_const ? assert_cast<const ColumnConst &>(agg_state_column).getDataColumn() : agg_state_column);
|
||||
|
||||
AggregateFunctionPtr function = column.getAggregateFunction();
|
||||
ConstAggregateFunctionPtr function = column.getAggregateFunction();
|
||||
|
||||
size_t size = agg_state_is_const ? 1 : input_rows_count;
|
||||
|
||||
@ -960,7 +960,7 @@ class FunctionBinaryArithmetic : public IFunction
|
||||
const ColumnAggregateFunction & rhs = typeid_cast<const ColumnAggregateFunction &>(
|
||||
rhs_is_const ? assert_cast<const ColumnConst &>(rhs_column).getDataColumn() : rhs_column);
|
||||
|
||||
AggregateFunctionPtr function = lhs.getAggregateFunction();
|
||||
ConstAggregateFunctionPtr function = lhs.getAggregateFunction();
|
||||
|
||||
size_t size = (lhs_is_const && rhs_is_const) ? 1 : input_rows_count;
|
||||
|
||||
|
@ -3,6 +3,8 @@
|
||||
#include <Core/ColumnNumbers.h>
|
||||
#include <Core/ColumnsWithTypeAndName.h>
|
||||
#include <Core/Names.h>
|
||||
#include <Core/IResolvedFunction.h>
|
||||
#include <Common/Exception.h>
|
||||
#include <DataTypes/IDataType.h>
|
||||
|
||||
#include "config.h"
|
||||
@ -122,11 +124,11 @@ using Values = std::vector<llvm::Value *>;
|
||||
/** Function with known arguments and return type (when the specific overload was chosen).
|
||||
* It is also the point where all function-specific properties are known.
|
||||
*/
|
||||
class IFunctionBase
|
||||
class IFunctionBase : public IResolvedFunction
|
||||
{
|
||||
public:
|
||||
|
||||
virtual ~IFunctionBase() = default;
|
||||
~IFunctionBase() override = default;
|
||||
|
||||
virtual ColumnPtr execute( /// NOLINT
|
||||
const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count, bool dry_run = false) const
|
||||
@ -137,8 +139,10 @@ public:
|
||||
/// Get the main function name.
|
||||
virtual String getName() const = 0;
|
||||
|
||||
virtual const DataTypes & getArgumentTypes() const = 0;
|
||||
virtual const DataTypePtr & getResultType() const = 0;
|
||||
const Array & getParameters() const final
|
||||
{
|
||||
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "IFunctionBase doesn't support getParameters method");
|
||||
}
|
||||
|
||||
/// Do preparations and return executable.
|
||||
/// sample_columns should contain data types of arguments and values of constants, if relevant.
|
||||
|
@ -51,6 +51,8 @@ public:
|
||||
const DataTypes & getArgumentTypes() const override { return arguments; }
|
||||
const DataTypePtr & getResultType() const override { return result_type; }
|
||||
|
||||
const FunctionPtr & getFunction() const { return function; }
|
||||
|
||||
#if USE_EMBEDDED_COMPILER
|
||||
|
||||
bool isCompilable() const override { return function->isCompilable(getArgumentTypes()); }
|
||||
|
@ -104,7 +104,7 @@ DataTypePtr FunctionArrayReduce::getReturnTypeImpl(const ColumnsWithTypeAndName
|
||||
aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name, argument_types, params_row, properties);
|
||||
}
|
||||
|
||||
return aggregate_function->getReturnType();
|
||||
return aggregate_function->getResultType();
|
||||
}
|
||||
|
||||
|
||||
|
@ -122,7 +122,7 @@ DataTypePtr FunctionArrayReduceInRanges::getReturnTypeImpl(const ColumnsWithType
|
||||
aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name, argument_types, params_row, properties);
|
||||
}
|
||||
|
||||
return std::make_shared<DataTypeArray>(aggregate_function->getReturnType());
|
||||
return std::make_shared<DataTypeArray>(aggregate_function->getResultType());
|
||||
}
|
||||
|
||||
|
||||
|
@ -87,7 +87,7 @@ DataTypePtr FunctionInitializeAggregation::getReturnTypeImpl(const ColumnsWithTy
|
||||
aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name, argument_types, params_row, properties);
|
||||
}
|
||||
|
||||
return aggregate_function->getReturnType();
|
||||
return aggregate_function->getResultType();
|
||||
}
|
||||
|
||||
|
||||
|
@ -91,7 +91,7 @@ public:
|
||||
if (arguments.size() == 2)
|
||||
column_with_groups = arguments[1].column;
|
||||
|
||||
AggregateFunctionPtr aggregate_function_ptr = column_with_states->getAggregateFunction();
|
||||
ConstAggregateFunctionPtr aggregate_function_ptr = column_with_states->getAggregateFunction();
|
||||
const IAggregateFunction & agg_func = *aggregate_function_ptr;
|
||||
|
||||
AlignedBuffer place(agg_func.sizeOfData(), agg_func.alignOfData());
|
||||
@ -99,7 +99,7 @@ public:
|
||||
/// Will pass empty arena if agg_func does not allocate memory in arena
|
||||
std::unique_ptr<Arena> arena = agg_func.allocatesMemoryInArena() ? std::make_unique<Arena>() : nullptr;
|
||||
|
||||
auto result_column_ptr = agg_func.getReturnType()->createColumn();
|
||||
auto result_column_ptr = agg_func.getResultType()->createColumn();
|
||||
IColumn & result_column = *result_column_ptr;
|
||||
result_column.reserve(column_with_states->size());
|
||||
|
||||
|
@ -47,8 +47,6 @@ void ActionsDAG::Node::toTree(JSONBuilder::JSONMap & map) const
|
||||
|
||||
if (function_base)
|
||||
map.add("Function", function_base->getName());
|
||||
else if (function_builder)
|
||||
map.add("Function", function_builder->getName());
|
||||
|
||||
if (type == ActionType::FUNCTION)
|
||||
map.add("Compiled", is_function_compiled);
|
||||
@ -166,7 +164,6 @@ const ActionsDAG::Node & ActionsDAG::addFunction(
|
||||
|
||||
Node node;
|
||||
node.type = ActionType::FUNCTION;
|
||||
node.function_builder = function;
|
||||
node.children = std::move(children);
|
||||
|
||||
bool all_const = true;
|
||||
@ -238,6 +235,86 @@ const ActionsDAG::Node & ActionsDAG::addFunction(
|
||||
return addNode(std::move(node));
|
||||
}
|
||||
|
||||
const ActionsDAG::Node & ActionsDAG::addFunction(
|
||||
const FunctionBasePtr & function_base,
|
||||
NodeRawConstPtrs children,
|
||||
std::string result_name)
|
||||
{
|
||||
size_t num_arguments = children.size();
|
||||
|
||||
Node node;
|
||||
node.type = ActionType::FUNCTION;
|
||||
node.children = std::move(children);
|
||||
|
||||
bool all_const = true;
|
||||
ColumnsWithTypeAndName arguments(num_arguments);
|
||||
|
||||
for (size_t i = 0; i < num_arguments; ++i)
|
||||
{
|
||||
const auto & child = *node.children[i];
|
||||
|
||||
ColumnWithTypeAndName argument;
|
||||
argument.column = child.column;
|
||||
argument.type = child.result_type;
|
||||
argument.name = child.result_name;
|
||||
|
||||
if (!argument.column || !isColumnConst(*argument.column))
|
||||
all_const = false;
|
||||
|
||||
arguments[i] = std::move(argument);
|
||||
}
|
||||
|
||||
node.function_base = function_base;
|
||||
node.result_type = node.function_base->getResultType();
|
||||
node.function = node.function_base->prepare(arguments);
|
||||
node.is_deterministic = node.function_base->isDeterministic();
|
||||
|
||||
/// If all arguments are constants, and function is suitable to be executed in 'prepare' stage - execute function.
|
||||
if (node.function_base->isSuitableForConstantFolding())
|
||||
{
|
||||
ColumnPtr column;
|
||||
|
||||
if (all_const)
|
||||
{
|
||||
size_t num_rows = arguments.empty() ? 0 : arguments.front().column->size();
|
||||
column = node.function->execute(arguments, node.result_type, num_rows, true);
|
||||
}
|
||||
else
|
||||
{
|
||||
column = node.function_base->getConstantResultForNonConstArguments(arguments, node.result_type);
|
||||
}
|
||||
|
||||
/// If the result is not a constant, just in case, we will consider the result as unknown.
|
||||
if (column && isColumnConst(*column))
|
||||
{
|
||||
/// All constant (literal) columns in block are added with size 1.
|
||||
/// But if there was no columns in block before executing a function, the result has size 0.
|
||||
/// Change the size to 1.
|
||||
|
||||
if (column->empty())
|
||||
column = column->cloneResized(1);
|
||||
|
||||
node.column = std::move(column);
|
||||
}
|
||||
}
|
||||
|
||||
if (result_name.empty())
|
||||
{
|
||||
result_name = function_base->getName() + "(";
|
||||
for (size_t i = 0; i < num_arguments; ++i)
|
||||
{
|
||||
if (i)
|
||||
result_name += ", ";
|
||||
result_name += node.children[i]->result_name;
|
||||
}
|
||||
result_name += ")";
|
||||
}
|
||||
|
||||
node.result_name = std::move(result_name);
|
||||
|
||||
return addNode(std::move(node));
|
||||
}
|
||||
|
||||
const ActionsDAG::Node & ActionsDAG::findInOutputs(const std::string & name) const
|
||||
{
|
||||
if (const auto * node = tryFindInOutputs(name))
|
||||
@ -1927,8 +2004,7 @@ ActionsDAGPtr ActionsDAG::cloneActionsForFilterPushDown(
|
||||
|
||||
FunctionOverloadResolverPtr func_builder_cast = CastInternalOverloadResolver<CastType::nonAccurate>::createImpl();
|
||||
|
||||
predicate->function_builder = func_builder_cast;
|
||||
predicate->function_base = predicate->function_builder->build(arguments);
|
||||
predicate->function_base = func_builder_cast->build(arguments);
|
||||
predicate->function = predicate->function_base->prepare(arguments);
|
||||
}
|
||||
}
|
||||
@ -1939,7 +2015,9 @@ ActionsDAGPtr ActionsDAG::cloneActionsForFilterPushDown(
|
||||
predicate->children.swap(new_children);
|
||||
auto arguments = prepareFunctionArguments(predicate->children);
|
||||
|
||||
predicate->function_base = predicate->function_builder->build(arguments);
|
||||
FunctionOverloadResolverPtr func_builder_and = std::make_unique<FunctionToOverloadResolverAdaptor>(std::make_shared<FunctionAnd>());
|
||||
|
||||
predicate->function_base = func_builder_and->build(arguments);
|
||||
predicate->function = predicate->function_base->prepare(arguments);
|
||||
}
|
||||
}
|
||||
|
@ -74,7 +74,6 @@ public:
|
||||
std::string result_name;
|
||||
DataTypePtr result_type;
|
||||
|
||||
FunctionOverloadResolverPtr function_builder;
|
||||
/// Can be used to get function signature or properties like monotonicity.
|
||||
FunctionBasePtr function_base;
|
||||
/// Prepared function which is used in function execution.
|
||||
@ -139,6 +138,10 @@ public:
|
||||
const FunctionOverloadResolverPtr & function,
|
||||
NodeRawConstPtrs children,
|
||||
std::string result_name);
|
||||
const Node & addFunction(
|
||||
const FunctionBasePtr & function_base,
|
||||
NodeRawConstPtrs children,
|
||||
std::string result_name);
|
||||
|
||||
/// Find first column by name in output nodes. This search is linear.
|
||||
const Node & findInOutputs(const std::string & name) const;
|
||||
|
@ -53,7 +53,7 @@ void AggregateDescription::explain(WriteBuffer & out, size_t indent) const
|
||||
out << type->getName();
|
||||
}
|
||||
|
||||
out << ") → " << function->getReturnType()->getName() << "\n";
|
||||
out << ") → " << function->getResultType()->getName() << "\n";
|
||||
}
|
||||
else
|
||||
out << prefix << " Function: nullptr\n";
|
||||
@ -109,7 +109,7 @@ void AggregateDescription::explain(JSONBuilder::JSONMap & map) const
|
||||
args_array->add(type->getName());
|
||||
|
||||
function_map->add("Argument Types", std::move(args_array));
|
||||
function_map->add("Result Type", function->getReturnType()->getName());
|
||||
function_map->add("Result Type", function->getResultType()->getName());
|
||||
|
||||
map.add("Function", std::move(function_map));
|
||||
}
|
||||
|
@ -45,7 +45,7 @@ OutputBlockColumns prepareOutputBlockColumns(
|
||||
}
|
||||
else
|
||||
{
|
||||
final_aggregate_columns[i] = aggregate_functions[i]->getReturnType()->createColumn();
|
||||
final_aggregate_columns[i] = aggregate_functions[i]->getResultType()->createColumn();
|
||||
final_aggregate_columns[i]->reserve(rows);
|
||||
|
||||
if (aggregate_functions[i]->isState())
|
||||
|
@ -433,7 +433,7 @@ Block Aggregator::Params::getHeader(
|
||||
{
|
||||
auto & elem = res.getByName(aggregate.column_name);
|
||||
|
||||
elem.type = aggregate.function->getReturnType();
|
||||
elem.type = aggregate.function->getResultType();
|
||||
elem.column = elem.type->createColumn();
|
||||
}
|
||||
}
|
||||
@ -452,7 +452,7 @@ Block Aggregator::Params::getHeader(
|
||||
|
||||
DataTypePtr type;
|
||||
if (final)
|
||||
type = aggregate.function->getReturnType();
|
||||
type = aggregate.function->getResultType();
|
||||
else
|
||||
type = std::make_shared<DataTypeAggregateFunction>(aggregate.function, argument_types, aggregate.parameters);
|
||||
|
||||
|
@ -423,7 +423,7 @@ void ExpressionAnalyzer::analyzeAggregation(ActionsDAGPtr & temp_actions)
|
||||
aggregated_columns = temp_actions->getNamesAndTypesList();
|
||||
|
||||
for (const auto & desc : aggregate_descriptions)
|
||||
aggregated_columns.emplace_back(desc.column_name, desc.function->getReturnType());
|
||||
aggregated_columns.emplace_back(desc.column_name, desc.function->getResultType());
|
||||
}
|
||||
|
||||
|
||||
@ -2021,7 +2021,7 @@ ExpressionAnalysisResult::ExpressionAnalysisResult(
|
||||
for (const auto & f : w.window_functions)
|
||||
{
|
||||
query_analyzer.columns_after_window.push_back(
|
||||
{f.column_name, f.aggregate_function->getReturnType()});
|
||||
{f.column_name, f.aggregate_function->getResultType()});
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -403,7 +403,7 @@ static void compileInsertAggregatesIntoResultColumns(llvm::Module & module, cons
|
||||
std::vector<ColumnDataPlaceholder> columns(functions.size());
|
||||
for (size_t i = 0; i < functions.size(); ++i)
|
||||
{
|
||||
auto return_type = functions[i].function->getReturnType();
|
||||
auto return_type = functions[i].function->getResultType();
|
||||
auto * data = b.CreateLoad(column_type, b.CreateConstInBoundsGEP1_64(column_type, columns_arg, i));
|
||||
|
||||
auto * column_data_type = toNativeType(b, removeNullable(return_type));
|
||||
|
@ -365,8 +365,8 @@ void Planner::buildQueryPlanIfNeeded()
|
||||
{
|
||||
auto function_node = std::make_shared<FunctionNode>("and");
|
||||
auto and_function = FunctionFactory::instance().get("and", query_context);
|
||||
function_node->resolveAsFunction(std::move(and_function), std::make_shared<DataTypeUInt8>());
|
||||
function_node->getArguments().getNodes() = {query_node.getPrewhere(), query_node.getWhere()};
|
||||
function_node->resolveAsFunction(and_function->build(function_node->getArgumentTypes()));
|
||||
query_node.getWhere() = std::move(function_node);
|
||||
query_node.getPrewhere() = {};
|
||||
}
|
||||
|
@ -121,7 +121,7 @@ public:
|
||||
return node;
|
||||
}
|
||||
|
||||
const ActionsDAG::Node * addFunctionIfNecessary(const std::string & node_name, ActionsDAG::NodeRawConstPtrs children, FunctionOverloadResolverPtr function)
|
||||
const ActionsDAG::Node * addFunctionIfNecessary(const std::string & node_name, ActionsDAG::NodeRawConstPtrs children, FunctionBasePtr function)
|
||||
{
|
||||
auto it = node_name_to_node.find(node_name);
|
||||
if (it != node_name_to_node.end())
|
||||
@ -325,7 +325,7 @@ PlannerActionsVisitorImpl::NodeNameAndNodeMinLevel PlannerActionsVisitorImpl::vi
|
||||
lambda_actions, captured_column_names, lambda_arguments_names_and_types, result_type, lambda_expression_node_name);
|
||||
actions_stack.pop_back();
|
||||
|
||||
actions_stack[level].addFunctionIfNecessary(lambda_node_name, std::move(lambda_children), std::move(function_capture));
|
||||
actions_stack[level].addFunctionIfNecessary(lambda_node_name, std::move(lambda_children), function_capture->build({}));
|
||||
|
||||
size_t actions_stack_size = actions_stack.size();
|
||||
for (size_t i = level + 1; i < actions_stack_size; ++i)
|
||||
|
@ -101,14 +101,14 @@ public:
|
||||
{
|
||||
auto grouping_ordinary_function = std::make_shared<FunctionGroupingOrdinary>(arguments_indexes, force_grouping_standard_compatibility);
|
||||
auto grouping_ordinary_function_adaptor = std::make_shared<FunctionToOverloadResolverAdaptor>(std::move(grouping_ordinary_function));
|
||||
function_node->resolveAsFunction(std::move(grouping_ordinary_function_adaptor), std::make_shared<DataTypeUInt64>());
|
||||
function_node->resolveAsFunction(grouping_ordinary_function_adaptor->build({}));
|
||||
break;
|
||||
}
|
||||
case GroupByKind::ROLLUP:
|
||||
{
|
||||
auto grouping_rollup_function = std::make_shared<FunctionGroupingForRollup>(arguments_indexes, aggregation_keys_size, force_grouping_standard_compatibility);
|
||||
auto grouping_rollup_function_adaptor = std::make_shared<FunctionToOverloadResolverAdaptor>(std::move(grouping_rollup_function));
|
||||
function_node->resolveAsFunction(std::move(grouping_rollup_function_adaptor), std::make_shared<DataTypeUInt64>());
|
||||
function_node->resolveAsFunction(grouping_rollup_function_adaptor->build({}));
|
||||
function_node->getArguments().getNodes().push_back(std::move(grouping_set_argument_column));
|
||||
break;
|
||||
}
|
||||
@ -116,7 +116,7 @@ public:
|
||||
{
|
||||
auto grouping_cube_function = std::make_shared<FunctionGroupingForCube>(arguments_indexes, aggregation_keys_size, force_grouping_standard_compatibility);
|
||||
auto grouping_cube_function_adaptor = std::make_shared<FunctionToOverloadResolverAdaptor>(std::move(grouping_cube_function));
|
||||
function_node->resolveAsFunction(std::move(grouping_cube_function_adaptor), std::make_shared<DataTypeUInt64>());
|
||||
function_node->resolveAsFunction(grouping_cube_function_adaptor->build({}));
|
||||
function_node->getArguments().getNodes().push_back(std::move(grouping_set_argument_column));
|
||||
break;
|
||||
}
|
||||
@ -124,7 +124,7 @@ public:
|
||||
{
|
||||
auto grouping_grouping_sets_function = std::make_shared<FunctionGroupingForGroupingSets>(arguments_indexes, grouping_sets_keys_indices, force_grouping_standard_compatibility);
|
||||
auto grouping_grouping_sets_function_adaptor = std::make_shared<FunctionToOverloadResolverAdaptor>(std::move(grouping_grouping_sets_function));
|
||||
function_node->resolveAsFunction(std::move(grouping_grouping_sets_function_adaptor), std::make_shared<DataTypeUInt64>());
|
||||
function_node->resolveAsFunction(grouping_grouping_sets_function_adaptor->build({}));
|
||||
function_node->getArguments().getNodes().push_back(std::move(grouping_set_argument_column));
|
||||
break;
|
||||
}
|
||||
|
@ -65,7 +65,7 @@ std::optional<AggregationAnalysisResult> analyzeAggregation(QueryTreeNodePtr & q
|
||||
ColumnsWithTypeAndName aggregates_columns;
|
||||
aggregates_columns.reserve(aggregates_descriptions.size());
|
||||
for (auto & aggregate_description : aggregates_descriptions)
|
||||
aggregates_columns.emplace_back(nullptr, aggregate_description.function->getReturnType(), aggregate_description.column_name);
|
||||
aggregates_columns.emplace_back(nullptr, aggregate_description.function->getResultType(), aggregate_description.column_name);
|
||||
|
||||
Names aggregation_keys;
|
||||
|
||||
@ -284,7 +284,7 @@ std::optional<WindowAnalysisResult> analyzeWindow(QueryTreeNodePtr & query_tree,
|
||||
|
||||
for (auto & window_description : window_descriptions)
|
||||
for (auto & window_function : window_description.window_functions)
|
||||
window_functions_additional_columns.emplace_back(nullptr, window_function.aggregate_function->getReturnType(), window_function.column_name);
|
||||
window_functions_additional_columns.emplace_back(nullptr, window_function.aggregate_function->getResultType(), window_function.column_name);
|
||||
|
||||
auto before_window_step = std::make_unique<ActionsChainStep>(before_window_actions,
|
||||
ActionsChainStep::AvailableOutputColumnsStrategy::ALL_NODES,
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user