Implemented faster and not numerically stable versions of statistical functions [#CLICKHOUSE-2].

This commit is contained in:
Alexey Milovidov 2017-12-23 01:23:03 +03:00
parent 22ef87b763
commit a7b8541cea
5 changed files with 333 additions and 27 deletions

View File

@ -45,23 +45,15 @@ AggregateFunctionPtr createAggregateFunctionStatisticsBinary(const std::string &
}
void registerAggregateFunctionsStatistics(AggregateFunctionFactory & factory)
void registerAggregateFunctionsStatisticsStable(AggregateFunctionFactory & factory)
{
factory.registerFunction("varSamp", createAggregateFunctionStatisticsUnary<AggregateFunctionVarSamp>);
factory.registerFunction("varPop", createAggregateFunctionStatisticsUnary<AggregateFunctionVarPop>);
factory.registerFunction("stddevSamp", createAggregateFunctionStatisticsUnary<AggregateFunctionStdDevSamp>);
factory.registerFunction("stddevPop", createAggregateFunctionStatisticsUnary<AggregateFunctionStdDevPop>);
factory.registerFunction("covarSamp", createAggregateFunctionStatisticsBinary<AggregateFunctionCovarSamp>);
factory.registerFunction("covarPop", createAggregateFunctionStatisticsBinary<AggregateFunctionCovarPop>);
factory.registerFunction("corr", createAggregateFunctionStatisticsBinary<AggregateFunctionCorr>, AggregateFunctionFactory::CaseInsensitive);
/// Synonims for compatibility.
factory.registerFunction("VAR_SAMP", createAggregateFunctionStatisticsUnary<AggregateFunctionVarSamp>, AggregateFunctionFactory::CaseInsensitive);
factory.registerFunction("VAR_POP", createAggregateFunctionStatisticsUnary<AggregateFunctionVarPop>, AggregateFunctionFactory::CaseInsensitive);
factory.registerFunction("STDDEV_SAMP", createAggregateFunctionStatisticsUnary<AggregateFunctionStdDevSamp>, AggregateFunctionFactory::CaseInsensitive);
factory.registerFunction("STDDEV_POP", createAggregateFunctionStatisticsUnary<AggregateFunctionStdDevPop>, AggregateFunctionFactory::CaseInsensitive);
factory.registerFunction("COVAR_SAMP", createAggregateFunctionStatisticsBinary<AggregateFunctionCovarSamp>, AggregateFunctionFactory::CaseInsensitive);
factory.registerFunction("COVAR_POP", createAggregateFunctionStatisticsBinary<AggregateFunctionCovarPop>, AggregateFunctionFactory::CaseInsensitive);
factory.registerFunction("varSampStable", createAggregateFunctionStatisticsUnary<AggregateFunctionVarSampStable>);
factory.registerFunction("varPopStable", createAggregateFunctionStatisticsUnary<AggregateFunctionVarPopStable>);
factory.registerFunction("stddevSampStable", createAggregateFunctionStatisticsUnary<AggregateFunctionStddevSampStable>);
factory.registerFunction("stddevPopStable", createAggregateFunctionStatisticsUnary<AggregateFunctionStddevPopStable>);
factory.registerFunction("covarSampStable", createAggregateFunctionStatisticsBinary<AggregateFunctionCovarSampStable>);
factory.registerFunction("covarPopStable", createAggregateFunctionStatisticsBinary<AggregateFunctionCovarPopStable>);
factory.registerFunction("corrStable", createAggregateFunctionStatisticsBinary<AggregateFunctionCorrStable>);
}
}

View File

@ -50,8 +50,6 @@ template <typename T, typename Op>
class AggregateFunctionVarianceData
{
public:
AggregateFunctionVarianceData() = default;
void update(const IColumn & column, size_t row_num)
{
T received = static_cast<const ColumnVector<T> &>(column).getData()[row_num];
@ -446,24 +444,24 @@ struct AggregateFunctionCorrImpl
};
template <typename T>
using AggregateFunctionVarSamp = AggregateFunctionVariance<T, AggregateFunctionVarSampImpl>;
using AggregateFunctionVarSampStable = AggregateFunctionVariance<T, AggregateFunctionVarSampImpl>;
template <typename T>
using AggregateFunctionStdDevSamp = AggregateFunctionVariance<T, AggregateFunctionStdDevSampImpl>;
using AggregateFunctionStddevSampStable = AggregateFunctionVariance<T, AggregateFunctionStdDevSampImpl>;
template <typename T>
using AggregateFunctionVarPop = AggregateFunctionVariance<T, AggregateFunctionVarPopImpl>;
using AggregateFunctionVarPopStable = AggregateFunctionVariance<T, AggregateFunctionVarPopImpl>;
template <typename T>
using AggregateFunctionStdDevPop = AggregateFunctionVariance<T, AggregateFunctionStdDevPopImpl>;
using AggregateFunctionStddevPopStable = AggregateFunctionVariance<T, AggregateFunctionStdDevPopImpl>;
template <typename T, typename U>
using AggregateFunctionCovarSamp = AggregateFunctionCovariance<T, U, AggregateFunctionCovarSampImpl>;
using AggregateFunctionCovarSampStable = AggregateFunctionCovariance<T, U, AggregateFunctionCovarSampImpl>;
template <typename T, typename U>
using AggregateFunctionCovarPop = AggregateFunctionCovariance<T, U, AggregateFunctionCovarPopImpl>;
using AggregateFunctionCovarPopStable = AggregateFunctionCovariance<T, U, AggregateFunctionCovarPopImpl>;
template <typename T, typename U>
using AggregateFunctionCorr = AggregateFunctionCovariance<T, U, AggregateFunctionCorrImpl, true>;
using AggregateFunctionCorrStable = AggregateFunctionCovariance<T, U, AggregateFunctionCorrImpl, true>;
}

View File

@ -0,0 +1,247 @@
#pragma once
#include <cmath>
#include <IO/WriteHelpers.h>
#include <IO/ReadHelpers.h>
#include <AggregateFunctions/IAggregateFunction.h>
#include <DataTypes/DataTypesNumber.h>
namespace DB
{
enum class VarianceMode
{
Population,
Sample
};
enum class VariancePower
{
Original,
Sqrt
};
template <typename T>
struct VarMoments
{
T m0{};
T m1{};
T m2{};
void add(T x)
{
++m0;
m1 += x;
m2 += x * x;
}
void merge(const VarMoments & rhs)
{
m0 += rhs.m0;
m1 += rhs.m1;
m2 += rhs.m2;
}
void write(WriteBuffer & buf) const
{
writePODBinary(*this, buf);
}
void read(ReadBuffer & buf)
{
readPODBinary(*this, buf);
}
template <VarianceMode mode, VariancePower power>
T get() const
{
if (m0 == 0 && mode == VarianceMode::Sample)
return std::numeric_limits<T>::quiet_NaN();
T res = (m2 - m1 * m1 / m0) / (m0 - (mode == VarianceMode::Sample));
return power == VariancePower::Original ? res : sqrt(res);
}
};
template <typename T>
struct CovarMoments
{
T m0{};
T x1{};
T y1{};
T xy{};
void add(T x, T y)
{
++m0;
x1 += x;
y1 += y;
xy += x * y;
}
void merge(const CovarMoments & rhs)
{
m0 += rhs.m0;
x1 += rhs.x1;
y1 += rhs.y1;
xy += rhs.xy;
}
void write(WriteBuffer & buf) const
{
writePODBinary(*this, buf);
}
void read(ReadBuffer & buf)
{
readPODBinary(*this, buf);
}
template <VarianceMode mode>
T get() const
{
if (m0 == 0 && mode == VarianceMode::Sample)
return std::numeric_limits<T>::quiet_NaN();
return (xy - x1 * y1 / m0) / (m0 - (mode == VarianceMode::Sample));
}
};
template <typename T>
struct CorrMoments
{
T m0{};
T x1{};
T y1{};
T xy{};
T x2{};
T y2{};
void add(T x, T y)
{
++m0;
x1 += x;
y1 += y;
xy += x * y;
x2 += x * x;
y2 += y * y;
}
void merge(const CorrMoments & rhs)
{
m0 += rhs.m0;
x1 += rhs.x1;
y1 += rhs.y1;
xy += rhs.xy;
x2 += rhs.x2;
y2 += rhs.y2;
}
void write(WriteBuffer & buf) const
{
writePODBinary(*this, buf);
}
void read(ReadBuffer & buf)
{
readPODBinary(*this, buf);
}
T get() const
{
return (m0 * xy - x1 * y1) / sqrt((m0 * x2 - x1 * x1) * (m0 * y2 - y1 * y1));
}
};
enum class StatisticsFunctionKind
{
varPop, varSamp,
stddevPop, stddevSamp,
covarPop, covarSamp,
corr
};
template <typename T, typename Data, StatisticsFunctionKind Kind>
class AggregateFunctionVarianceSimple final
: public IAggregateFunctionDataHelper<Data, AggregateFunctionVarianceSimple<T, Data, Kind>>
{
public:
String getName() const override
{
switch (Kind)
{
case StatisticsFunctionKind::varPop: return "varPop";
case StatisticsFunctionKind::varSamp: return "varSamp";
case StatisticsFunctionKind::stddevPop: return "stddevPop";
case StatisticsFunctionKind::stddevSamp: return "stddevSamp";
case StatisticsFunctionKind::covarPop: return "covarPop";
case StatisticsFunctionKind::covarSamp: return "covarSamp";
case StatisticsFunctionKind::corr: return "corr";
}
}
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeNumber<T>>();
}
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
{
if constexpr (Kind == StatisticsFunctionKind::covarPop || Kind == StatisticsFunctionKind::covarSamp || Kind == StatisticsFunctionKind::corr)
this->data(place).add(
static_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num],
static_cast<const ColumnVector<T> &>(*columns[1]).getData()[row_num]);
else
this->data(place).add(
static_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num]);
}
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
{
this->data(place).merge(this->data(rhs));
}
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{
this->data(place).write(buf);
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override
{
this->data(place).read(buf);
}
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
const auto & data = this->data(place);
auto & dst = static_cast<ColumnVector<T> &>(to).getData();
if constexpr (Kind == StatisticsFunctionKind::varPop) dst.push_back(data.template get<VarianceMode::Population, VariancePower::Original>());
else if constexpr (Kind == StatisticsFunctionKind::varSamp) dst.push_back(data.template get<VarianceMode::Sample, VariancePower::Original>());
else if constexpr (Kind == StatisticsFunctionKind::stddevPop) dst.push_back(data.template get<VarianceMode::Population, VariancePower::Sqrt>());
else if constexpr (Kind == StatisticsFunctionKind::stddevSamp) dst.push_back(data.template get<VarianceMode::Sample, VariancePower::Sqrt>());
else if constexpr (Kind == StatisticsFunctionKind::covarPop) dst.push_back(data.template get<VarianceMode::Population>());
else if constexpr (Kind == StatisticsFunctionKind::covarSamp) dst.push_back(data.template get<VarianceMode::Sample>());
else if constexpr (Kind == StatisticsFunctionKind::corr) dst.push_back(data.get());
}
const char * getHeaderFilePath() const override { return __FILE__; }
};
template <typename T> using AggregateFunctionVarPopSimple = AggregateFunctionVarianceSimple<T, VarMoments<T>, StatisticsFunctionKind::varPop>;
template <typename T> using AggregateFunctionVarSampSimple = AggregateFunctionVarianceSimple<T, VarMoments<T>, StatisticsFunctionKind::varSamp>;
template <typename T> using AggregateFunctionStddevPopSimple = AggregateFunctionVarianceSimple<T, VarMoments<T>, StatisticsFunctionKind::stddevPop>;
template <typename T> using AggregateFunctionStddevSampSimple = AggregateFunctionVarianceSimple<T, VarMoments<T>, StatisticsFunctionKind::stddevSamp>;
template <typename T> using AggregateFunctionCovarPopSimple = AggregateFunctionVarianceSimple<T, CovarMoments<T>, StatisticsFunctionKind::covarPop>;
template <typename T> using AggregateFunctionCovarSampSimple = AggregateFunctionVarianceSimple<T, CovarMoments<T>, StatisticsFunctionKind::covarSamp>;
template <typename T> using AggregateFunctionCorrSimple = AggregateFunctionVarianceSimple<T, CorrMoments<T>, StatisticsFunctionKind::corr>;
}

View File

@ -0,0 +1,67 @@
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/Helpers.h>
#include <AggregateFunctions/FactoryHelpers.h>
#include <AggregateFunctions/AggregateFunctionStatisticsSimple.h>
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
namespace
{
template <template <typename> class FunctionTemplate>
AggregateFunctionPtr createAggregateFunctionStatisticsUnary(const std::string & name, const DataTypes & argument_types, const Array & parameters)
{
assertNoParameters(name, parameters);
assertUnary(name, argument_types);
if (typeid_cast<const DataTypeFloat32 *>(argument_types[0].get())) return std::make_shared<FunctionTemplate<Float32>>();
if (typeid_cast<const DataTypeFloat64 *>(argument_types[0].get())) return std::make_shared<FunctionTemplate<Float64>>();
throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
template <template <typename> class FunctionTemplate>
AggregateFunctionPtr createAggregateFunctionStatisticsBinary(const std::string & name, const DataTypes & argument_types, const Array & parameters)
{
assertNoParameters(name, parameters);
assertBinary(name, argument_types);
if (!argument_types[0]->equals(*argument_types[1]))
throw Exception("Illegal types " + argument_types[0]->getName() + " and " + argument_types[1]->getName()
+ " of arguments for aggregate function " + name + ", must be the same", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (typeid_cast<const DataTypeFloat32 *>(argument_types[0].get())) return std::make_shared<FunctionTemplate<Float32>>();
if (typeid_cast<const DataTypeFloat64 *>(argument_types[0].get())) return std::make_shared<FunctionTemplate<Float64>>();
throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
}
void registerAggregateFunctionsStatisticsSimple(AggregateFunctionFactory & factory)
{
factory.registerFunction("varSamp", createAggregateFunctionStatisticsUnary<AggregateFunctionVarSampSimple>);
factory.registerFunction("varPop", createAggregateFunctionStatisticsUnary<AggregateFunctionVarPopSimple>);
factory.registerFunction("stddevSamp", createAggregateFunctionStatisticsUnary<AggregateFunctionStddevSampSimple>);
factory.registerFunction("stddevPop", createAggregateFunctionStatisticsUnary<AggregateFunctionStddevPopSimple>);
factory.registerFunction("covarSamp", createAggregateFunctionStatisticsBinary<AggregateFunctionCovarSampSimple>);
factory.registerFunction("covarPop", createAggregateFunctionStatisticsBinary<AggregateFunctionCovarPopSimple>);
factory.registerFunction("corr", createAggregateFunctionStatisticsBinary<AggregateFunctionCorrSimple>, AggregateFunctionFactory::CaseInsensitive);
/// Synonims for compatibility.
factory.registerFunction("VAR_SAMP", createAggregateFunctionStatisticsUnary<AggregateFunctionVarSampSimple>, AggregateFunctionFactory::CaseInsensitive);
factory.registerFunction("VAR_POP", createAggregateFunctionStatisticsUnary<AggregateFunctionVarPopSimple>, AggregateFunctionFactory::CaseInsensitive);
factory.registerFunction("STDDEV_SAMP", createAggregateFunctionStatisticsUnary<AggregateFunctionStddevSampSimple>, AggregateFunctionFactory::CaseInsensitive);
factory.registerFunction("STDDEV_POP", createAggregateFunctionStatisticsUnary<AggregateFunctionStddevPopSimple>, AggregateFunctionFactory::CaseInsensitive);
factory.registerFunction("COVAR_SAMP", createAggregateFunctionStatisticsBinary<AggregateFunctionCovarSampSimple>, AggregateFunctionFactory::CaseInsensitive);
factory.registerFunction("COVAR_POP", createAggregateFunctionStatisticsBinary<AggregateFunctionCovarPopSimple>, AggregateFunctionFactory::CaseInsensitive);
}
}

View File

@ -13,7 +13,8 @@ void registerAggregateFunctionGroupArrayInsertAt(AggregateFunctionFactory & fact
void registerAggregateFunctionsQuantile(AggregateFunctionFactory & factory);
void registerAggregateFunctionsSequenceMatch(AggregateFunctionFactory & factory);
void registerAggregateFunctionsMinMaxAny(AggregateFunctionFactory & factory);
void registerAggregateFunctionsStatistics(AggregateFunctionFactory & factory);
void registerAggregateFunctionsStatisticsStable(AggregateFunctionFactory & factory);
void registerAggregateFunctionsStatisticsSimple(AggregateFunctionFactory & factory);
void registerAggregateFunctionSum(AggregateFunctionFactory & factory);
void registerAggregateFunctionSumMap(AggregateFunctionFactory & factory);
void registerAggregateFunctionsUniq(AggregateFunctionFactory & factory);
@ -34,7 +35,8 @@ void registerAggregateFunctions()
registerAggregateFunctionsQuantile(factory);
registerAggregateFunctionsSequenceMatch(factory);
registerAggregateFunctionsMinMaxAny(factory);
registerAggregateFunctionsStatistics(factory);
registerAggregateFunctionsStatisticsStable(factory);
registerAggregateFunctionsStatisticsSimple(factory);
registerAggregateFunctionSum(factory);
registerAggregateFunctionSumMap(factory);
registerAggregateFunctionsUniq(factory);