fixed decimal scales calc, updated the tests

This commit is contained in:
myrrc 2020-11-24 17:07:59 +03:00
parent fbb0e6e6aa
commit 420f2489a7
5 changed files with 25 additions and 12 deletions

View File

@ -33,24 +33,26 @@ struct AvgFraction
/// Allow division by zero as sometimes we need to return NaN. /// Allow division by zero as sometimes we need to return NaN.
/// Invoked only is either Numerator or Denominator are Decimal. /// Invoked only is either Numerator or Denominator are Decimal.
Float64 NO_SANITIZE_UNDEFINED divideIfAnyDecimal(UInt32 scale) const Float64 NO_SANITIZE_UNDEFINED divideIfAnyDecimal(UInt32 num_scale, UInt32 denom_scale) const
{ {
if constexpr (IsDecimalNumber<Numerator> && IsDecimalNumber<Denominator>) if constexpr (IsDecimalNumber<Numerator> && IsDecimalNumber<Denominator>)
{ {
const UInt32 result_scale = std::max(num_scale, denom_scale);
if constexpr (std::is_same_v<Numerator, Decimal256> && std::is_same_v<Denominator, Decimal128>) if constexpr (std::is_same_v<Numerator, Decimal256> && std::is_same_v<Denominator, Decimal128>)
///Special case as Decimal256 / Decimal128 = compile error (as Decimal128 is not parametrized by a wide ///Special case as Decimal256 / Decimal128 = compile error (as Decimal128 is not parametrized by a wide
///int), but an __int128 instead ///int), but an __int128 instead
return DecimalUtils::convertTo<Float64>( return DecimalUtils::convertTo<Float64>(
numerator / (denominator.template convertTo<Decimal256>()), scale); numerator / (denominator.template convertTo<Decimal256>()), result_scale);
else else
return DecimalUtils::convertTo<Float64>(numerator / denominator, scale); return DecimalUtils::convertTo<Float64>(numerator / denominator, result_scale);
} }
/// Numerator is always casted to Float64 to divide correctly if the denominator is not Float64. /// Numerator is always casted to Float64 to divide correctly if the denominator is not Float64.
Float64 num_converted; Float64 num_converted;
if constexpr (IsDecimalNumber<Numerator>) if constexpr (IsDecimalNumber<Numerator>)
num_converted = DecimalUtils::convertTo<Float64>(numerator, scale); num_converted = DecimalUtils::convertTo<Float64>(numerator, num_scale);
else else
num_converted = static_cast<Float64>(numerator); /// all other types, including extended integral. num_converted = static_cast<Float64>(numerator); /// all other types, including extended integral.
@ -58,7 +60,7 @@ struct AvgFraction
Float64, Denominator> denom_converted; Float64, Denominator> denom_converted;
if constexpr (IsDecimalNumber<Denominator>) if constexpr (IsDecimalNumber<Denominator>)
denom_converted = DecimalUtils::convertTo<Float64>(denominator, scale); denom_converted = DecimalUtils::convertTo<Float64>(denominator, denom_scale);
else if constexpr (DecimalOrExtendedInt<Denominator>) else if constexpr (DecimalOrExtendedInt<Denominator>)
/// no way to divide Float64 and extended integral type without an explicit cast. /// no way to divide Float64 and extended integral type without an explicit cast.
denom_converted = static_cast<Float64>(denominator); denom_converted = static_cast<Float64>(denominator);
@ -90,8 +92,9 @@ public:
using Fraction = AvgFraction<Numerator, Denominator>; using Fraction = AvgFraction<Numerator, Denominator>;
using Base = IAggregateFunctionDataHelper<Fraction, Derived>; using Base = IAggregateFunctionDataHelper<Fraction, Derived>;
explicit AggregateFunctionAvgBase(const DataTypes & argument_types_, UInt32 scale_ = 0) explicit AggregateFunctionAvgBase(const DataTypes & argument_types_,
: Base(argument_types_, {}), scale(scale_) {} UInt32 num_scale_ = 0, UInt32 denom_scale_ = 0)
: Base(argument_types_, {}), num_scale(num_scale_), denom_scale(denom_scale_) {}
DataTypePtr getReturnType() const final { return std::make_shared<DataTypeNumber<Float64>>(); } DataTypePtr getReturnType() const final { return std::make_shared<DataTypeNumber<Float64>>(); }
@ -124,12 +127,14 @@ public:
void insertResultInto(AggregateDataPtr place, IColumn & to, Arena *) const override void insertResultInto(AggregateDataPtr place, IColumn & to, Arena *) const override
{ {
if constexpr (IsDecimalNumber<Numerator> || IsDecimalNumber<Denominator>) if constexpr (IsDecimalNumber<Numerator> || IsDecimalNumber<Denominator>)
static_cast<ColumnVector<Float64> &>(to).getData().push_back(this->data(place).divideIfAnyDecimal(scale)); static_cast<ColumnVector<Float64> &>(to).getData().push_back(
this->data(place).divideIfAnyDecimal(num_scale, denom_scale));
else else
static_cast<ColumnVector<Float64> &>(to).getData().push_back(this->data(place).divide()); static_cast<ColumnVector<Float64> &>(to).getData().push_back(this->data(place).divide());
} }
private: private:
UInt32 scale; UInt32 num_scale;
UInt32 denom_scale;
}; };
template <class T> template <class T>

View File

@ -83,11 +83,14 @@ AggregateFunctionPtr createAggregateFunctionAvgWeighted(const std::string & name
if (left_decimal && right_decimal) if (left_decimal && right_decimal)
ptr.reset(create(*data_type, *data_type_weight, ptr.reset(create(*data_type, *data_type_weight,
argument_types, argument_types,
getDecimalScale((sizeof(*data_type) > sizeof(*data_type_weight)) ? *data_type : *data_type_weight))); getDecimalScale(*data_type), getDecimalScale(*data_type_weight)));
else if (left_decimal) else if (left_decimal)
ptr.reset(create(*data_type, *data_type_weight, argument_types, getDecimalScale(*data_type))); ptr.reset(create(*data_type, *data_type_weight, argument_types,
getDecimalScale(*data_type)));
else if (right_decimal) else if (right_decimal)
ptr.reset(create(*data_type, *data_type_weight, argument_types, getDecimalScale(*data_type_weight))); ptr.reset(create(*data_type, *data_type_weight, argument_types,
// numerator is not decimal, so its scale is 0
0, getDecimalScale(*data_type_weight)));
else else
ptr.reset(create(*data_type, *data_type_weight, argument_types)); ptr.reset(create(*data_type, *data_type_weight, argument_types));

View File

@ -20,6 +20,7 @@
INSERT INTO perf_avg(num) INSERT INTO perf_avg(num)
SELECT toUInt64(UserID / (WatchID + 1) * 1000000) SELECT toUInt64(UserID / (WatchID + 1) * 1000000)
FROM hits_100m_single FROM hits_100m_single
LIMIT 50000000
</fill_query> </fill_query>
<query>SELECT avg(num) FROM perf_avg FORMAT Null</query> <query>SELECT avg(num) FROM perf_avg FORMAT Null</query>

View File

@ -1,5 +1,7 @@
2.3333333333333335 2.3333333333333335
nan nan
1.0
1.0
8 8
nan nan
8 8

View File

@ -5,6 +5,8 @@ CUR_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
${CLICKHOUSE_CLIENT} --query="SELECT avgWeighted(x, weight) FROM (SELECT t.1 AS x, t.2 AS weight FROM (SELECT arrayJoin([(1, 5), (2, 4), (3, 3), (4, 2), (5, 1)]) AS t));" ${CLICKHOUSE_CLIENT} --query="SELECT avgWeighted(x, weight) FROM (SELECT t.1 AS x, t.2 AS weight FROM (SELECT arrayJoin([(1, 5), (2, 4), (3, 3), (4, 2), (5, 1)]) AS t));"
${CLICKHOUSE_CLIENT} --query="SELECT avgWeighted(x, weight) FROM (SELECT t.1 AS x, t.2 AS weight FROM (SELECT arrayJoin([(1, 0), (2, 0), (3, 0), (4, 0), (5, 0)]) AS t));" ${CLICKHOUSE_CLIENT} --query="SELECT avgWeighted(x, weight) FROM (SELECT t.1 AS x, t.2 AS weight FROM (SELECT arrayJoin([(1, 0), (2, 0), (3, 0), (4, 0), (5, 0)]) AS t));"
${CLICKHOUSE_CLIENT} --query="SELECT avgWeighted(x, y) FROM (select toDecimal256(1, 0) x, toDecimal256(1, 1) y);"
${CLICKHOUSE_CLIENT} --query="SELECT avgWeighted(x, y) FROM (select toDecimal32(1, 0) x, toDecimal256(1, 1) y);"
types=("Int8" "Int16" "Int32" "Int64" "UInt8" "UInt16" "UInt32" "UInt64" "Float32" "Float64") types=("Int8" "Int16" "Int32" "Int64" "UInt8" "UInt16" "UInt32" "UInt64" "Float32" "Float64")