updated the decimal template magic

This commit is contained in:
myrrc 2020-11-04 16:14:07 +03:00
parent 3d70ab7f3b
commit 43b2d20314
5 changed files with 110 additions and 60 deletions

View File

@ -34,7 +34,8 @@ AggregateFunctionPtr createAggregateFunctionAvg(const std::string & name, const
AggregateFunctionPtr res;
if (isDecimal(data_type))
res.reset(createWithDecimalType<AggregateFunctionAvg>(*data_type, argument_types));
res.reset(createWithDecimalType<AggregateFunctionAvg>(
*data_type, getDecimalScale(*data_type), argument_types));
else
res.reset(createWithNumericType<AggregateFunctionAvg>(*data_type, argument_types));

View File

@ -1,25 +1,19 @@
#pragma once
#include <type_traits>
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
#include <Columns/ColumnsNumber.h>
#include <DataTypes/DataTypesDecimal.h>
#include <DataTypes/DataTypesNumber.h>
#include <AggregateFunctions/IAggregateFunction.h>
#include "Core/DecimalFunctions.h"
namespace DB
{
template <class Denominator>
struct RationalFraction
{
Float64 numerator{0};
Denominator denominator{0};
/// Allow division by zero as sometimes we need to return NaN.
Float64 NO_SANITIZE_UNDEFINED result() const { return numerator / denominator; }
};
template <class T>
using DecimalOrVectorCol = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>;
template <class T> constexpr bool DecimalOrExtendedInt =
IsDecimalNumber<T>
@ -29,29 +23,78 @@ template <class T> constexpr bool DecimalOrExtendedInt =
|| std::is_same_v<T, UInt256>;
/**
* The discussion showed that the easiest (and simplest) way is to cast both the columns of numerator and denominator
* to Float64. Another way would be to write some template magic that figures out the appropriate numerator
* and denominator (and the resulting type) in favour of extended integral types (UInt128 e.g.) and Decimals (
* which are a mess themselves). The second way is also a bit useless because now Decimals are not used in functions
* like avg.
*
* The ability to explicitly specify the denominator is made for avg (it uses the integral value as the denominator is
* simply the length of the supplied list).
*
* Helper class to encapsulate values conversion for avg and avgWeighted.
*/
template <class Numerator, class Denominator>
struct AvgFraction
{
Numerator numerator{0};
Denominator denominator{0};
static constexpr bool any_is_decimal = IsDecimalNumber<Numerator> || IsDecimalNumber<Denominator>;
/// Allow division by zero as sometimes we need to return NaN.
/// Invoked only is either Numerator or Denominator are Decimal.
std::enable_if_t<any_is_decimal, Float64> NO_SANITIZE_UNDEFINED divide(UInt32 scale) const
{
if constexpr (IsDecimalNumber<Numerator> && IsDecimalNumber<Denominator>)
return DecimalUtils::convertTo<Float64>(numerator / denominator, scale);
/// Numerator is always casted to Float64 to divide correctly if the denominator is not Float64.
const Float64 num_converted = [scale](Numerator n)
{
if constexpr (IsDecimalNumber<Numerator>)
return DecimalUtils::convertTo<Float64>(n, scale);
else
return static_cast<Float64>(n); /// all other types, including extended integral.
} (numerator);
const auto denom_converted = [scale](Denominator d) ->
std::conditional_t<DecimalOrExtendedInt<Denominator>, Float64, Denominator>
{
if constexpr (IsDecimalNumber<Denominator>)
return DecimalUtils::convertTo<Float64>(d, scale);
else if constexpr (DecimalOrExtendedInt<Denominator>)
/// no way to divide Float64 and extended integral type without an explicit cast.
return static_cast<Float64>(d);
else
return d; /// can divide on float, no cast required.
} (denominator);
return num_converted / denom_converted;
}
std::enable_if_t<!any_is_decimal, Float64> NO_SANITIZE_UNDEFINED divide() const
{
if constexpr (DecimalOrExtendedInt<Denominator>) /// if extended int
return static_cast<Float64>(numerator) / static_cast<Float64>(denominator);
else
return static_cast<Float64>(numerator) / denominator;
}
};
/**
* @tparam Derived When deriving from this class, use the child class name as in CRTP, e.g.
* class Self : Agg<char, bool, bool, Self>.
*/
template <class Denominator, class Derived>
template <class Numerator, class Denominator, class Derived>
class AggregateFunctionAvgBase : public
IAggregateFunctionDataHelper<RationalFraction<Denominator>, Derived>
IAggregateFunctionDataHelper<AvgFraction<Numerator, Denominator>, Derived>
{
public:
using Fraction = RationalFraction<Denominator>;
using Fraction = AvgFraction<Numerator, Denominator>;
using Base = IAggregateFunctionDataHelper<Fraction, Derived>;
explicit AggregateFunctionAvgBase(const DataTypes & argument_types_): Base(argument_types_, {}) {}
/// ctor for native types
explicit AggregateFunctionAvgBase(const DataTypes & argument_types_)
: Base(argument_types_, {}), scale(0) {}
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeNumber<Float64>>(); }
/// ctor for Decimals
AggregateFunctionAvgBase(UInt32 scale_, const DataTypes & argument_types_)
: Base(argument_types_, {}), scale(scale_) {}
DataTypePtr getReturnType() const final { return std::make_shared<DataTypeNumber<Float64>>(); }
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
{
@ -81,26 +124,24 @@ public:
void insertResultInto(AggregateDataPtr place, IColumn & to, Arena *) const override
{
static_cast<ColumnVector<Float64> &>(to).getData().push_back(this->data(place).result());
if constexpr (IsDecimalNumber<Numerator> || IsDecimalNumber<Denominator>)
static_cast<ColumnVector<Float64> &>(to).getData().push_back(this->data(place).divide(scale));
else
static_cast<ColumnVector<Float64> &>(to).getData().push_back(this->data(place).divide());
}
private:
UInt32 scale;
};
template <class T>
class AggregateFunctionAvg final : public AggregateFunctionAvgBase<UInt64, AggregateFunctionAvg<T>>
class AggregateFunctionAvg final : public AggregateFunctionAvgBase<T, UInt64, AggregateFunctionAvg<T>>
{
public:
using AggregateFunctionAvgBase<UInt64, AggregateFunctionAvg<T>>::AggregateFunctionAvgBase;
using AggregateFunctionAvgBase<T, UInt64, AggregateFunctionAvg<T>>::AggregateFunctionAvgBase;
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const final
{
if constexpr(IsDecimalNumber<T>)
this->data(place).numerator += columns[0]->getFloat64(row_num);
else if constexpr(DecimalOrExtendedInt<T>)
this->data(place).numerator += static_cast<Float64>(
static_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num]);
else
this->data(place).numerator += static_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num];
this->data(place).numerator += static_cast<const DecimalOrVectorCol<T> &>(*columns[0]).getData()[row_num];
++this->data(place).denominator;
}

View File

@ -76,7 +76,21 @@ AggregateFunctionPtr createAggregateFunctionAvgWeighted(const std::string & name
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
AggregateFunctionPtr ptr;
ptr.reset(create(*data_type, *data_type_weight, argument_types));
const bool left_decimal = isDecimal(data_type);
const bool right_decimal = isDecimal(data_type_weight);
if (left_decimal && right_decimal)
ptr.reset(create(*data_type, *data_type_weight,
getDecimalScale((sizeof(*data_type) > sizeof(*data_type_weight)) ? *data_type : *data_type_weight),
argument_types));
else if (left_decimal)
ptr.reset(create(*data_type, *data_type_weight, getDecimalScale(*data_type), argument_types));
else if (right_decimal)
ptr.reset(create(*data_type, *data_type_weight, getDecimalScale(*data_type_weight), argument_types));
else
ptr.reset(create(*data_type, *data_type_weight, argument_types));
return ptr;
}
}

View File

@ -5,35 +5,28 @@
namespace DB
{
template <class T>
using FieldType = std::conditional_t<IsDecimalNumber<T>,
std::conditional_t<std::is_same_v<T, Decimal256>,
Decimal256, Decimal128>,
NearestFieldType<T>>;
template <class Value, class Weight>
class AggregateFunctionAvgWeighted final :
public AggregateFunctionAvgBase<Float64, AggregateFunctionAvgWeighted<Value, Weight>>
public AggregateFunctionAvgBase<FieldType<Value>, FieldType<Weight>, AggregateFunctionAvgWeighted<Value, Weight>>
{
public:
using Base = AggregateFunctionAvgBase<Float64, AggregateFunctionAvgWeighted<Value, Weight>>;
using Base = AggregateFunctionAvgBase<
FieldType<Value>, FieldType<Weight>, AggregateFunctionAvgWeighted<Value, Weight>>;
using Base::Base;
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
{
const Float64 value = [&columns, row_num] {
if constexpr(IsDecimalNumber<Value>)
return columns[0]->getFloat64(row_num);
else
return static_cast<Float64>(static_cast<const ColumnVector<Value>&>(*columns[0]).getData()[row_num]);
}();
const Value value = static_cast<const DecimalOrVectorCol<Value> &>(*columns[0]).getData()[row_num];
const Weight weight = static_cast<const DecimalOrVectorCol<Weight> &>(*columns[1]).getData()[row_num];
using WeightRet = std::conditional_t<DecimalOrExtendedInt<Weight>, Float64, Weight>;
const WeightRet weight = [&columns, row_num]() -> WeightRet {
if constexpr(IsDecimalNumber<Weight>) /// Unable to cast to double -> use the virtual method
return columns[1]->getFloat64(row_num);
else if constexpr(DecimalOrExtendedInt<Weight>) /// Casting to double, otherwise += would be ambitious.
return static_cast<Float64>(static_cast<const ColumnVector<Weight>&>(*columns[1]).getData()[row_num]);
else
return static_cast<const ColumnVector<Weight>&>(*columns[1]).getData()[row_num];
}();
this->data(place).numerator += weight * value;
this->data(place).denominator += weight;
this->data(place).numerator += static_cast<FieldType<Value>>(value) * weight;
this->data(place).denominator += static_cast<FieldType<Weight>>(weight);
}
String getName() const override { return "avgWeighted"; }

View File

@ -4,7 +4,7 @@
<create_query>CREATE TABLE perf_avg(
num UInt64,
num_u Decimal256(75) DEFAULT toDecimal256(num / 100000, 75),
num_u Decimal256(75) DEFAULT toDecimal256(num / 400000, 75),
num_f Float64 DEFAULT num
) ENGINE = MergeTree() ORDER BY tuple()
</create_query>
@ -13,11 +13,12 @@
INSERT INTO perf_avg(num)
SELECT number / r
FROM system.numbers
ARRAY JOIN range(1, 10000) AS r
LIMIT 5000000
ARRAY JOIN range(1, 400000) AS r
LIMIT 200000000
</fill_query>
<query>SELECT avg(num) FROM perf_avg</query>
<query>SELECT avg(2 * num) FROM perf_avg</query>
<query>SELECT avg(num_u) FROM perf_avg</query>
<query>SELECT avg(num_f) FROM perf_avg</query>
<query>SELECT avgWeighted(num_f, num) FROM perf_avg</query>