mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-09-20 08:40:50 +00:00
Implemented faster and not numerically stable versions of statistical functions [#CLICKHOUSE-2].
This commit is contained in:
parent
22ef87b763
commit
a7b8541cea
@ -45,23 +45,15 @@ AggregateFunctionPtr createAggregateFunctionStatisticsBinary(const std::string &
|
||||
|
||||
}
|
||||
|
||||
void registerAggregateFunctionsStatistics(AggregateFunctionFactory & factory)
|
||||
void registerAggregateFunctionsStatisticsStable(AggregateFunctionFactory & factory)
|
||||
{
|
||||
factory.registerFunction("varSamp", createAggregateFunctionStatisticsUnary<AggregateFunctionVarSamp>);
|
||||
factory.registerFunction("varPop", createAggregateFunctionStatisticsUnary<AggregateFunctionVarPop>);
|
||||
factory.registerFunction("stddevSamp", createAggregateFunctionStatisticsUnary<AggregateFunctionStdDevSamp>);
|
||||
factory.registerFunction("stddevPop", createAggregateFunctionStatisticsUnary<AggregateFunctionStdDevPop>);
|
||||
factory.registerFunction("covarSamp", createAggregateFunctionStatisticsBinary<AggregateFunctionCovarSamp>);
|
||||
factory.registerFunction("covarPop", createAggregateFunctionStatisticsBinary<AggregateFunctionCovarPop>);
|
||||
factory.registerFunction("corr", createAggregateFunctionStatisticsBinary<AggregateFunctionCorr>, AggregateFunctionFactory::CaseInsensitive);
|
||||
|
||||
/// Synonims for compatibility.
|
||||
factory.registerFunction("VAR_SAMP", createAggregateFunctionStatisticsUnary<AggregateFunctionVarSamp>, AggregateFunctionFactory::CaseInsensitive);
|
||||
factory.registerFunction("VAR_POP", createAggregateFunctionStatisticsUnary<AggregateFunctionVarPop>, AggregateFunctionFactory::CaseInsensitive);
|
||||
factory.registerFunction("STDDEV_SAMP", createAggregateFunctionStatisticsUnary<AggregateFunctionStdDevSamp>, AggregateFunctionFactory::CaseInsensitive);
|
||||
factory.registerFunction("STDDEV_POP", createAggregateFunctionStatisticsUnary<AggregateFunctionStdDevPop>, AggregateFunctionFactory::CaseInsensitive);
|
||||
factory.registerFunction("COVAR_SAMP", createAggregateFunctionStatisticsBinary<AggregateFunctionCovarSamp>, AggregateFunctionFactory::CaseInsensitive);
|
||||
factory.registerFunction("COVAR_POP", createAggregateFunctionStatisticsBinary<AggregateFunctionCovarPop>, AggregateFunctionFactory::CaseInsensitive);
|
||||
factory.registerFunction("varSampStable", createAggregateFunctionStatisticsUnary<AggregateFunctionVarSampStable>);
|
||||
factory.registerFunction("varPopStable", createAggregateFunctionStatisticsUnary<AggregateFunctionVarPopStable>);
|
||||
factory.registerFunction("stddevSampStable", createAggregateFunctionStatisticsUnary<AggregateFunctionStddevSampStable>);
|
||||
factory.registerFunction("stddevPopStable", createAggregateFunctionStatisticsUnary<AggregateFunctionStddevPopStable>);
|
||||
factory.registerFunction("covarSampStable", createAggregateFunctionStatisticsBinary<AggregateFunctionCovarSampStable>);
|
||||
factory.registerFunction("covarPopStable", createAggregateFunctionStatisticsBinary<AggregateFunctionCovarPopStable>);
|
||||
factory.registerFunction("corrStable", createAggregateFunctionStatisticsBinary<AggregateFunctionCorrStable>);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -50,8 +50,6 @@ template <typename T, typename Op>
|
||||
class AggregateFunctionVarianceData
|
||||
{
|
||||
public:
|
||||
AggregateFunctionVarianceData() = default;
|
||||
|
||||
void update(const IColumn & column, size_t row_num)
|
||||
{
|
||||
T received = static_cast<const ColumnVector<T> &>(column).getData()[row_num];
|
||||
@ -446,24 +444,24 @@ struct AggregateFunctionCorrImpl
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using AggregateFunctionVarSamp = AggregateFunctionVariance<T, AggregateFunctionVarSampImpl>;
|
||||
using AggregateFunctionVarSampStable = AggregateFunctionVariance<T, AggregateFunctionVarSampImpl>;
|
||||
|
||||
template <typename T>
|
||||
using AggregateFunctionStdDevSamp = AggregateFunctionVariance<T, AggregateFunctionStdDevSampImpl>;
|
||||
using AggregateFunctionStddevSampStable = AggregateFunctionVariance<T, AggregateFunctionStdDevSampImpl>;
|
||||
|
||||
template <typename T>
|
||||
using AggregateFunctionVarPop = AggregateFunctionVariance<T, AggregateFunctionVarPopImpl>;
|
||||
using AggregateFunctionVarPopStable = AggregateFunctionVariance<T, AggregateFunctionVarPopImpl>;
|
||||
|
||||
template <typename T>
|
||||
using AggregateFunctionStdDevPop = AggregateFunctionVariance<T, AggregateFunctionStdDevPopImpl>;
|
||||
using AggregateFunctionStddevPopStable = AggregateFunctionVariance<T, AggregateFunctionStdDevPopImpl>;
|
||||
|
||||
template <typename T, typename U>
|
||||
using AggregateFunctionCovarSamp = AggregateFunctionCovariance<T, U, AggregateFunctionCovarSampImpl>;
|
||||
using AggregateFunctionCovarSampStable = AggregateFunctionCovariance<T, U, AggregateFunctionCovarSampImpl>;
|
||||
|
||||
template <typename T, typename U>
|
||||
using AggregateFunctionCovarPop = AggregateFunctionCovariance<T, U, AggregateFunctionCovarPopImpl>;
|
||||
using AggregateFunctionCovarPopStable = AggregateFunctionCovariance<T, U, AggregateFunctionCovarPopImpl>;
|
||||
|
||||
template <typename T, typename U>
|
||||
using AggregateFunctionCorr = AggregateFunctionCovariance<T, U, AggregateFunctionCorrImpl, true>;
|
||||
using AggregateFunctionCorrStable = AggregateFunctionCovariance<T, U, AggregateFunctionCorrImpl, true>;
|
||||
|
||||
}
|
||||
|
247
dbms/src/AggregateFunctions/AggregateFunctionStatisticsSimple.h
Normal file
247
dbms/src/AggregateFunctions/AggregateFunctionStatisticsSimple.h
Normal file
@ -0,0 +1,247 @@
|
||||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
enum class VarianceMode
|
||||
{
|
||||
Population,
|
||||
Sample
|
||||
};
|
||||
|
||||
enum class VariancePower
|
||||
{
|
||||
Original,
|
||||
Sqrt
|
||||
};
|
||||
|
||||
|
||||
template <typename T>
|
||||
struct VarMoments
|
||||
{
|
||||
T m0{};
|
||||
T m1{};
|
||||
T m2{};
|
||||
|
||||
void add(T x)
|
||||
{
|
||||
++m0;
|
||||
m1 += x;
|
||||
m2 += x * x;
|
||||
}
|
||||
|
||||
void merge(const VarMoments & rhs)
|
||||
{
|
||||
m0 += rhs.m0;
|
||||
m1 += rhs.m1;
|
||||
m2 += rhs.m2;
|
||||
}
|
||||
|
||||
void write(WriteBuffer & buf) const
|
||||
{
|
||||
writePODBinary(*this, buf);
|
||||
}
|
||||
|
||||
void read(ReadBuffer & buf)
|
||||
{
|
||||
readPODBinary(*this, buf);
|
||||
}
|
||||
|
||||
template <VarianceMode mode, VariancePower power>
|
||||
T get() const
|
||||
{
|
||||
if (m0 == 0 && mode == VarianceMode::Sample)
|
||||
return std::numeric_limits<T>::quiet_NaN();
|
||||
|
||||
T res = (m2 - m1 * m1 / m0) / (m0 - (mode == VarianceMode::Sample));
|
||||
return power == VariancePower::Original ? res : sqrt(res);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct CovarMoments
|
||||
{
|
||||
T m0{};
|
||||
T x1{};
|
||||
T y1{};
|
||||
T xy{};
|
||||
|
||||
void add(T x, T y)
|
||||
{
|
||||
++m0;
|
||||
x1 += x;
|
||||
y1 += y;
|
||||
xy += x * y;
|
||||
}
|
||||
|
||||
void merge(const CovarMoments & rhs)
|
||||
{
|
||||
m0 += rhs.m0;
|
||||
x1 += rhs.x1;
|
||||
y1 += rhs.y1;
|
||||
xy += rhs.xy;
|
||||
}
|
||||
|
||||
void write(WriteBuffer & buf) const
|
||||
{
|
||||
writePODBinary(*this, buf);
|
||||
}
|
||||
|
||||
void read(ReadBuffer & buf)
|
||||
{
|
||||
readPODBinary(*this, buf);
|
||||
}
|
||||
|
||||
template <VarianceMode mode>
|
||||
T get() const
|
||||
{
|
||||
if (m0 == 0 && mode == VarianceMode::Sample)
|
||||
return std::numeric_limits<T>::quiet_NaN();
|
||||
|
||||
return (xy - x1 * y1 / m0) / (m0 - (mode == VarianceMode::Sample));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct CorrMoments
|
||||
{
|
||||
T m0{};
|
||||
T x1{};
|
||||
T y1{};
|
||||
T xy{};
|
||||
T x2{};
|
||||
T y2{};
|
||||
|
||||
void add(T x, T y)
|
||||
{
|
||||
++m0;
|
||||
x1 += x;
|
||||
y1 += y;
|
||||
xy += x * y;
|
||||
x2 += x * x;
|
||||
y2 += y * y;
|
||||
}
|
||||
|
||||
void merge(const CorrMoments & rhs)
|
||||
{
|
||||
m0 += rhs.m0;
|
||||
x1 += rhs.x1;
|
||||
y1 += rhs.y1;
|
||||
xy += rhs.xy;
|
||||
x2 += rhs.x2;
|
||||
y2 += rhs.y2;
|
||||
}
|
||||
|
||||
void write(WriteBuffer & buf) const
|
||||
{
|
||||
writePODBinary(*this, buf);
|
||||
}
|
||||
|
||||
void read(ReadBuffer & buf)
|
||||
{
|
||||
readPODBinary(*this, buf);
|
||||
}
|
||||
|
||||
T get() const
|
||||
{
|
||||
return (m0 * xy - x1 * y1) / sqrt((m0 * x2 - x1 * x1) * (m0 * y2 - y1 * y1));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
enum class StatisticsFunctionKind
|
||||
{
|
||||
varPop, varSamp,
|
||||
stddevPop, stddevSamp,
|
||||
covarPop, covarSamp,
|
||||
corr
|
||||
};
|
||||
|
||||
|
||||
template <typename T, typename Data, StatisticsFunctionKind Kind>
|
||||
class AggregateFunctionVarianceSimple final
|
||||
: public IAggregateFunctionDataHelper<Data, AggregateFunctionVarianceSimple<T, Data, Kind>>
|
||||
{
|
||||
public:
|
||||
String getName() const override
|
||||
{
|
||||
switch (Kind)
|
||||
{
|
||||
case StatisticsFunctionKind::varPop: return "varPop";
|
||||
case StatisticsFunctionKind::varSamp: return "varSamp";
|
||||
case StatisticsFunctionKind::stddevPop: return "stddevPop";
|
||||
case StatisticsFunctionKind::stddevSamp: return "stddevSamp";
|
||||
case StatisticsFunctionKind::covarPop: return "covarPop";
|
||||
case StatisticsFunctionKind::covarSamp: return "covarSamp";
|
||||
case StatisticsFunctionKind::corr: return "corr";
|
||||
}
|
||||
}
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
{
|
||||
return std::make_shared<DataTypeNumber<T>>();
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
{
|
||||
if constexpr (Kind == StatisticsFunctionKind::covarPop || Kind == StatisticsFunctionKind::covarSamp || Kind == StatisticsFunctionKind::corr)
|
||||
this->data(place).add(
|
||||
static_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num],
|
||||
static_cast<const ColumnVector<T> &>(*columns[1]).getData()[row_num]);
|
||||
else
|
||||
this->data(place).add(
|
||||
static_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num]);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
this->data(place).merge(this->data(rhs));
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
|
||||
{
|
||||
this->data(place).write(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override
|
||||
{
|
||||
this->data(place).read(buf);
|
||||
}
|
||||
|
||||
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
|
||||
{
|
||||
const auto & data = this->data(place);
|
||||
auto & dst = static_cast<ColumnVector<T> &>(to).getData();
|
||||
|
||||
if constexpr (Kind == StatisticsFunctionKind::varPop) dst.push_back(data.template get<VarianceMode::Population, VariancePower::Original>());
|
||||
else if constexpr (Kind == StatisticsFunctionKind::varSamp) dst.push_back(data.template get<VarianceMode::Sample, VariancePower::Original>());
|
||||
else if constexpr (Kind == StatisticsFunctionKind::stddevPop) dst.push_back(data.template get<VarianceMode::Population, VariancePower::Sqrt>());
|
||||
else if constexpr (Kind == StatisticsFunctionKind::stddevSamp) dst.push_back(data.template get<VarianceMode::Sample, VariancePower::Sqrt>());
|
||||
else if constexpr (Kind == StatisticsFunctionKind::covarPop) dst.push_back(data.template get<VarianceMode::Population>());
|
||||
else if constexpr (Kind == StatisticsFunctionKind::covarSamp) dst.push_back(data.template get<VarianceMode::Sample>());
|
||||
else if constexpr (Kind == StatisticsFunctionKind::corr) dst.push_back(data.get());
|
||||
}
|
||||
|
||||
const char * getHeaderFilePath() const override { return __FILE__; }
|
||||
};
|
||||
|
||||
|
||||
template <typename T> using AggregateFunctionVarPopSimple = AggregateFunctionVarianceSimple<T, VarMoments<T>, StatisticsFunctionKind::varPop>;
|
||||
template <typename T> using AggregateFunctionVarSampSimple = AggregateFunctionVarianceSimple<T, VarMoments<T>, StatisticsFunctionKind::varSamp>;
|
||||
template <typename T> using AggregateFunctionStddevPopSimple = AggregateFunctionVarianceSimple<T, VarMoments<T>, StatisticsFunctionKind::stddevPop>;
|
||||
template <typename T> using AggregateFunctionStddevSampSimple = AggregateFunctionVarianceSimple<T, VarMoments<T>, StatisticsFunctionKind::stddevSamp>;
|
||||
template <typename T> using AggregateFunctionCovarPopSimple = AggregateFunctionVarianceSimple<T, CovarMoments<T>, StatisticsFunctionKind::covarPop>;
|
||||
template <typename T> using AggregateFunctionCovarSampSimple = AggregateFunctionVarianceSimple<T, CovarMoments<T>, StatisticsFunctionKind::covarSamp>;
|
||||
template <typename T> using AggregateFunctionCorrSimple = AggregateFunctionVarianceSimple<T, CorrMoments<T>, StatisticsFunctionKind::corr>;
|
||||
|
||||
}
|
@ -0,0 +1,67 @@
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/Helpers.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
#include <AggregateFunctions/AggregateFunctionStatisticsSimple.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
template <template <typename> class FunctionTemplate>
|
||||
AggregateFunctionPtr createAggregateFunctionStatisticsUnary(const std::string & name, const DataTypes & argument_types, const Array & parameters)
|
||||
{
|
||||
assertNoParameters(name, parameters);
|
||||
assertUnary(name, argument_types);
|
||||
|
||||
if (typeid_cast<const DataTypeFloat32 *>(argument_types[0].get())) return std::make_shared<FunctionTemplate<Float32>>();
|
||||
if (typeid_cast<const DataTypeFloat64 *>(argument_types[0].get())) return std::make_shared<FunctionTemplate<Float64>>();
|
||||
|
||||
throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
}
|
||||
|
||||
template <template <typename> class FunctionTemplate>
|
||||
AggregateFunctionPtr createAggregateFunctionStatisticsBinary(const std::string & name, const DataTypes & argument_types, const Array & parameters)
|
||||
{
|
||||
assertNoParameters(name, parameters);
|
||||
assertBinary(name, argument_types);
|
||||
|
||||
if (!argument_types[0]->equals(*argument_types[1]))
|
||||
throw Exception("Illegal types " + argument_types[0]->getName() + " and " + argument_types[1]->getName()
|
||||
+ " of arguments for aggregate function " + name + ", must be the same", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
if (typeid_cast<const DataTypeFloat32 *>(argument_types[0].get())) return std::make_shared<FunctionTemplate<Float32>>();
|
||||
if (typeid_cast<const DataTypeFloat64 *>(argument_types[0].get())) return std::make_shared<FunctionTemplate<Float64>>();
|
||||
|
||||
throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void registerAggregateFunctionsStatisticsSimple(AggregateFunctionFactory & factory)
|
||||
{
|
||||
factory.registerFunction("varSamp", createAggregateFunctionStatisticsUnary<AggregateFunctionVarSampSimple>);
|
||||
factory.registerFunction("varPop", createAggregateFunctionStatisticsUnary<AggregateFunctionVarPopSimple>);
|
||||
factory.registerFunction("stddevSamp", createAggregateFunctionStatisticsUnary<AggregateFunctionStddevSampSimple>);
|
||||
factory.registerFunction("stddevPop", createAggregateFunctionStatisticsUnary<AggregateFunctionStddevPopSimple>);
|
||||
factory.registerFunction("covarSamp", createAggregateFunctionStatisticsBinary<AggregateFunctionCovarSampSimple>);
|
||||
factory.registerFunction("covarPop", createAggregateFunctionStatisticsBinary<AggregateFunctionCovarPopSimple>);
|
||||
factory.registerFunction("corr", createAggregateFunctionStatisticsBinary<AggregateFunctionCorrSimple>, AggregateFunctionFactory::CaseInsensitive);
|
||||
|
||||
/// Synonims for compatibility.
|
||||
factory.registerFunction("VAR_SAMP", createAggregateFunctionStatisticsUnary<AggregateFunctionVarSampSimple>, AggregateFunctionFactory::CaseInsensitive);
|
||||
factory.registerFunction("VAR_POP", createAggregateFunctionStatisticsUnary<AggregateFunctionVarPopSimple>, AggregateFunctionFactory::CaseInsensitive);
|
||||
factory.registerFunction("STDDEV_SAMP", createAggregateFunctionStatisticsUnary<AggregateFunctionStddevSampSimple>, AggregateFunctionFactory::CaseInsensitive);
|
||||
factory.registerFunction("STDDEV_POP", createAggregateFunctionStatisticsUnary<AggregateFunctionStddevPopSimple>, AggregateFunctionFactory::CaseInsensitive);
|
||||
factory.registerFunction("COVAR_SAMP", createAggregateFunctionStatisticsBinary<AggregateFunctionCovarSampSimple>, AggregateFunctionFactory::CaseInsensitive);
|
||||
factory.registerFunction("COVAR_POP", createAggregateFunctionStatisticsBinary<AggregateFunctionCovarPopSimple>, AggregateFunctionFactory::CaseInsensitive);
|
||||
}
|
||||
|
||||
}
|
@ -13,7 +13,8 @@ void registerAggregateFunctionGroupArrayInsertAt(AggregateFunctionFactory & fact
|
||||
void registerAggregateFunctionsQuantile(AggregateFunctionFactory & factory);
|
||||
void registerAggregateFunctionsSequenceMatch(AggregateFunctionFactory & factory);
|
||||
void registerAggregateFunctionsMinMaxAny(AggregateFunctionFactory & factory);
|
||||
void registerAggregateFunctionsStatistics(AggregateFunctionFactory & factory);
|
||||
void registerAggregateFunctionsStatisticsStable(AggregateFunctionFactory & factory);
|
||||
void registerAggregateFunctionsStatisticsSimple(AggregateFunctionFactory & factory);
|
||||
void registerAggregateFunctionSum(AggregateFunctionFactory & factory);
|
||||
void registerAggregateFunctionSumMap(AggregateFunctionFactory & factory);
|
||||
void registerAggregateFunctionsUniq(AggregateFunctionFactory & factory);
|
||||
@ -34,7 +35,8 @@ void registerAggregateFunctions()
|
||||
registerAggregateFunctionsQuantile(factory);
|
||||
registerAggregateFunctionsSequenceMatch(factory);
|
||||
registerAggregateFunctionsMinMaxAny(factory);
|
||||
registerAggregateFunctionsStatistics(factory);
|
||||
registerAggregateFunctionsStatisticsStable(factory);
|
||||
registerAggregateFunctionsStatisticsSimple(factory);
|
||||
registerAggregateFunctionSum(factory);
|
||||
registerAggregateFunctionSumMap(factory);
|
||||
registerAggregateFunctionsUniq(factory);
|
||||
|
Loading…
Reference in New Issue
Block a user