Merge pull request #3129 from 4ertus2/decimal

var, stddev and math functions for decimal [CLICKHOUSE-3765]
This commit is contained in:
alexey-milovidov 2018-09-14 21:12:37 +03:00 committed by GitHub
commit 685560134a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 410 additions and 189 deletions

View File

@ -2,13 +2,17 @@
#include <cmath>
#include <common/arithmeticOverflow.h>
#include <IO/WriteHelpers.h>
#include <IO/ReadHelpers.h>
#include <AggregateFunctions/IAggregateFunction.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypesDecimal.h>
#include <Columns/ColumnVector.h>
#include <Columns/ColumnDecimal.h>
/** This is simple, not numerically stable
@ -26,17 +30,11 @@
namespace DB
{
enum class VarianceMode
namespace ErrorCodes
{
Population,
Sample
};
enum class VariancePower
{
Original,
Sqrt
};
extern const int LOGICAL_ERROR;
extern const int DECIMAL_OVERFLOW;
}
template <typename T>
@ -70,15 +68,74 @@ struct VarMoments
readPODBinary(*this, buf);
}
template <VarianceMode mode, VariancePower power>
T get() const
T getPopulation() const
{
if (m0 == 0 && mode == VarianceMode::Sample)
return (m2 - m1 * m1 / m0) / m0;
}
T getSample() const
{
if (m0 == 0)
return std::numeric_limits<T>::quiet_NaN();
return (m2 - m1 * m1 / m0) / (m0 - 1);
}
T get() const { throw Exception("Unexpected call", ErrorCodes::LOGICAL_ERROR); }
};
template <typename T>
struct VarMomentsDecimal
{
using NativeType = typename T::NativeType;
UInt64 m0{};
NativeType m1{};
NativeType m2{};
void add(NativeType x)
{
++m0;
m1 += x;
NativeType tmp; /// scale' = 2 * scale
if (common::mulOverflow(x, x, tmp) || common::addOverflow(m2, tmp, m2))
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
}
void merge(const VarMomentsDecimal & rhs)
{
m0 += rhs.m0;
m1 += rhs.m1;
if (common::addOverflow(m2, rhs.m2, m2))
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
}
void write(WriteBuffer & buf) const { writePODBinary(*this, buf); }
void read(ReadBuffer & buf) { readPODBinary(*this, buf); }
Float64 getPopulation(UInt32 scale) const
{
NativeType tmp;
if (common::mulOverflow(m1, m1, tmp) ||
common::subOverflow(m2, NativeType(tmp/m0), tmp))
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
return convertFromDecimal<DataTypeDecimal<T>, DataTypeNumber<Float64>>(tmp / m0, scale);
}
Float64 getSample(UInt32 scale) const
{
if (m0 == 0)
return std::numeric_limits<T>::quiet_NaN();
T res = (m2 - m1 * m1 / m0) / (m0 - (mode == VarianceMode::Sample));
return power == VariancePower::Original ? res : sqrt(res);
NativeType tmp;
if (common::mulOverflow(m1, m1, tmp) ||
common::subOverflow(m2, NativeType(tmp/m0), tmp))
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
return convertFromDecimal<DataTypeDecimal<T>, DataTypeNumber<Float64>>(tmp / (m0 - 1), scale);
}
Float64 get() const { throw Exception("Unexpected call", ErrorCodes::LOGICAL_ERROR); }
};
template <typename T>
@ -115,14 +172,19 @@ struct CovarMoments
readPODBinary(*this, buf);
}
template <VarianceMode mode>
T get() const
T getPopulation() const
{
if (m0 == 0 && mode == VarianceMode::Sample)
return std::numeric_limits<T>::quiet_NaN();
return (xy - x1 * y1 / m0) / (m0 - (mode == VarianceMode::Sample));
return (xy - x1 * y1 / m0) / m0;
}
T getSample() const
{
if (m0 == 0)
return std::numeric_limits<T>::quiet_NaN();
return (xy - x1 * y1 / m0) / (m0 - 1);
}
T get() const { throw Exception("Unexpected call", ErrorCodes::LOGICAL_ERROR); }
};
template <typename T>
@ -169,6 +231,9 @@ struct CorrMoments
{
return (m0 * xy - x1 * y1) / sqrt((m0 * x2 - x1 * x1) * (m0 * y2 - y1 * y1));
}
T getPopulation() const { throw Exception("Unexpected call", ErrorCodes::LOGICAL_ERROR); }
T getSample() const { throw Exception("Unexpected call", ErrorCodes::LOGICAL_ERROR); }
};
@ -181,20 +246,54 @@ enum class StatisticsFunctionKind
};
template <typename T1, typename T2>
using VarianceCalcType = std::conditional_t<std::is_same_v<T1, Float32> && std::is_same_v<T2, Float32>, Float32, Float64>;
template <typename T1, typename T2, typename Data, StatisticsFunctionKind Kind>
class AggregateFunctionVarianceSimple final
: public IAggregateFunctionDataHelper<Data, AggregateFunctionVarianceSimple<T1, T2, Data, Kind>>
template <typename T, StatisticsFunctionKind _kind>
struct StatFuncOneArg
{
using ResultType = VarianceCalcType<T1, T2>;
using Type1 = T;
using Type2 = T;
using ResultType = std::conditional_t<std::is_same_v<T, Float32>, Float32, Float64>;
using Data = std::conditional_t<IsDecimalNumber<T>, VarMomentsDecimal<Decimal128>, VarMoments<ResultType>>;
static constexpr StatisticsFunctionKind kind = _kind;
static constexpr UInt32 num_args = 1;
};
template <typename T1, typename T2, StatisticsFunctionKind _kind>
struct StatFuncTwoArg
{
using Type1 = T1;
using Type2 = T2;
using ResultType = std::conditional_t<std::is_same_v<T1, T2> && std::is_same_v<T1, Float32>, Float32, Float64>;
using Data = std::conditional_t<_kind == StatisticsFunctionKind::corr, CorrMoments<ResultType>, CovarMoments<ResultType>>;
static constexpr StatisticsFunctionKind kind = _kind;
static constexpr UInt32 num_args = 2;
};
template <typename StatFunc>
class AggregateFunctionVarianceSimple final
: public IAggregateFunctionDataHelper<typename StatFunc::Data, AggregateFunctionVarianceSimple<StatFunc>>
{
public:
using T1 = typename StatFunc::Type1;
using T2 = typename StatFunc::Type2;
using ColVecT1 = std::conditional_t<IsDecimalNumber<T1>, ColumnDecimal<T1>, ColumnVector<T1>>;
using ColVecT2 = std::conditional_t<IsDecimalNumber<T2>, ColumnDecimal<T2>, ColumnVector<T2>>;
using ResultType = typename StatFunc::ResultType;
using ColVecResult = ColumnVector<ResultType>;
AggregateFunctionVarianceSimple()
: src_scale(0)
{}
AggregateFunctionVarianceSimple(const IDataType & data_type)
: src_scale(getDecimalScale(data_type))
{}
String getName() const override
{
switch (Kind)
switch (StatFunc::kind)
{
case StatisticsFunctionKind::varPop: return "varPop";
case StatisticsFunctionKind::varSamp: return "varSamp";
@ -214,13 +313,13 @@ public:
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
{
if constexpr (Kind == StatisticsFunctionKind::covarPop || Kind == StatisticsFunctionKind::covarSamp || Kind == StatisticsFunctionKind::corr)
if constexpr (StatFunc::num_args == 2)
this->data(place).add(
static_cast<const ColumnVector<T1> &>(*columns[0]).getData()[row_num],
static_cast<const ColumnVector<T2> &>(*columns[1]).getData()[row_num]);
static_cast<const ColVecT1 &>(*columns[0]).getData()[row_num],
static_cast<const ColVecT2 &>(*columns[1]).getData()[row_num]);
else
this->data(place).add(
static_cast<const ColumnVector<T1> &>(*columns[0]).getData()[row_num]);
static_cast<const ColVecT1 &>(*columns[0]).getData()[row_num]);
}
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
@ -241,27 +340,48 @@ public:
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
const auto & data = this->data(place);
auto & dst = static_cast<ColumnVector<ResultType> &>(to).getData();
auto & dst = static_cast<ColVecResult &>(to).getData();
if constexpr (Kind == StatisticsFunctionKind::varPop) dst.push_back(data.template get<VarianceMode::Population, VariancePower::Original>());
else if constexpr (Kind == StatisticsFunctionKind::varSamp) dst.push_back(data.template get<VarianceMode::Sample, VariancePower::Original>());
else if constexpr (Kind == StatisticsFunctionKind::stddevPop) dst.push_back(data.template get<VarianceMode::Population, VariancePower::Sqrt>());
else if constexpr (Kind == StatisticsFunctionKind::stddevSamp) dst.push_back(data.template get<VarianceMode::Sample, VariancePower::Sqrt>());
else if constexpr (Kind == StatisticsFunctionKind::covarPop) dst.push_back(data.template get<VarianceMode::Population>());
else if constexpr (Kind == StatisticsFunctionKind::covarSamp) dst.push_back(data.template get<VarianceMode::Sample>());
else if constexpr (Kind == StatisticsFunctionKind::corr) dst.push_back(data.get());
if constexpr (IsDecimalNumber<T1>)
{
switch (StatFunc::kind)
{
case StatisticsFunctionKind::varPop: dst.push_back(data.getPopulation(src_scale * 2)); break;
case StatisticsFunctionKind::varSamp: dst.push_back(data.getSample(src_scale * 2)); break;
case StatisticsFunctionKind::stddevPop: dst.push_back(sqrt(data.getPopulation(src_scale * 2))); break;
case StatisticsFunctionKind::stddevSamp: dst.push_back(sqrt(data.getSample(src_scale * 2))); break;
default:
__builtin_unreachable();
}
}
else
{
switch (StatFunc::kind)
{
case StatisticsFunctionKind::varPop: dst.push_back(data.getPopulation()); break;
case StatisticsFunctionKind::varSamp: dst.push_back(data.getSample()); break;
case StatisticsFunctionKind::stddevPop: dst.push_back(sqrt(data.getPopulation())); break;
case StatisticsFunctionKind::stddevSamp: dst.push_back(sqrt(data.getSample())); break;
case StatisticsFunctionKind::covarPop: dst.push_back(data.getPopulation()); break;
case StatisticsFunctionKind::covarSamp: dst.push_back(data.getSample()); break;
case StatisticsFunctionKind::corr: dst.push_back(data.get()); break;
}
}
}
const char * getHeaderFilePath() const override { return __FILE__; }
private:
UInt32 src_scale;
};
template <typename T> using AggregateFunctionVarPopSimple = AggregateFunctionVarianceSimple<T, T, VarMoments<VarianceCalcType<T, T>>, StatisticsFunctionKind::varPop>;
template <typename T> using AggregateFunctionVarSampSimple = AggregateFunctionVarianceSimple<T, T, VarMoments<VarianceCalcType<T, T>>, StatisticsFunctionKind::varSamp>;
template <typename T> using AggregateFunctionStddevPopSimple = AggregateFunctionVarianceSimple<T, T, VarMoments<VarianceCalcType<T, T>>, StatisticsFunctionKind::stddevPop>;
template <typename T> using AggregateFunctionStddevSampSimple = AggregateFunctionVarianceSimple<T, T, VarMoments<VarianceCalcType<T, T>>, StatisticsFunctionKind::stddevSamp>;
template <typename T1, typename T2> using AggregateFunctionCovarPopSimple = AggregateFunctionVarianceSimple<T1, T2, CovarMoments<VarianceCalcType<T1, T2>>, StatisticsFunctionKind::covarPop>;
template <typename T1, typename T2> using AggregateFunctionCovarSampSimple = AggregateFunctionVarianceSimple<T1, T2, CovarMoments<VarianceCalcType<T1, T2>>, StatisticsFunctionKind::covarSamp>;
template <typename T1, typename T2> using AggregateFunctionCorrSimple = AggregateFunctionVarianceSimple<T1, T2, CorrMoments<VarianceCalcType<T1, T2>>, StatisticsFunctionKind::corr>;
template <typename T> using AggregateFunctionVarPopSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::varPop>>;
template <typename T> using AggregateFunctionVarSampSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::varSamp>>;
template <typename T> using AggregateFunctionStddevPopSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::stddevPop>>;
template <typename T> using AggregateFunctionStddevSampSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::stddevSamp>>;
template <typename T1, typename T2> using AggregateFunctionCovarPopSimple = AggregateFunctionVarianceSimple<StatFuncTwoArg<T1, T2, StatisticsFunctionKind::covarPop>>;
template <typename T1, typename T2> using AggregateFunctionCovarSampSimple = AggregateFunctionVarianceSimple<StatFuncTwoArg<T1, T2, StatisticsFunctionKind::covarSamp>>;
template <typename T1, typename T2> using AggregateFunctionCorrSimple = AggregateFunctionVarianceSimple<StatFuncTwoArg<T1, T2, StatisticsFunctionKind::corr>>;
}

View File

@ -21,11 +21,18 @@ AggregateFunctionPtr createAggregateFunctionStatisticsUnary(const std::string &
assertNoParameters(name, parameters);
assertUnary(name, argument_types);
AggregateFunctionPtr res(createWithNumericType<FunctionTemplate>(*argument_types[0]));
AggregateFunctionPtr res;
DataTypePtr data_type = argument_types[0];
if (isDecimal(data_type))
{
res.reset(createWithDecimalType<FunctionTemplate>(*data_type));
}
else
res.reset(createWithNumericType<FunctionTemplate>(*data_type));
if (!res)
throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name,
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return res;
}
@ -51,6 +58,7 @@ void registerAggregateFunctionsStatisticsSimple(AggregateFunctionFactory & facto
factory.registerFunction("varPop", createAggregateFunctionStatisticsUnary<AggregateFunctionVarPopSimple>);
factory.registerFunction("stddevSamp", createAggregateFunctionStatisticsUnary<AggregateFunctionStddevSampSimple>);
factory.registerFunction("stddevPop", createAggregateFunctionStatisticsUnary<AggregateFunctionStddevPopSimple>);
factory.registerFunction("covarSamp", createAggregateFunctionStatisticsBinary<AggregateFunctionCovarSampSimple>);
factory.registerFunction("covarPop", createAggregateFunctionStatisticsBinary<AggregateFunctionCovarPopSimple>);
factory.registerFunction("corr", createAggregateFunctionStatisticsBinary<AggregateFunctionCorrSimple>, AggregateFunctionFactory::CaseInsensitive);

View File

@ -14,7 +14,9 @@ struct TypePair
using RightType = U;
};
template <typename T, bool _int, bool _dec, bool _float, typename F>
template <typename T, bool _int, bool _float, bool _dec, typename F>
bool callOnBasicType(TypeIndex number, F && f)
{
if constexpr (_int)
@ -65,24 +67,24 @@ bool callOnBasicType(TypeIndex number, F && f)
}
/// Unroll template using TypeIndex
template <typename F, bool _int = true, bool _dec = true, bool _float = false>
template <bool _int, bool _float, bool _dec, typename F>
inline bool callOnBasicTypes(TypeIndex type_num1, TypeIndex type_num2, F && f)
{
if constexpr (_int)
{
switch (type_num1)
{
case TypeIndex::UInt8: return callOnBasicType<UInt8, _int, _dec, _float>(type_num2, std::forward<F>(f));
case TypeIndex::UInt16: return callOnBasicType<UInt16, _int, _dec, _float>(type_num2, std::forward<F>(f));
case TypeIndex::UInt32: return callOnBasicType<UInt32, _int, _dec, _float>(type_num2, std::forward<F>(f));
case TypeIndex::UInt64: return callOnBasicType<UInt64, _int, _dec, _float>(type_num2, std::forward<F>(f));
//case TypeIndex::UInt128: return callOnBasicType<UInt128, _int, _dec, _float>(type_num2, std::forward<F>(f));
case TypeIndex::UInt8: return callOnBasicType<UInt8, _int, _float, _dec>(type_num2, std::forward<F>(f));
case TypeIndex::UInt16: return callOnBasicType<UInt16, _int, _float, _dec>(type_num2, std::forward<F>(f));
case TypeIndex::UInt32: return callOnBasicType<UInt32, _int, _float, _dec>(type_num2, std::forward<F>(f));
case TypeIndex::UInt64: return callOnBasicType<UInt64, _int, _float, _dec>(type_num2, std::forward<F>(f));
//case TypeIndex::UInt128: return callOnBasicType<UInt128, _int, _float, _dec>(type_num2, std::forward<F>(f));
case TypeIndex::Int8: return callOnBasicType<Int8, _int, _dec, _float>(type_num2, std::forward<F>(f));
case TypeIndex::Int16: return callOnBasicType<Int16, _int, _dec, _float>(type_num2, std::forward<F>(f));
case TypeIndex::Int32: return callOnBasicType<Int32, _int, _dec, _float>(type_num2, std::forward<F>(f));
case TypeIndex::Int64: return callOnBasicType<Int64, _int, _dec, _float>(type_num2, std::forward<F>(f));
case TypeIndex::Int128: return callOnBasicType<Int128, _int, _dec, _float>(type_num2, std::forward<F>(f));
case TypeIndex::Int8: return callOnBasicType<Int8, _int, _float, _dec>(type_num2, std::forward<F>(f));
case TypeIndex::Int16: return callOnBasicType<Int16, _int, _float, _dec>(type_num2, std::forward<F>(f));
case TypeIndex::Int32: return callOnBasicType<Int32, _int, _float, _dec>(type_num2, std::forward<F>(f));
case TypeIndex::Int64: return callOnBasicType<Int64, _int, _float, _dec>(type_num2, std::forward<F>(f));
case TypeIndex::Int128: return callOnBasicType<Int128, _int, _float, _dec>(type_num2, std::forward<F>(f));
default:
break;
}
@ -92,9 +94,9 @@ inline bool callOnBasicTypes(TypeIndex type_num1, TypeIndex type_num2, F && f)
{
switch (type_num1)
{
case TypeIndex::Decimal32: return callOnBasicType<Decimal32, _int, _dec, _float>(type_num2, std::forward<F>(f));
case TypeIndex::Decimal64: return callOnBasicType<Decimal64, _int, _dec, _float>(type_num2, std::forward<F>(f));
case TypeIndex::Decimal128: return callOnBasicType<Decimal128, _int, _dec, _float>(type_num2, std::forward<F>(f));
case TypeIndex::Decimal32: return callOnBasicType<Decimal32, _int, _float, _dec>(type_num2, std::forward<F>(f));
case TypeIndex::Decimal64: return callOnBasicType<Decimal64, _int, _float, _dec>(type_num2, std::forward<F>(f));
case TypeIndex::Decimal128: return callOnBasicType<Decimal128, _int, _float, _dec>(type_num2, std::forward<F>(f));
default:
break;
}
@ -104,8 +106,8 @@ inline bool callOnBasicTypes(TypeIndex type_num1, TypeIndex type_num2, F && f)
{
switch (type_num1)
{
case TypeIndex::Float32: return callOnBasicType<Float32, _int, _dec, _float>(type_num2, std::forward<F>(f));
case TypeIndex::Float64: return callOnBasicType<Float64, _int, _dec, _float>(type_num2, std::forward<F>(f));
case TypeIndex::Float32: return callOnBasicType<Float32, _int, _float, _dec>(type_num2, std::forward<F>(f));
case TypeIndex::Float64: return callOnBasicType<Float64, _int, _float, _dec>(type_num2, std::forward<F>(f));
default:
break;
}

View File

@ -45,7 +45,7 @@ public:
/// Name of data type family (example: FixedString, Array).
virtual const char * getFamilyName() const = 0;
/// Unique type number or zero
/// Data type id. It's used for runtime type checks.
virtual TypeIndex getTypeId() const = 0;
/** Binary serialization for range of values in column - for writing to disk/network, etc.

View File

@ -725,7 +725,7 @@ private:
return true;
};
if (!callOnBasicTypes(left_number, right_number, call))
if (!callOnBasicTypes<true, false, true>(left_number, right_number, call))
throw Exception("Wrong call for " + getName() + " with " + col_left.type->getName() + " and " + col_right.type->getName(),
ErrorCodes::LOGICAL_ERROR);
}

View File

@ -1,13 +1,15 @@
#pragma once
#include <common/preciseExp10.h>
#include <Core/callOnTypeIndex.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypesDecimal.h>
#include <Columns/ColumnsNumber.h>
#include <Columns/ColumnDecimal.h>
#include <Columns/ColumnConst.h>
#include <Functions/IFunction.h>
#include <Functions/FunctionHelpers.h>
#include <Common/config.h>
#include <Common/typeid_cast.h>
/** More efficient implementations of mathematical functions are possible when using a separate library.
* Disabled due to licence compatibility limitations.
@ -78,68 +80,91 @@ private:
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!isNumber(arguments.front()))
throw Exception{"Illegal type " + arguments.front()->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
const auto & arg = arguments.front();
if (!isNumber(arg) && !isDecimal(arg))
throw Exception{"Illegal type " + arg->getName() + " of argument of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
return std::make_shared<DataTypeFloat64>();
}
template <typename FieldType>
bool execute(Block & block, const IColumn * arg, const size_t result)
template <typename T>
static void executeInIterations(const T * src_data, Float64 * dst_data, size_t size)
{
if (const auto col = checkAndGetColumn<ColumnVector<FieldType>>(arg))
const size_t rows_remaining = size % Impl::rows_per_iteration;
const size_t rows_size = size - rows_remaining;
for (size_t i = 0; i < rows_size; i += Impl::rows_per_iteration)
Impl::execute(&src_data[i], &dst_data[i]);
if (rows_remaining != 0)
{
auto dst = ColumnVector<Float64>::create();
T src_remaining[Impl::rows_per_iteration];
memcpy(src_remaining, &src_data[rows_size], rows_remaining * sizeof(T));
memset(src_remaining + rows_remaining, 0, (Impl::rows_per_iteration - rows_remaining) * sizeof(T));
Float64 dst_remaining[Impl::rows_per_iteration];
const auto & src_data = col->getData();
const auto src_size = src_data.size();
auto & dst_data = dst->getData();
dst_data.resize(src_size);
Impl::execute(src_remaining, dst_remaining);
const auto rows_remaining = src_size % Impl::rows_per_iteration;
const auto rows_size = src_size - rows_remaining;
for (size_t i = 0; i < rows_size; i += Impl::rows_per_iteration)
Impl::execute(&src_data[i], &dst_data[i]);
if (rows_remaining != 0)
{
FieldType src_remaining[Impl::rows_per_iteration];
memcpy(src_remaining, &src_data[rows_size], rows_remaining * sizeof(FieldType));
memset(src_remaining + rows_remaining, 0, (Impl::rows_per_iteration - rows_remaining) * sizeof(FieldType));
Float64 dst_remaining[Impl::rows_per_iteration];
Impl::execute(src_remaining, dst_remaining);
memcpy(&dst_data[rows_size], dst_remaining, rows_remaining * sizeof(Float64));
}
block.getByPosition(result).column = std::move(dst);
return true;
memcpy(&dst_data[rows_size], dst_remaining, rows_remaining * sizeof(Float64));
}
}
return false;
template <typename T>
static bool execute(Block & block, const ColumnVector<T> * col, const size_t result)
{
const auto & src_data = col->getData();
const size_t size = src_data.size();
auto dst = ColumnVector<Float64>::create();
auto & dst_data = dst->getData();
dst_data.resize(size);
executeInIterations(src_data.data(), dst_data.data(), size);
block.getByPosition(result).column = std::move(dst);
return true;
}
template <typename T>
static bool execute(Block & block, const ColumnDecimal<T> * col, const size_t result)
{
const auto & src_data = col->getData();
const size_t size = src_data.size();
UInt32 scale = src_data.getScale();
auto dst = ColumnVector<Float64>::create();
auto & dst_data = dst->getData();
dst_data.resize(size);
for (size_t i = 0; i < size; ++i)
dst_data[i] = convertFromDecimal<DataTypeDecimal<T>, DataTypeNumber<Float64>>(src_data[i], scale);
executeInIterations(dst_data.data(), dst_data.data(), size);
block.getByPosition(result).column = std::move(dst);
return true;
}
bool useDefaultImplementationForConstants() const override { return true; }
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) override
{
const auto arg = block.getByPosition(arguments[0]).column.get();
const ColumnWithTypeAndName & col = block.getByPosition(arguments[0]);
if (!execute<UInt8>(block, arg, result) &&
!execute<UInt16>(block, arg, result) &&
!execute<UInt32>(block, arg, result) &&
!execute<UInt64>(block, arg, result) &&
!execute<Int8>(block, arg, result) &&
!execute<Int16>(block, arg, result) &&
!execute<Int32>(block, arg, result) &&
!execute<Int64>(block, arg, result) &&
!execute<Float32>(block, arg, result) &&
!execute<Float64>(block, arg, result))
auto call = [&](const auto & types) -> bool
{
throw Exception{"Illegal column " + arg->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_COLUMN};
}
using Types = std::decay_t<decltype(types)>;
using Type = typename Types::RightType;
using ColVecType = std::conditional_t<IsDecimalNumber<Type>, ColumnDecimal<Type>, ColumnVector<Type>>;
const auto col_vec = checkAndGetColumn<ColVecType>(col.column.get());
return execute<Type>(block, col_vec, result);
};
if (!callOnBasicType<void, true, true, true>(col.type->getTypeId(), call))
throw Exception{"Illegal column " + col.column->getName() + " of argument of function " + getName(),
ErrorCodes::ILLEGAL_COLUMN};
}
};
@ -211,8 +236,7 @@ private:
}
template <typename LeftType, typename RightType>
bool executeRight(Block & block, const size_t result, const ColumnConst * left_arg,
const IColumn * right_arg)
bool executeTyped(Block & block, const size_t result, const ColumnConst * left_arg, const IColumn * right_arg)
{
if (const auto right_arg_typed = checkAndGetColumn<ColumnVector<RightType>>(right_arg))
{
@ -251,8 +275,7 @@ private:
}
template <typename LeftType, typename RightType>
bool executeRight(Block & block, const size_t result, const ColumnVector<LeftType> * left_arg,
const IColumn * right_arg)
bool executeTyped(Block & block, const size_t result, const ColumnVector<LeftType> * left_arg, const IColumn * right_arg)
{
if (const auto right_arg_typed = checkAndGetColumn<ColumnVector<RightType>>(right_arg))
{
@ -324,80 +347,47 @@ private:
return false;
}
template <typename LeftType>
bool executeLeft(Block & block, const ColumnNumbers & arguments, const size_t result,
const IColumn * left_arg)
{
if (const auto left_arg_typed = checkAndGetColumn<ColumnVector<LeftType>>(left_arg))
{
const auto right_arg = block.getByPosition(arguments[1]).column.get();
if (executeRight<LeftType, UInt8>(block, result, left_arg_typed, right_arg) ||
executeRight<LeftType, UInt16>(block, result, left_arg_typed, right_arg) ||
executeRight<LeftType, UInt32>(block, result, left_arg_typed, right_arg) ||
executeRight<LeftType, UInt64>(block, result, left_arg_typed, right_arg) ||
executeRight<LeftType, Int8>(block, result, left_arg_typed, right_arg) ||
executeRight<LeftType, Int16>(block, result, left_arg_typed, right_arg) ||
executeRight<LeftType, Int32>(block, result, left_arg_typed, right_arg) ||
executeRight<LeftType, Int64>(block, result, left_arg_typed, right_arg) ||
executeRight<LeftType, Float32>(block, result, left_arg_typed, right_arg) ||
executeRight<LeftType, Float64>(block, result, left_arg_typed, right_arg))
{
return true;
}
else
{
throw Exception{"Illegal column " + block.getByPosition(arguments[1]).column->getName() +
" of second argument of function " + getName(),
ErrorCodes::ILLEGAL_COLUMN};
}
}
else if (const auto left_arg_typed = checkAndGetColumnConst<ColumnVector<LeftType>>(left_arg))
{
const auto right_arg = block.getByPosition(arguments[1]).column.get();
if (executeRight<LeftType, UInt8>(block, result, left_arg_typed, right_arg) ||
executeRight<LeftType, UInt16>(block, result, left_arg_typed, right_arg) ||
executeRight<LeftType, UInt32>(block, result, left_arg_typed, right_arg) ||
executeRight<LeftType, UInt64>(block, result, left_arg_typed, right_arg) ||
executeRight<LeftType, Int8>(block, result, left_arg_typed, right_arg) ||
executeRight<LeftType, Int16>(block, result, left_arg_typed, right_arg) ||
executeRight<LeftType, Int32>(block, result, left_arg_typed, right_arg) ||
executeRight<LeftType, Int64>(block, result, left_arg_typed, right_arg) ||
executeRight<LeftType, Float32>(block, result, left_arg_typed, right_arg) ||
executeRight<LeftType, Float64>(block, result, left_arg_typed, right_arg))
{
return true;
}
else
{
throw Exception{"Illegal column " + block.getByPosition(arguments[1]).column->getName() +
" of second argument of function " + getName(),
ErrorCodes::ILLEGAL_COLUMN};
}
}
return false;
}
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) override
{
const auto left_arg = block.getByPosition(arguments[0]).column.get();
const ColumnWithTypeAndName & col_left = block.getByPosition(arguments[0]);
const ColumnWithTypeAndName & col_right = block.getByPosition(arguments[1]);
if (!executeLeft<UInt8>(block, arguments, result, left_arg) &&
!executeLeft<UInt16>(block, arguments, result, left_arg) &&
!executeLeft<UInt32>(block, arguments, result, left_arg) &&
!executeLeft<UInt64>(block, arguments, result, left_arg) &&
!executeLeft<Int8>(block, arguments, result, left_arg) &&
!executeLeft<Int16>(block, arguments, result, left_arg) &&
!executeLeft<Int32>(block, arguments, result, left_arg) &&
!executeLeft<Int64>(block, arguments, result, left_arg) &&
!executeLeft<Float32>(block, arguments, result, left_arg) &&
!executeLeft<Float64>(block, arguments, result, left_arg))
auto call = [&](const auto & types) -> bool
{
throw Exception{"Illegal column " + left_arg->getName() + " of argument of function " + getName(),
using Types = std::decay_t<decltype(types)>;
using LeftType = typename Types::LeftType;
using RightType = typename Types::RightType;
using ColVecLeft = ColumnVector<LeftType>;
const IColumn * left_arg = col_left.column.get();
const IColumn * right_arg = col_right.column.get();
if (const auto left_arg_typed = checkAndGetColumn<ColVecLeft>(left_arg))
{
if (executeTyped<LeftType, RightType>(block, result, left_arg_typed, right_arg))
return true;
throw Exception{"Illegal column " + right_arg->getName() + " of second argument of function " + getName(),
ErrorCodes::ILLEGAL_COLUMN};
}
else if (const auto left_arg_typed = checkAndGetColumnConst<ColVecLeft>(left_arg))
{
if (executeTyped<LeftType, RightType>(block, result, left_arg_typed, right_arg))
return true;
throw Exception{"Illegal column " + right_arg->getName() + " of second argument of function " + getName(),
ErrorCodes::ILLEGAL_COLUMN};
}
return false;
};
TypeIndex left_index = col_left.type->getTypeId();
TypeIndex right_index = col_right.type->getTypeId();
if (!callOnBasicTypes<true, true, false>(left_index, right_index, call))
throw Exception{"Illegal column " + col_left.column->getName() + " of argument of function " + getName(),
ErrorCodes::ILLEGAL_COLUMN};
}
}
};

View File

@ -26,3 +26,11 @@
[-50.0000,-40.0000,-30.0000,-20.0000,-10.0000,0.0000,10.0000,20.0000,30.0000,40.0000,50.0000]
[-16.66666666,-13.33333333,-10.00000000,-6.66666666,-3.33333333,0.00000000,3.33333333,6.66666666,10.00000000,13.33333333,16.66666666]
[-10.00000000,-8.00000000,-6.00000000,-4.00000000,-2.00000000,0.00000000,2.00000000,4.00000000,6.00000000,8.00000000,10.00000000]
850 94.44444438684269 34 Float64 Float64 Float64
850 94.4444443868427 34.00000000000001
858.5 95.38888883071111 34.34 Float64 Float64 Float64
858.5 95.38888883071112 34.34
29.154759474226502 9.718253155111915 5.830951894845301 Float64 Float64 Float64
29.154759474226502 9.718253155111915 5.830951894845301
29.300170647967224 9.766723546344041 5.860034129593445 Float64 Float64 Float64
29.300170647967224 9.766723546344041 5.860034129593445

View File

@ -54,6 +54,21 @@ SELECT quantilesExact(0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0)(a)
SELECT quantilesExact(0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0)(b) FROM test.decimal;
SELECT quantilesExact(0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0)(c) FROM test.decimal;
SELECT varPop(a) AS va, varPop(b) AS vb, varPop(c) AS vc, toTypeName(va), toTypeName(vb), toTypeName(vc) FROM test.decimal;
SELECT varPop(toFloat64(a)), varPop(toFloat64(b)), varPop(toFloat64(c)) FROM test.decimal;
SELECT varSamp(a) AS va, varSamp(b) AS vb, varSamp(c) AS vc, toTypeName(va), toTypeName(vb), toTypeName(vc) FROM test.decimal;
SELECT varSamp(toFloat64(a)), varSamp(toFloat64(b)), varSamp(toFloat64(c)) FROM test.decimal;
SELECT stddevPop(a) AS da, stddevPop(b) AS db, stddevPop(c) AS dc, toTypeName(da), toTypeName(db), toTypeName(dc) FROM test.decimal;
SELECT stddevPop(toFloat64(a)), stddevPop(toFloat64(b)), stddevPop(toFloat64(c)) FROM test.decimal;
SELECT stddevSamp(a) AS da, stddevSamp(b) AS db, stddevSamp(c) AS dc, toTypeName(da), toTypeName(db), toTypeName(dc) FROM test.decimal;
SELECT stddevSamp(toFloat64(a)), stddevSamp(toFloat64(b)), stddevSamp(toFloat64(c)) FROM test.decimal;
SELECT covarPop(a, a), covarPop(b, b), covarPop(c, c) FROM test.decimal; -- { serverError 43 }
SELECT covarSamp(a, a), covarSamp(b, b), covarSamp(c, c) FROM test.decimal; -- { serverError 43 }
SELECT corr(a, a), corr(b, b), corr(c, c) FROM test.decimal; -- { serverError 43 }
SELECT 1 LIMIT 0;
-- TODO: sumMap
-- TODO: other quantile(s)
-- TODO: groupArray, groupArrayInsertAt, groupUniqArray

View File

@ -0,0 +1,30 @@
42.4200 3.7476 42.419153766068966
42.4200 5.4066 42.41786197045111
42.4200 1.6275 42.413098391048806
42.4200 6.513 42.419169
42.4200 3.4875 42.417263671875
1.00000 0.8427007929497149 0.15729920705028513
42.4200 115.60113124678627 1.6029995567009473e50
0.00 0 1 0
3.14159265 0 -1 -0
1.00 1.5707963267948966 0 0.7853981633974483
42.4200 3.7476 42.419153766068966
42.4200 5.4066 42.41786197045111
42.4200 1.6275 42.413098391048806
42.4200 6.513 42.419169
42.4200 3.4875 42.417263671875
1.00000 0.8427007929497149 0.15729920705028513
42.4200 115.60113124678627 1.6029995567009473e50
0.00 0 1 0
3.14159265358979328 0 -1 -0
1.00 1.5707963267948966 0 0.7853981633974483
42.4200 3.7476 42.419153766068966
42.4200 5.4066 42.41786197045111
42.4200 1.6275 42.413098391048806
42.4200 6.513 42.419169
42.4200 3.4875 42.417263671875
1.00000 0.8427007929497149 0.15729920705028513
42.4200 115.60113124678627 1.6029995567009473e50
0.00 0 1 0
3.1415926535897927981986333033020522496 0 -1 -0
1.00 1.5707963267948966 0 0.7853981633974483

View File

@ -0,0 +1,48 @@
SET allow_experimental_decimal_type = 1;
SET send_logs_level = 'none';
SELECT toDecimal32('42.42', 4) AS x, toDecimal32(log(x), 4) AS y, exp(y);
SELECT toDecimal32('42.42', 4) AS x, toDecimal32(log2(x), 4) AS y, exp2(y);
SELECT toDecimal32('42.42', 4) AS x, toDecimal32(log10(x), 4) AS y, exp10(y);
SELECT toDecimal32('42.42', 4) AS x, toDecimal32(sqrt(x), 3) AS y, y * y;
SELECT toDecimal32('42.42', 4) AS x, toDecimal32(cbrt(x), 4) AS y, toDecimal64(y, 4) * y * y;
SELECT toDecimal32('1.0', 5) AS x, erf(x), erfc(x);
SELECT toDecimal32('42.42', 4) AS x, lgamma(x), tgamma(x);
SELECT toDecimal32('0.0', 2) AS x, round(sin(x), 8), round(cos(x), 8), round(tan(x), 8);
SELECT toDecimal32(pi(), 8) AS x, round(sin(x), 8), round(cos(x), 8), round(tan(x), 8);
SELECT toDecimal32('1.0', 2) AS x, asin(x), acos(x), atan(x);
SELECT toDecimal64('42.42', 4) AS x, toDecimal32(log(x), 4) AS y, exp(y);
SELECT toDecimal64('42.42', 4) AS x, toDecimal32(log2(x), 4) AS y, exp2(y);
SELECT toDecimal64('42.42', 4) AS x, toDecimal32(log10(x), 4) AS y, exp10(y);
SELECT toDecimal64('42.42', 4) AS x, toDecimal32(sqrt(x), 3) AS y, y * y;
SELECT toDecimal64('42.42', 4) AS x, toDecimal32(cbrt(x), 4) AS y, toDecimal64(y, 4) * y * y;
SELECT toDecimal64('1.0', 5) AS x, erf(x), erfc(x);
SELECT toDecimal64('42.42', 4) AS x, lgamma(x), tgamma(x);
SELECT toDecimal64('0.0', 2) AS x, round(sin(x), 8), round(cos(x), 8), round(tan(x), 8);
SELECT toDecimal64(pi(), 17) AS x, round(sin(x), 8), round(cos(x), 8), round(tan(x), 8);
SELECT toDecimal64('1.0', 2) AS x, asin(x), acos(x), atan(x);
SELECT toDecimal128('42.42', 4) AS x, toDecimal32(log(x), 4) AS y, exp(y);
SELECT toDecimal128('42.42', 4) AS x, toDecimal32(log2(x), 4) AS y, exp2(y);
SELECT toDecimal128('42.42', 4) AS x, toDecimal32(log10(x), 4) AS y, exp10(y);
SELECT toDecimal128('42.42', 4) AS x, toDecimal32(sqrt(x), 3) AS y, y * y;
SELECT toDecimal128('42.42', 4) AS x, toDecimal32(cbrt(x), 4) AS y, toDecimal64(y, 4) * y * y;
SELECT toDecimal128('1.0', 5) AS x, erf(x), erfc(x);
SELECT toDecimal128('42.42', 4) AS x, lgamma(x), tgamma(x);
SELECT toDecimal128('0.0', 2) AS x, round(sin(x), 8), round(cos(x), 8), round(tan(x), 8);
SELECT toDecimal128(pi(), 37) AS x, round(sin(x), 8), round(cos(x), 8), round(tan(x), 8);
SELECT toDecimal128('1.0', 2) AS x, asin(x), acos(x), atan(x);
SELECT toDecimal32('4.2', 1) AS x, pow(x, 2), pow(x, 0.5); -- { serverError 43 }
SELECT toDecimal64('4.2', 1) AS x, pow(x, 2), pow(x, 0.5); -- { serverError 43 }
SELECT toDecimal128('4.2', 1) AS x, pow(x, 2), pow(x, 0.5); -- { serverError 43 }