mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-26 09:32:01 +00:00
dbms: Server: Feature implementation. [#METR-16188]
This commit is contained in:
parent
bf6aecc826
commit
5f0a1cab74
@ -9,13 +9,9 @@
|
|||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
|
||||||
struct C;
|
|
||||||
struct C;
|
|
||||||
namespace DB
|
namespace DB
|
||||||
{
|
{
|
||||||
|
|
||||||
/// XXX Реализовать корреляцию (corr).
|
|
||||||
|
|
||||||
/** Статистические аггрегатные функции:
|
/** Статистические аггрегатные функции:
|
||||||
* varSamp - выборочная дисперсия
|
* varSamp - выборочная дисперсия
|
||||||
* stddevSamp - среднее выборочное квадратичное отклонение
|
* stddevSamp - среднее выборочное квадратичное отклонение
|
||||||
@ -23,6 +19,7 @@ namespace DB
|
|||||||
* stddevPop - среднее квадратичное отклонение
|
* stddevPop - среднее квадратичное отклонение
|
||||||
* covarSamp - выборочная ковариация
|
* covarSamp - выборочная ковариация
|
||||||
* covarPop - ковариация
|
* covarPop - ковариация
|
||||||
|
* corr - корреляция
|
||||||
*/
|
*/
|
||||||
|
|
||||||
/** Параллельный и инкрементальный алгоритм для вычисления дисперсии.
|
/** Параллельный и инкрементальный алгоритм для вычисления дисперсии.
|
||||||
@ -62,7 +59,7 @@ public:
|
|||||||
mean = (source.count * source.mean + count * mean) / total_count;
|
mean = (source.count * source.mean + count * mean) / total_count;
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
mean = source.mean + delta * (count / total_count);
|
mean = source.mean + delta * (static_cast<Float64>(count) / total_count);
|
||||||
|
|
||||||
m2 += source.m2 + delta * delta * factor;
|
m2 += source.m2 + delta * delta * factor;
|
||||||
count = total_count;
|
count = total_count;
|
||||||
@ -206,7 +203,7 @@ struct StdDevPopImpl
|
|||||||
* (J. Bennett et al., Sandia National Laboratories,
|
* (J. Bennett et al., Sandia National Laboratories,
|
||||||
* 2009 IEEE International Conference on Cluster Computing)
|
* 2009 IEEE International Conference on Cluster Computing)
|
||||||
*/
|
*/
|
||||||
template<typename T, typename U, typename Op>
|
template<typename T, typename U, typename Op, bool compute_marginal_moments>
|
||||||
class CovarianceData
|
class CovarianceData
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
@ -216,16 +213,25 @@ public:
|
|||||||
{
|
{
|
||||||
T left_received = static_cast<const ColumnVector<T> &>(column_left).getData()[row_num];
|
T left_received = static_cast<const ColumnVector<T> &>(column_left).getData()[row_num];
|
||||||
Float64 val_left = static_cast<Float64>(left_received);
|
Float64 val_left = static_cast<Float64>(left_received);
|
||||||
|
Float64 left_delta = val_left - left_mean;
|
||||||
|
|
||||||
U right_received = static_cast<const ColumnVector<U> &>(column_right).getData()[row_num];
|
U right_received = static_cast<const ColumnVector<U> &>(column_right).getData()[row_num];
|
||||||
Float64 val_right = static_cast<Float64>(right_received);
|
Float64 val_right = static_cast<Float64>(right_received);
|
||||||
|
Float64 right_delta = val_right - right_mean;
|
||||||
|
|
||||||
Float64 old_right_mean = right_mean;
|
Float64 old_right_mean = right_mean;
|
||||||
|
|
||||||
++count;
|
++count;
|
||||||
left_mean += (val_left - left_mean) / count;
|
|
||||||
right_mean += (val_right - right_mean) / count;
|
left_mean += left_delta / count;
|
||||||
|
right_mean += right_delta / count;
|
||||||
co_moment += (val_left - left_mean) * (val_right - old_right_mean);
|
co_moment += (val_left - left_mean) * (val_right - old_right_mean);
|
||||||
|
|
||||||
|
if (compute_marginal_moments)
|
||||||
|
{
|
||||||
|
left_m2 += left_delta * (val_left - left_mean);
|
||||||
|
right_m2 += right_delta * (val_right - right_mean);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void mergeWith(const CovarianceData & source)
|
void mergeWith(const CovarianceData & source)
|
||||||
@ -238,10 +244,16 @@ public:
|
|||||||
Float64 left_delta = left_mean - source.left_mean;
|
Float64 left_delta = left_mean - source.left_mean;
|
||||||
Float64 right_delta = right_mean - source.right_mean;
|
Float64 right_delta = right_mean - source.right_mean;
|
||||||
|
|
||||||
left_mean += left_delta * (source.count / total_count);
|
left_mean += left_delta * (static_cast<Float64>(source.count) / total_count);
|
||||||
right_mean += right_delta * (source.count / total_count);
|
right_mean += right_delta * (static_cast<Float64>(source.count) / total_count);
|
||||||
co_moment += source.co_moment + left_delta * right_delta * factor;
|
co_moment += source.co_moment + left_delta * right_delta * factor;
|
||||||
count = total_count;
|
count = total_count;
|
||||||
|
|
||||||
|
if (compute_marginal_moments)
|
||||||
|
{
|
||||||
|
left_m2 += source.left_m2 + left_delta * left_delta * factor;
|
||||||
|
right_m2 += source.right_m2 + right_delta * right_delta * factor;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void serialize(WriteBuffer & buf) const
|
void serialize(WriteBuffer & buf) const
|
||||||
@ -250,6 +262,12 @@ public:
|
|||||||
writeBinary(left_mean, buf);
|
writeBinary(left_mean, buf);
|
||||||
writeBinary(right_mean, buf);
|
writeBinary(right_mean, buf);
|
||||||
writeBinary(co_moment, buf);
|
writeBinary(co_moment, buf);
|
||||||
|
|
||||||
|
if (compute_marginal_moments)
|
||||||
|
{
|
||||||
|
writeBinary(left_m2, buf);
|
||||||
|
writeBinary(right_m2, buf);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void deserialize(ReadBuffer & buf)
|
void deserialize(ReadBuffer & buf)
|
||||||
@ -258,11 +276,17 @@ public:
|
|||||||
readBinary(left_mean, buf);
|
readBinary(left_mean, buf);
|
||||||
readBinary(right_mean, buf);
|
readBinary(right_mean, buf);
|
||||||
readBinary(co_moment, buf);
|
readBinary(co_moment, buf);
|
||||||
|
|
||||||
|
if (compute_marginal_moments)
|
||||||
|
{
|
||||||
|
readBinary(left_m2, buf);
|
||||||
|
readBinary(right_m2, buf);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void publish(IColumn & to) const
|
void publish(IColumn & to) const
|
||||||
{
|
{
|
||||||
static_cast<ColumnFloat64 &>(to).getData().push_back(Op::apply(co_moment, count));
|
static_cast<ColumnFloat64 &>(to).getData().push_back(Op::apply(co_moment, left_m2, right_m2, count));
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -270,10 +294,15 @@ private:
|
|||||||
Float64 left_mean = 0.0;
|
Float64 left_mean = 0.0;
|
||||||
Float64 right_mean = 0.0;
|
Float64 right_mean = 0.0;
|
||||||
Float64 co_moment = 0.0;
|
Float64 co_moment = 0.0;
|
||||||
|
Float64 left_m2 = 0.0;
|
||||||
|
Float64 right_m2 = 0.0;
|
||||||
};
|
};
|
||||||
|
|
||||||
template<typename T, typename U, typename Op>
|
template<typename T, typename U, typename Op, bool compute_marginal_moments = false>
|
||||||
class AggregateFunctionCovariance final : public IBinaryAggregateFunction<CovarianceData<T, U, Op>, AggregateFunctionCovariance<T, U, Op> >
|
class AggregateFunctionCovariance final
|
||||||
|
: public IBinaryAggregateFunction<
|
||||||
|
CovarianceData<T, U, Op, compute_marginal_moments>,
|
||||||
|
AggregateFunctionCovariance<T, U, Op, compute_marginal_moments> >
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
String getName() const override { return Op::name; }
|
String getName() const override { return Op::name; }
|
||||||
@ -311,7 +340,7 @@ public:
|
|||||||
|
|
||||||
void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const override
|
void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const override
|
||||||
{
|
{
|
||||||
CovarianceData<T, U, Op> source;
|
CovarianceData<T, U, Op, compute_marginal_moments> source;
|
||||||
source.deserialize(buf);
|
source.deserialize(buf);
|
||||||
|
|
||||||
this->data(place).mergeWith(source);
|
this->data(place).mergeWith(source);
|
||||||
@ -332,7 +361,7 @@ struct CovarSampImpl
|
|||||||
{
|
{
|
||||||
static constexpr auto name = "covarSamp";
|
static constexpr auto name = "covarSamp";
|
||||||
|
|
||||||
static inline Float64 apply(Float64 co_moment, UInt64 count)
|
static inline Float64 apply(Float64 co_moment, Float64 left_m2, Float64 right_m2, UInt64 count)
|
||||||
{
|
{
|
||||||
if (count < 2)
|
if (count < 2)
|
||||||
return 0.0;
|
return 0.0;
|
||||||
@ -347,7 +376,7 @@ struct CovarPopImpl
|
|||||||
{
|
{
|
||||||
static constexpr auto name = "covarPop";
|
static constexpr auto name = "covarPop";
|
||||||
|
|
||||||
static inline Float64 apply(Float64 co_moment, UInt64 count)
|
static inline Float64 apply(Float64 co_moment, Float64 left_m2, Float64 right_m2, UInt64 count)
|
||||||
{
|
{
|
||||||
if (count < 2)
|
if (count < 2)
|
||||||
return 0.0;
|
return 0.0;
|
||||||
@ -356,6 +385,21 @@ struct CovarPopImpl
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/** Реализация функции corr.
|
||||||
|
*/
|
||||||
|
struct CorrImpl
|
||||||
|
{
|
||||||
|
static constexpr auto name = "corr";
|
||||||
|
|
||||||
|
static inline Float64 apply(Float64 co_moment, Float64 left_m2, Float64 right_m2, UInt64 count)
|
||||||
|
{
|
||||||
|
if (count < 2)
|
||||||
|
return 0.0;
|
||||||
|
else
|
||||||
|
return co_moment / sqrt(left_m2 * right_m2);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
@ -376,4 +420,7 @@ using AggregateFunctionCovarSamp = AggregateFunctionCovariance<T, U, CovarSampIm
|
|||||||
template<typename T, typename U>
|
template<typename T, typename U>
|
||||||
using AggregateFunctionCovarPop = AggregateFunctionCovariance<T, U, CovarPopImpl>;
|
using AggregateFunctionCovarPop = AggregateFunctionCovariance<T, U, CovarPopImpl>;
|
||||||
|
|
||||||
|
template<typename T, typename U>
|
||||||
|
using AggregateFunctionCorr = AggregateFunctionCovariance<T, U, CorrImpl, true>;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -617,6 +617,18 @@ AggregateFunctionPtr AggregateFunctionFactory::get(const String & name, const Da
|
|||||||
|
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
else if (name == "corr")
|
||||||
|
{
|
||||||
|
if (argument_types.size() != 2)
|
||||||
|
throw Exception("Incorrect number of arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||||
|
|
||||||
|
AggregateFunctionPtr res = createWithTwoNumericTypes<AggregateFunctionCorr>(*argument_types[0], *argument_types[1]);
|
||||||
|
if (!res)
|
||||||
|
throw Exception("Illegal types " + argument_types[0]->getName() + " and " + argument_types[1]->getName()
|
||||||
|
+ " of arguments for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
else if (recursion_level == 0 && name.size() > strlen("State") && !(strcmp(name.data() + name.size() - strlen("State"), "State")))
|
else if (recursion_level == 0 && name.size() > strlen("State") && !(strcmp(name.data() + name.size() - strlen("State"), "State")))
|
||||||
{
|
{
|
||||||
/// Для агрегатных функций вида aggState, где agg - имя другой агрегатной функции.
|
/// Для агрегатных функций вида aggState, где agg - имя другой агрегатной функции.
|
||||||
@ -718,7 +730,8 @@ const AggregateFunctionFactory::FunctionNames & AggregateFunctionFactory::getFun
|
|||||||
"stddevSamp",
|
"stddevSamp",
|
||||||
"stddevPop",
|
"stddevPop",
|
||||||
"covarSamp",
|
"covarSamp",
|
||||||
"covarPop"
|
"covarPop",
|
||||||
|
"corr"
|
||||||
};
|
};
|
||||||
|
|
||||||
return names;
|
return names;
|
||||||
|
Loading…
Reference in New Issue
Block a user