mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-10 01:25:21 +00:00
updated the decimal template magic
This commit is contained in:
parent
3d70ab7f3b
commit
43b2d20314
@ -34,7 +34,8 @@ AggregateFunctionPtr createAggregateFunctionAvg(const std::string & name, const
|
||||
AggregateFunctionPtr res;
|
||||
|
||||
if (isDecimal(data_type))
|
||||
res.reset(createWithDecimalType<AggregateFunctionAvg>(*data_type, argument_types));
|
||||
res.reset(createWithDecimalType<AggregateFunctionAvg>(
|
||||
*data_type, getDecimalScale(*data_type), argument_types));
|
||||
else
|
||||
res.reset(createWithNumericType<AggregateFunctionAvg>(*data_type, argument_types));
|
||||
|
||||
|
@ -1,25 +1,19 @@
|
||||
#pragma once
|
||||
|
||||
#include <type_traits>
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <Columns/ColumnsNumber.h>
|
||||
#include <DataTypes/DataTypesDecimal.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
#include "Core/DecimalFunctions.h"
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
template <class Denominator>
|
||||
struct RationalFraction
|
||||
{
|
||||
Float64 numerator{0};
|
||||
Denominator denominator{0};
|
||||
|
||||
/// Allow division by zero as sometimes we need to return NaN.
|
||||
Float64 NO_SANITIZE_UNDEFINED result() const { return numerator / denominator; }
|
||||
};
|
||||
template <class T>
|
||||
using DecimalOrVectorCol = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>;
|
||||
|
||||
template <class T> constexpr bool DecimalOrExtendedInt =
|
||||
IsDecimalNumber<T>
|
||||
@ -29,29 +23,78 @@ template <class T> constexpr bool DecimalOrExtendedInt =
|
||||
|| std::is_same_v<T, UInt256>;
|
||||
|
||||
/**
|
||||
* The discussion showed that the easiest (and simplest) way is to cast both the columns of numerator and denominator
|
||||
* to Float64. Another way would be to write some template magic that figures out the appropriate numerator
|
||||
* and denominator (and the resulting type) in favour of extended integral types (UInt128 e.g.) and Decimals (
|
||||
* which are a mess themselves). The second way is also a bit useless because now Decimals are not used in functions
|
||||
* like avg.
|
||||
*
|
||||
* The ability to explicitly specify the denominator is made for avg (it uses the integral value as the denominator is
|
||||
* simply the length of the supplied list).
|
||||
*
|
||||
* Helper class to encapsulate values conversion for avg and avgWeighted.
|
||||
*/
|
||||
template <class Numerator, class Denominator>
|
||||
struct AvgFraction
|
||||
{
|
||||
Numerator numerator{0};
|
||||
Denominator denominator{0};
|
||||
|
||||
static constexpr bool any_is_decimal = IsDecimalNumber<Numerator> || IsDecimalNumber<Denominator>;
|
||||
|
||||
/// Allow division by zero as sometimes we need to return NaN.
|
||||
/// Invoked only is either Numerator or Denominator are Decimal.
|
||||
std::enable_if_t<any_is_decimal, Float64> NO_SANITIZE_UNDEFINED divide(UInt32 scale) const
|
||||
{
|
||||
if constexpr (IsDecimalNumber<Numerator> && IsDecimalNumber<Denominator>)
|
||||
return DecimalUtils::convertTo<Float64>(numerator / denominator, scale);
|
||||
|
||||
/// Numerator is always casted to Float64 to divide correctly if the denominator is not Float64.
|
||||
const Float64 num_converted = [scale](Numerator n)
|
||||
{
|
||||
if constexpr (IsDecimalNumber<Numerator>)
|
||||
return DecimalUtils::convertTo<Float64>(n, scale);
|
||||
else
|
||||
return static_cast<Float64>(n); /// all other types, including extended integral.
|
||||
} (numerator);
|
||||
|
||||
const auto denom_converted = [scale](Denominator d) ->
|
||||
std::conditional_t<DecimalOrExtendedInt<Denominator>, Float64, Denominator>
|
||||
{
|
||||
if constexpr (IsDecimalNumber<Denominator>)
|
||||
return DecimalUtils::convertTo<Float64>(d, scale);
|
||||
else if constexpr (DecimalOrExtendedInt<Denominator>)
|
||||
/// no way to divide Float64 and extended integral type without an explicit cast.
|
||||
return static_cast<Float64>(d);
|
||||
else
|
||||
return d; /// can divide on float, no cast required.
|
||||
} (denominator);
|
||||
|
||||
return num_converted / denom_converted;
|
||||
}
|
||||
|
||||
std::enable_if_t<!any_is_decimal, Float64> NO_SANITIZE_UNDEFINED divide() const
|
||||
{
|
||||
if constexpr (DecimalOrExtendedInt<Denominator>) /// if extended int
|
||||
return static_cast<Float64>(numerator) / static_cast<Float64>(denominator);
|
||||
else
|
||||
return static_cast<Float64>(numerator) / denominator;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/**
|
||||
* @tparam Derived When deriving from this class, use the child class name as in CRTP, e.g.
|
||||
* class Self : Agg<char, bool, bool, Self>.
|
||||
*/
|
||||
template <class Denominator, class Derived>
|
||||
template <class Numerator, class Denominator, class Derived>
|
||||
class AggregateFunctionAvgBase : public
|
||||
IAggregateFunctionDataHelper<RationalFraction<Denominator>, Derived>
|
||||
IAggregateFunctionDataHelper<AvgFraction<Numerator, Denominator>, Derived>
|
||||
{
|
||||
public:
|
||||
using Fraction = RationalFraction<Denominator>;
|
||||
using Fraction = AvgFraction<Numerator, Denominator>;
|
||||
using Base = IAggregateFunctionDataHelper<Fraction, Derived>;
|
||||
|
||||
explicit AggregateFunctionAvgBase(const DataTypes & argument_types_): Base(argument_types_, {}) {}
|
||||
/// ctor for native types
|
||||
explicit AggregateFunctionAvgBase(const DataTypes & argument_types_)
|
||||
: Base(argument_types_, {}), scale(0) {}
|
||||
|
||||
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeNumber<Float64>>(); }
|
||||
/// ctor for Decimals
|
||||
AggregateFunctionAvgBase(UInt32 scale_, const DataTypes & argument_types_)
|
||||
: Base(argument_types_, {}), scale(scale_) {}
|
||||
|
||||
DataTypePtr getReturnType() const final { return std::make_shared<DataTypeNumber<Float64>>(); }
|
||||
|
||||
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
@ -81,26 +124,24 @@ public:
|
||||
|
||||
void insertResultInto(AggregateDataPtr place, IColumn & to, Arena *) const override
|
||||
{
|
||||
static_cast<ColumnVector<Float64> &>(to).getData().push_back(this->data(place).result());
|
||||
if constexpr (IsDecimalNumber<Numerator> || IsDecimalNumber<Denominator>)
|
||||
static_cast<ColumnVector<Float64> &>(to).getData().push_back(this->data(place).divide(scale));
|
||||
else
|
||||
static_cast<ColumnVector<Float64> &>(to).getData().push_back(this->data(place).divide());
|
||||
}
|
||||
private:
|
||||
UInt32 scale;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
class AggregateFunctionAvg final : public AggregateFunctionAvgBase<UInt64, AggregateFunctionAvg<T>>
|
||||
class AggregateFunctionAvg final : public AggregateFunctionAvgBase<T, UInt64, AggregateFunctionAvg<T>>
|
||||
{
|
||||
public:
|
||||
using AggregateFunctionAvgBase<UInt64, AggregateFunctionAvg<T>>::AggregateFunctionAvgBase;
|
||||
using AggregateFunctionAvgBase<T, UInt64, AggregateFunctionAvg<T>>::AggregateFunctionAvgBase;
|
||||
|
||||
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const final
|
||||
{
|
||||
if constexpr(IsDecimalNumber<T>)
|
||||
this->data(place).numerator += columns[0]->getFloat64(row_num);
|
||||
else if constexpr(DecimalOrExtendedInt<T>)
|
||||
this->data(place).numerator += static_cast<Float64>(
|
||||
static_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num]);
|
||||
else
|
||||
this->data(place).numerator += static_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num];
|
||||
|
||||
this->data(place).numerator += static_cast<const DecimalOrVectorCol<T> &>(*columns[0]).getData()[row_num];
|
||||
++this->data(place).denominator;
|
||||
}
|
||||
|
||||
|
@ -76,7 +76,21 @@ AggregateFunctionPtr createAggregateFunctionAvgWeighted(const std::string & name
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
AggregateFunctionPtr ptr;
|
||||
ptr.reset(create(*data_type, *data_type_weight, argument_types));
|
||||
|
||||
const bool left_decimal = isDecimal(data_type);
|
||||
const bool right_decimal = isDecimal(data_type_weight);
|
||||
|
||||
if (left_decimal && right_decimal)
|
||||
ptr.reset(create(*data_type, *data_type_weight,
|
||||
getDecimalScale((sizeof(*data_type) > sizeof(*data_type_weight)) ? *data_type : *data_type_weight),
|
||||
argument_types));
|
||||
else if (left_decimal)
|
||||
ptr.reset(create(*data_type, *data_type_weight, getDecimalScale(*data_type), argument_types));
|
||||
else if (right_decimal)
|
||||
ptr.reset(create(*data_type, *data_type_weight, getDecimalScale(*data_type_weight), argument_types));
|
||||
else
|
||||
ptr.reset(create(*data_type, *data_type_weight, argument_types));
|
||||
|
||||
return ptr;
|
||||
}
|
||||
}
|
||||
|
@ -5,35 +5,28 @@
|
||||
|
||||
namespace DB
|
||||
{
|
||||
template <class T>
|
||||
using FieldType = std::conditional_t<IsDecimalNumber<T>,
|
||||
std::conditional_t<std::is_same_v<T, Decimal256>,
|
||||
Decimal256, Decimal128>,
|
||||
NearestFieldType<T>>;
|
||||
|
||||
template <class Value, class Weight>
|
||||
class AggregateFunctionAvgWeighted final :
|
||||
public AggregateFunctionAvgBase<Float64, AggregateFunctionAvgWeighted<Value, Weight>>
|
||||
public AggregateFunctionAvgBase<FieldType<Value>, FieldType<Weight>, AggregateFunctionAvgWeighted<Value, Weight>>
|
||||
{
|
||||
public:
|
||||
using Base = AggregateFunctionAvgBase<Float64, AggregateFunctionAvgWeighted<Value, Weight>>;
|
||||
using Base = AggregateFunctionAvgBase<
|
||||
FieldType<Value>, FieldType<Weight>, AggregateFunctionAvgWeighted<Value, Weight>>;
|
||||
using Base::Base;
|
||||
|
||||
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
{
|
||||
const Float64 value = [&columns, row_num] {
|
||||
if constexpr(IsDecimalNumber<Value>)
|
||||
return columns[0]->getFloat64(row_num);
|
||||
else
|
||||
return static_cast<Float64>(static_cast<const ColumnVector<Value>&>(*columns[0]).getData()[row_num]);
|
||||
}();
|
||||
const Value value = static_cast<const DecimalOrVectorCol<Value> &>(*columns[0]).getData()[row_num];
|
||||
const Weight weight = static_cast<const DecimalOrVectorCol<Weight> &>(*columns[1]).getData()[row_num];
|
||||
|
||||
using WeightRet = std::conditional_t<DecimalOrExtendedInt<Weight>, Float64, Weight>;
|
||||
const WeightRet weight = [&columns, row_num]() -> WeightRet {
|
||||
if constexpr(IsDecimalNumber<Weight>) /// Unable to cast to double -> use the virtual method
|
||||
return columns[1]->getFloat64(row_num);
|
||||
else if constexpr(DecimalOrExtendedInt<Weight>) /// Casting to double, otherwise += would be ambitious.
|
||||
return static_cast<Float64>(static_cast<const ColumnVector<Weight>&>(*columns[1]).getData()[row_num]);
|
||||
else
|
||||
return static_cast<const ColumnVector<Weight>&>(*columns[1]).getData()[row_num];
|
||||
}();
|
||||
|
||||
this->data(place).numerator += weight * value;
|
||||
this->data(place).denominator += weight;
|
||||
this->data(place).numerator += static_cast<FieldType<Value>>(value) * weight;
|
||||
this->data(place).denominator += static_cast<FieldType<Weight>>(weight);
|
||||
}
|
||||
|
||||
String getName() const override { return "avgWeighted"; }
|
||||
|
@ -4,7 +4,7 @@
|
||||
|
||||
<create_query>CREATE TABLE perf_avg(
|
||||
num UInt64,
|
||||
num_u Decimal256(75) DEFAULT toDecimal256(num / 100000, 75),
|
||||
num_u Decimal256(75) DEFAULT toDecimal256(num / 400000, 75),
|
||||
num_f Float64 DEFAULT num
|
||||
) ENGINE = MergeTree() ORDER BY tuple()
|
||||
</create_query>
|
||||
@ -13,11 +13,12 @@
|
||||
INSERT INTO perf_avg(num)
|
||||
SELECT number / r
|
||||
FROM system.numbers
|
||||
ARRAY JOIN range(1, 10000) AS r
|
||||
LIMIT 5000000
|
||||
ARRAY JOIN range(1, 400000) AS r
|
||||
LIMIT 200000000
|
||||
</fill_query>
|
||||
|
||||
<query>SELECT avg(num) FROM perf_avg</query>
|
||||
<query>SELECT avg(2 * num) FROM perf_avg</query>
|
||||
<query>SELECT avg(num_u) FROM perf_avg</query>
|
||||
<query>SELECT avg(num_f) FROM perf_avg</query>
|
||||
<query>SELECT avgWeighted(num_f, num) FROM perf_avg</query>
|
||||
|
Loading…
Reference in New Issue
Block a user