Added simple, non-numerically stable variants of statistical functions; use it by default; old functions are accessible under -Stable suffix, like varPopStable [#CLICKHOUSE-2].

This commit is contained in:
Alexey Milovidov 2017-12-23 01:59:45 +03:00
parent a7b8541cea
commit 69a7761812
3 changed files with 34 additions and 28 deletions

View File

@ -168,10 +168,16 @@ enum class StatisticsFunctionKind
};
template <typename T, typename Data, StatisticsFunctionKind Kind>
template <typename T1, typename T2>
using VarianceCalcType = std::conditional_t<std::is_same_v<T1, Float32> && std::is_same_v<T2, Float32>, Float32, Float64>;
template <typename T1, typename T2, typename Data, StatisticsFunctionKind Kind>
class AggregateFunctionVarianceSimple final
: public IAggregateFunctionDataHelper<Data, AggregateFunctionVarianceSimple<T, Data, Kind>>
: public IAggregateFunctionDataHelper<Data, AggregateFunctionVarianceSimple<T1, T2, Data, Kind>>
{
using ResultType = VarianceCalcType<T1, T2>;
public:
String getName() const override
{
@ -189,18 +195,18 @@ public:
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeNumber<T>>();
return std::make_shared<DataTypeNumber<ResultType>>();
}
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
{
if constexpr (Kind == StatisticsFunctionKind::covarPop || Kind == StatisticsFunctionKind::covarSamp || Kind == StatisticsFunctionKind::corr)
this->data(place).add(
static_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num],
static_cast<const ColumnVector<T> &>(*columns[1]).getData()[row_num]);
static_cast<const ColumnVector<T1> &>(*columns[0]).getData()[row_num],
static_cast<const ColumnVector<T2> &>(*columns[1]).getData()[row_num]);
else
this->data(place).add(
static_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num]);
static_cast<const ColumnVector<T1> &>(*columns[0]).getData()[row_num]);
}
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
@ -221,7 +227,7 @@ public:
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
const auto & data = this->data(place);
auto & dst = static_cast<ColumnVector<T> &>(to).getData();
auto & dst = static_cast<ColumnVector<ResultType> &>(to).getData();
if constexpr (Kind == StatisticsFunctionKind::varPop) dst.push_back(data.template get<VarianceMode::Population, VariancePower::Original>());
else if constexpr (Kind == StatisticsFunctionKind::varSamp) dst.push_back(data.template get<VarianceMode::Sample, VariancePower::Original>());
@ -236,12 +242,12 @@ public:
};
template <typename T> using AggregateFunctionVarPopSimple = AggregateFunctionVarianceSimple<T, VarMoments<T>, StatisticsFunctionKind::varPop>;
template <typename T> using AggregateFunctionVarSampSimple = AggregateFunctionVarianceSimple<T, VarMoments<T>, StatisticsFunctionKind::varSamp>;
template <typename T> using AggregateFunctionStddevPopSimple = AggregateFunctionVarianceSimple<T, VarMoments<T>, StatisticsFunctionKind::stddevPop>;
template <typename T> using AggregateFunctionStddevSampSimple = AggregateFunctionVarianceSimple<T, VarMoments<T>, StatisticsFunctionKind::stddevSamp>;
template <typename T> using AggregateFunctionCovarPopSimple = AggregateFunctionVarianceSimple<T, CovarMoments<T>, StatisticsFunctionKind::covarPop>;
template <typename T> using AggregateFunctionCovarSampSimple = AggregateFunctionVarianceSimple<T, CovarMoments<T>, StatisticsFunctionKind::covarSamp>;
template <typename T> using AggregateFunctionCorrSimple = AggregateFunctionVarianceSimple<T, CorrMoments<T>, StatisticsFunctionKind::corr>;
template <typename T> using AggregateFunctionVarPopSimple = AggregateFunctionVarianceSimple<T, T, VarMoments<VarianceCalcType<T, T>>, StatisticsFunctionKind::varPop>;
template <typename T> using AggregateFunctionVarSampSimple = AggregateFunctionVarianceSimple<T, T, VarMoments<VarianceCalcType<T, T>>, StatisticsFunctionKind::varSamp>;
template <typename T> using AggregateFunctionStddevPopSimple = AggregateFunctionVarianceSimple<T, T, VarMoments<VarianceCalcType<T, T>>, StatisticsFunctionKind::stddevPop>;
template <typename T> using AggregateFunctionStddevSampSimple = AggregateFunctionVarianceSimple<T, T, VarMoments<VarianceCalcType<T, T>>, StatisticsFunctionKind::stddevSamp>;
template <typename T1, typename T2> using AggregateFunctionCovarPopSimple = AggregateFunctionVarianceSimple<T1, T2, CovarMoments<VarianceCalcType<T1, T2>>, StatisticsFunctionKind::covarPop>;
template <typename T1, typename T2> using AggregateFunctionCovarSampSimple = AggregateFunctionVarianceSimple<T1, T2, CovarMoments<VarianceCalcType<T1, T2>>, StatisticsFunctionKind::covarSamp>;
template <typename T1, typename T2> using AggregateFunctionCorrSimple = AggregateFunctionVarianceSimple<T1, T2, CorrMoments<VarianceCalcType<T1, T2>>, StatisticsFunctionKind::corr>;
}

View File

@ -21,26 +21,26 @@ AggregateFunctionPtr createAggregateFunctionStatisticsUnary(const std::string &
assertNoParameters(name, parameters);
assertUnary(name, argument_types);
if (typeid_cast<const DataTypeFloat32 *>(argument_types[0].get())) return std::make_shared<FunctionTemplate<Float32>>();
if (typeid_cast<const DataTypeFloat64 *>(argument_types[0].get())) return std::make_shared<FunctionTemplate<Float64>>();
AggregateFunctionPtr res(createWithNumericType<FunctionTemplate>(*argument_types[0]));
throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (!res)
throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return res;
}
template <template <typename> class FunctionTemplate>
template <template <typename, typename> class FunctionTemplate>
AggregateFunctionPtr createAggregateFunctionStatisticsBinary(const std::string & name, const DataTypes & argument_types, const Array & parameters)
{
assertNoParameters(name, parameters);
assertBinary(name, argument_types);
if (!argument_types[0]->equals(*argument_types[1]))
AggregateFunctionPtr res(createWithTwoNumericTypes<FunctionTemplate>(*argument_types[0], *argument_types[1]));
if (!res)
throw Exception("Illegal types " + argument_types[0]->getName() + " and " + argument_types[1]->getName()
+ " of arguments for aggregate function " + name + ", must be the same", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
+ " of arguments for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (typeid_cast<const DataTypeFloat32 *>(argument_types[0].get())) return std::make_shared<FunctionTemplate<Float32>>();
if (typeid_cast<const DataTypeFloat64 *>(argument_types[0].get())) return std::make_shared<FunctionTemplate<Float64>>();
throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return res;
}
}

View File

@ -1,14 +1,14 @@
inf
nan
0
inf
nan
0
0
0
0
0
inf
nan
0
0
0
inf
nan
0