From d8370116c1cabcb9d46ad2f3a7cc290a0a180b7c Mon Sep 17 00:00:00 2001 From: myrrc Date: Sun, 25 Oct 2020 23:33:01 +0300 Subject: [PATCH] simplified the functions (agreement to cast to Float64) --- .../aggregate-functions/reference/avg.md | 3 +- .../reference/avgweighted.md | 6 +- .../AggregateFunctionAvg.cpp | 22 +-- src/AggregateFunctions/AggregateFunctionAvg.h | 111 ++++----------- .../AggregateFunctionAvgWeighted.cpp | 42 +----- .../AggregateFunctionAvgWeighted.h | 128 +----------------- .../AggregateFunctionFactory.h | 4 +- 7 files changed, 49 insertions(+), 267 deletions(-) diff --git a/docs/en/sql-reference/aggregate-functions/reference/avg.md b/docs/en/sql-reference/aggregate-functions/reference/avg.md index 4ebae95b79d..1741bbb744b 100644 --- a/docs/en/sql-reference/aggregate-functions/reference/avg.md +++ b/docs/en/sql-reference/aggregate-functions/reference/avg.md @@ -4,4 +4,5 @@ toc_priority: 5 # avg {#agg_function-avg} -Calculates the average. Only works for numbers. The result is always Float64. +Calculates the average. Only works for numbers (Integral, floating-point, or Decimals). +The result is always Float64. diff --git a/docs/en/sql-reference/aggregate-functions/reference/avgweighted.md b/docs/en/sql-reference/aggregate-functions/reference/avgweighted.md index a6fb5999fb8..22993f93e16 100644 --- a/docs/en/sql-reference/aggregate-functions/reference/avgweighted.md +++ b/docs/en/sql-reference/aggregate-functions/reference/avgweighted.md @@ -28,11 +28,7 @@ but may have different types. - `NaN`. If all the weights are equal to 0. - Weighted mean otherwise. -**Return type** - -- `Decimal` if both types are [Decimal](../../../sql-reference/data-types/decimal.md) - or if one type is Decimal and other is Integer. -- [Float64](../../../sql-reference/data-types/float.md) otherwise. +**Return type** is always [Float64](../../../sql-reference/data-types/float.md). **Example** diff --git a/src/AggregateFunctions/AggregateFunctionAvg.cpp b/src/AggregateFunctions/AggregateFunctionAvg.cpp index cf35e99dafb..4d1b01b25fc 100644 --- a/src/AggregateFunctions/AggregateFunctionAvg.cpp +++ b/src/AggregateFunctions/AggregateFunctionAvg.cpp @@ -1,3 +1,4 @@ +#include #include #include #include @@ -13,23 +14,22 @@ namespace ErrorCodes namespace { +constexpr bool allowType(const DataTypePtr& type) noexcept +{ + const WhichDataType t(type); + return t.isInt() || t.isUInt() || t.isFloat() || t.isDecimal(); +} + AggregateFunctionPtr createAggregateFunctionAvg(const std::string & name, const DataTypes & argument_types, const Array & parameters) { assertNoParameters(name, parameters); assertUnary(name, argument_types); - AggregateFunctionPtr res; - DataTypePtr data_type = argument_types[0]; - - if (isDecimal(data_type)) - res.reset(createWithDecimalType(*data_type, *data_type, argument_types)); - else - res.reset(createWithNumericType(*data_type, argument_types)); - - if (!res) + if (!allowType(argument_types[0])) throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name, - ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); - return res; + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + return std::make_shared(argument_types); } } diff --git a/src/AggregateFunctions/AggregateFunctionAvg.h b/src/AggregateFunctions/AggregateFunctionAvg.h index 0b77aa0c537..66ab20cec73 100644 --- a/src/AggregateFunctions/AggregateFunctionAvg.h +++ b/src/AggregateFunctions/AggregateFunctionAvg.h @@ -10,10 +10,7 @@ namespace DB { -template -using DecimalOrVectorCol = std::conditional_t, ColumnDecimal, ColumnVector>; - -/// A type-fixed rational fraction represented by a pair of #Numerator and #Denominator. +/// A type-fixed fraction represented by a pair of #Numerator and #Denominator. template struct RationalFraction { @@ -26,77 +23,42 @@ struct RationalFraction template Result NO_SANITIZE_UNDEFINED result() const { - if constexpr (std::is_floating_point_v) - if constexpr (std::numeric_limits::is_iec559) - { - if constexpr (is_big_int_v) - return static_cast(numerator) / static_cast(denominator); - else - return static_cast(numerator) / denominator; /// allow division by zero - } + if constexpr (std::is_floating_point_v && std::numeric_limits::is_iec559) + return static_cast(numerator) / denominator; /// allow division by zero if (denominator == static_cast(0)) return static_cast(0); - if constexpr (std::is_same_v) - return static_cast(numerator / static_cast(denominator)); - else - return static_cast(numerator / denominator); + return static_cast(numerator / denominator); } }; /** - * Motivation: ClickHouse has added the Decimal data type, which basically represents a fraction that stores - * the precise (unlike floating-point) result with respect to some scale. + * The discussion showed that the easiest (and simplest) way is to cast both the columns of numerator and denominator + * to Float64. Another way would be to write some template magic that figures out the appropriate numerator + * and denominator (and the resulting type) in favour of extended integral types (UInt128 e.g.) and Decimals ( + * which are a mess themselves). The second way is also a bit useless because now Decimals are not used in functions + * like avg. * - * These decimal types can't be divided by floating point data types, so functions like avg or avgWeighted - * can't return the Floa64 column as a result when of the input columns is Decimal (because that would, in case of - * avgWeighted, involve division numerator (Decimal) / denominator (Float64)). - * - * The rules for determining the output and intermediate storage types for these functions are different, so - * the struct representing the deduction guide is presented. - * - * Given the initial Columns types (e.g. values and weights for avgWeighted, values for avg), - * the struct calculated the output type and the intermediate storage type (that's used by the RationalFraction). - */ -template -struct AvgFunctionTypesDeductionTemplate -{ - using Numerator = int; - using Denominator = int; - using Fraction = RationalFraction; - - using ResultType = bool; - using ResultDataType = bool; - using ResultVectorType = bool; -}; - -/** - * @tparam InitialNumerator The type that the initial numerator column would have (needed to cast the input IColumn to - * appropriate type). - * @tparam InitialDenominator The type that the initial denominator column would have. - * - * @tparam Deduction Function template that, given the numerator and the denominator, finds the actual - * suitable storage and the resulting column type. + * The ability to explicitly specify the denominator is made for avg (it uses the integral value as the denominator is + * simply the length of the supplied list). * * @tparam Derived When deriving from this class, use the child class name as in CRTP, e.g. * class Self : Agg. */ -template class Deduction, class Derived> +template class AggregateFunctionAvgBase : public - IAggregateFunctionDataHelper::Fraction, Derived> + IAggregateFunctionDataHelper, Derived> { public: - using Deducted = Deduction; + using Numerator = Float64; + using Fraction = RationalFraction; - using ResultType = typename Deducted::ResultType; - using ResultDataType = typename Deducted::ResultDataType; - using ResultVectorType = typename Deducted::ResultVectorType; + using ResultType = Float64; + using ResultDataType = DataTypeNumber; + using ResultVectorType = ColumnVector; - using Numerator = typename Deducted::Numerator; - using Denominator = typename Deducted::Denominator; - - using Base = IAggregateFunctionDataHelper; + using Base = IAggregateFunctionDataHelper; /// ctor for native types explicit AggregateFunctionAvgBase(const DataTypes & argument_types_): Base(argument_types_, {}), scale(0) {} @@ -107,10 +69,7 @@ public: DataTypePtr getReturnType() const override { - if constexpr (IsDecimalNumber) - return std::make_shared(ResultDataType::maxPrecision(), scale); - else - return std::make_shared(); + return std::make_shared(); } void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override @@ -148,38 +107,16 @@ protected: UInt32 scale; }; -template -struct AvgFunctionTypesDeduction -{ - using Numerator = std::conditional_t, - std::conditional_t, - Decimal256, - Decimal128>, - NearestFieldType>; - - using Denominator = V; - using Fraction = RationalFraction; - - using ResultType = std::conditional_t, T, Float64>; - using ResultDataType = std::conditional_t, DataTypeDecimal, DataTypeNumber>; - using ResultVectorType = std::conditional_t, ColumnDecimal, ColumnVector>; -}; - -template -class AggregateFunctionAvg final : - public AggregateFunctionAvgBase> +class AggregateFunctionAvg final : public AggregateFunctionAvgBase { public: - using Base = - AggregateFunctionAvgBase>; - - using Base::Base; + using AggregateFunctionAvgBase::AggregateFunctionAvgBase; void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const final { - const auto & column = static_cast &>(*columns[0]); + const auto & column = static_cast &>(*columns[0]); this->data(place).numerator += column.getData()[row_num]; - this->data(place).denominator += 1; + ++this->data(place).denominator; } String getName() const final { return "avg"; } diff --git a/src/AggregateFunctions/AggregateFunctionAvgWeighted.cpp b/src/AggregateFunctions/AggregateFunctionAvgWeighted.cpp index 5b43aa19a5c..6b677414d87 100644 --- a/src/AggregateFunctions/AggregateFunctionAvgWeighted.cpp +++ b/src/AggregateFunctions/AggregateFunctionAvgWeighted.cpp @@ -1,3 +1,4 @@ +#include #include #include #include @@ -14,7 +15,7 @@ namespace ErrorCodes namespace { -constexpr bool allowTypes(const DataTypePtr& left, const DataTypePtr& right) +constexpr bool allowTypes(const DataTypePtr& left, const DataTypePtr& right) noexcept { const WhichDataType l_dt(left), r_dt(right); @@ -26,39 +27,6 @@ constexpr bool allowTypes(const DataTypePtr& left, const DataTypePtr& right) return allow(l_dt) && allow(r_dt); } -#define AT_SWITCH(LINE) \ - switch (which.idx) \ - { \ - LINE(Int8); LINE(Int16); LINE(Int32); LINE(Int64); LINE(Int128); LINE(Int256); \ - LINE(UInt8); LINE(UInt16); LINE(UInt32); LINE(UInt64); LINE(UInt128); LINE(UInt256); \ - LINE(Decimal32); LINE(Decimal64); LINE(Decimal128); LINE(Decimal256); \ - LINE(Float32); LINE(Float64); \ - default: return nullptr; \ - } - -template -static IAggregateFunction * create(const IDataType & second_type, TArgs && ... args) -{ - const WhichDataType which(second_type); - -#define LINE(Type) \ - case TypeIndex::Type: return new AggregateFunctionAvgWeighted(std::forward(args)...) - AT_SWITCH(LINE) -#undef LINE -} - -// Not using helper functions because there are no templates for binary decimal/numeric function. -template -static IAggregateFunction * create(const IDataType & first_type, const IDataType & second_type, TArgs && ... args) -{ - const WhichDataType which(first_type); - -#define LINE(Type) \ - case TypeIndex::Type: return create(second_type, std::forward(args)...) - AT_SWITCH(LINE) -#undef LINE -} - AggregateFunctionPtr createAggregateFunctionAvgWeighted(const std::string & name, const DataTypes & argument_types, const Array & parameters) { assertNoParameters(name, parameters); @@ -74,11 +42,7 @@ AggregateFunctionPtr createAggregateFunctionAvgWeighted(const std::string & name " are non-conforming as arguments for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); - AggregateFunctionPtr res; - res.reset(create(*data_type, *data_type_weight, argument_types)); - - assert(res); // type checking should be done in allowTypes. - return res; + return std::make_shared(argument_types); } } diff --git a/src/AggregateFunctions/AggregateFunctionAvgWeighted.h b/src/AggregateFunctions/AggregateFunctionAvgWeighted.h index 6b1c41d2748..c9b60cf9f50 100644 --- a/src/AggregateFunctions/AggregateFunctionAvgWeighted.h +++ b/src/AggregateFunctions/AggregateFunctionAvgWeighted.h @@ -5,134 +5,18 @@ namespace DB { -template -struct AvgWeightedFunctionTypesDeduction -{ - template struct NextAvgType { }; - template <> struct NextAvgType { using Type = Int16; }; - template <> struct NextAvgType { using Type = Int32; }; - template <> struct NextAvgType { using Type = Int64; }; - template <> struct NextAvgType { using Type = Int128; }; - - template <> struct NextAvgType { using Type = UInt16; }; - template <> struct NextAvgType { using Type = UInt32; }; - template <> struct NextAvgType { using Type = UInt64; }; - template <> struct NextAvgType { using Type = UInt128; }; - - // Promoted to Float as these types don't go well when operating with above ones - template <> struct NextAvgType { using Type = Float64; }; - template <> struct NextAvgType { using Type = Float64; }; - template <> struct NextAvgType { using Type = Float64; }; - template <> struct NextAvgType { using Type = Float64; }; - - template <> struct NextAvgType { using Type = Decimal128; }; - template <> struct NextAvgType { using Type = Decimal128; }; - template <> struct NextAvgType { using Type = Decimal128; }; - template <> struct NextAvgType { using Type = Decimal256; }; - - template <> struct NextAvgType { using Type = Float64; }; - template <> struct NextAvgType { using Type = Float64; }; - - template using NextAvgTypeT = typename NextAvgType::Type; - template using Largest = std::conditional_t<(sizeof(T) > sizeof(U)), T, U>; - - struct GetNumDenom - { - using U = Values; - using V = Weights; - static constexpr bool UDecimal = IsDecimalNumber; - static constexpr bool VDecimal = IsDecimalNumber; - static constexpr bool BothDecimal = UDecimal && VDecimal; - static constexpr bool NoneDecimal = !UDecimal && !VDecimal; - - /// we do not include extended integral types here as they produce errors while diving on Decimals. - template static constexpr bool IsIntegral = std::is_integral_v; - template static constexpr bool IsExtendedIntegral = - std::is_same_v || std::is_same_v - || std::is_same_v || std::is_same_v; - - static constexpr bool BothOrNoneDecimal = BothDecimal || NoneDecimal; - - using Num = std::conditional_t>, - - std::conditional_t, - NextAvgTypeT, - Float64>, - /// When the denominator only is Decimal, it would be converted to Float64 (as integral / Decimal - /// produces a compile error, vice versa allowed), so we just cast the numerator to Flaoat64; - Float64>>; - - /** - * When both types are Decimal, we can perform computations in the Decimals only. - * When none of the types is Decimal, the result is always correct, the numerator is the next largest type up to - * Float64. - * We use #V only as the denominator accumulates the sum of the weights. - * - * When the numerator only is Decimal, we set the denominator to next Largest type. - * - If the denominator was floating-point, the numerator would be Float64. - * - If not, the numerator would be Decimal (as the denominator is integral). - * - * When the denominator only is Decimal, it will be casted to Float64 as integral / Decimal produces a compile - * time error. - * - * Extended integer types can't be multiplied by doubles (I don't know, why), so we also convert them to - * double. - */ - using Denom = std::conditional_t<(VDecimal && !UDecimal) || IsExtendedIntegral, - Float64, - NextAvgTypeT>; - }; - - using Numerator = typename GetNumDenom::Num; - using Denominator = typename GetNumDenom::Denom; - using Fraction = RationalFraction; - - /// If either Numerator or Denominator are Decimal, the result is also Decimal as everything was checked in - /// GetNumDenom. - using T = std::conditional_t && IsDecimalNumber, - Largest, - std::conditional_t, - Numerator, - std::conditional_t, - Denominator, - bool>>>; // both numerator and denominator are non-decimal. - - using ResultType = std::conditional_t, T, Float64>; - using ResultDataType = std::conditional_t, DataTypeDecimal, DataTypeNumber>; - using ResultVectorType = std::conditional_t, ColumnDecimal, ColumnVector>; -}; - -/** - * @tparam Values The values column type. - * @tparam Weights The weights column type. - */ -template -class AggregateFunctionAvgWeighted final : public AggregateFunctionAvgBase< - Values, Weights, AvgWeightedFunctionTypesDeduction, AggregateFunctionAvgWeighted> +class AggregateFunctionAvgWeighted final : public AggregateFunctionAvgBase { public: - using Base = AggregateFunctionAvgBase< - Values, Weights, AvgWeightedFunctionTypesDeduction, AggregateFunctionAvgWeighted>; - using Base::Base; + using AggregateFunctionAvgBase::AggregateFunctionAvgBase; void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override { - const auto & values = static_cast &>(*columns[0]); - const auto & weights = static_cast &>(*columns[1]); + const auto & values = static_cast &>(*columns[0]); + const auto & weights = static_cast &>(*columns[1]); - using Numerator = typename Base::Numerator; - using Denominator = typename Base::Denominator; - - const Numerator value = Numerator(values.getData()[row_num]); - const Denominator weight = Denominator(weights.getData()[row_num]); + const auto value = values.getData()[row_num]; + const auto weight = weights.getData()[row_num]; this->data(place).numerator += value * weight; this->data(place).denominator += weight; diff --git a/src/AggregateFunctions/AggregateFunctionFactory.h b/src/AggregateFunctions/AggregateFunctionFactory.h index 143e6562a30..07db76d8dd1 100644 --- a/src/AggregateFunctions/AggregateFunctionFactory.h +++ b/src/AggregateFunctions/AggregateFunctionFactory.h @@ -21,7 +21,8 @@ class IDataType; using DataTypePtr = std::shared_ptr; using DataTypes = std::vector; -/** Creator have arguments: name of aggregate function, types of arguments, values of parameters. +/** + * The invoker has arguments: name of aggregate function, types of arguments, values of parameters. * Parameters are for "parametric" aggregate functions. * For example, in quantileWeighted(0.9)(x, weight), 0.9 is "parameter" and x, weight are "arguments". */ @@ -89,7 +90,6 @@ private: std::optional tryGetPropertiesImpl(const String & name, int recursion_level) const; -private: using AggregateFunctions = std::unordered_map; AggregateFunctions aggregate_functions;