mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-30 03:22:14 +00:00
This commit is contained in:
parent
c1e958a7d2
commit
8755f94548
@ -0,0 +1,38 @@
|
|||||||
|
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||||
|
#include <AggregateFunctions/AggregateFunctionAnalysisOfVariance.h>
|
||||||
|
#include <AggregateFunctions/FactoryHelpers.h>
|
||||||
|
|
||||||
|
namespace DB
|
||||||
|
{
|
||||||
|
|
||||||
|
namespace ErrorCodes
|
||||||
|
{
|
||||||
|
extern const int BAD_ARGUMENTS;
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace
|
||||||
|
{
|
||||||
|
|
||||||
|
AggregateFunctionPtr createAggregateFunctionAnalysisOfVariance(const std::string & name, const DataTypes & arguments, const Array & parameters, const Settings *)
|
||||||
|
{
|
||||||
|
assertNoParameters(name, parameters);
|
||||||
|
assertBinary(name, arguments);
|
||||||
|
|
||||||
|
if (!isNumber(arguments[0]) || !isNumber(arguments[1]))
|
||||||
|
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Aggregate function {} only supports numerical types", name);
|
||||||
|
|
||||||
|
return std::make_shared<AggregateFunctionAnalysisOfVariance>(arguments, parameters);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
void registerAggregateFunctionAnalysisOfVariance(AggregateFunctionFactory & factory)
|
||||||
|
{
|
||||||
|
AggregateFunctionProperties properties = { .is_order_dependent = false };
|
||||||
|
factory.registerFunction("analysisOfVariance", {createAggregateFunctionAnalysisOfVariance, properties}, AggregateFunctionFactory::CaseInsensitive);
|
||||||
|
|
||||||
|
/// This is widely used term
|
||||||
|
factory.registerAlias("anova", "analysisOfVariance", AggregateFunctionFactory::CaseInsensitive);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
98
src/AggregateFunctions/AggregateFunctionAnalysisOfVariance.h
Normal file
98
src/AggregateFunctions/AggregateFunctionAnalysisOfVariance.h
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <IO/VarInt.h>
|
||||||
|
#include <IO/WriteHelpers.h>
|
||||||
|
|
||||||
|
#include <array>
|
||||||
|
#include <DataTypes/DataTypesNumber.h>
|
||||||
|
#include <DataTypes/DataTypeTuple.h>
|
||||||
|
#include <Columns/ColumnNullable.h>
|
||||||
|
#include <Columns/ColumnsCommon.h>
|
||||||
|
#include <AggregateFunctions/IAggregateFunction.h>
|
||||||
|
#include <AggregateFunctions/Moments.h>
|
||||||
|
#include "Common/NaNUtils.h"
|
||||||
|
#include <Common/assert_cast.h>
|
||||||
|
#include <Core/Types.h>
|
||||||
|
|
||||||
|
namespace DB
|
||||||
|
{
|
||||||
|
|
||||||
|
namespace ErrorCodes
|
||||||
|
{
|
||||||
|
extern const int BAD_ARGUMENTS;
|
||||||
|
}
|
||||||
|
|
||||||
|
class AggregateFunctionAnalysisOfVarianceData final : public AnalysisOfVarianceMoments<Float64>
|
||||||
|
{
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
/// One way analysis of variance
|
||||||
|
/// Provides a statistical test of whether two or more population means are equal (null hypothesis)
|
||||||
|
/// Has an assumption that subjects from group i have normal distribution.
|
||||||
|
/// Accepts two arguments - a value and a group number which this value belongs to.
|
||||||
|
/// Groups are enumerated starting from 0 and there should be at least two groups to perform a test
|
||||||
|
/// Moreover there should be at least one group with the number of observations greater than one.
|
||||||
|
class AggregateFunctionAnalysisOfVariance final : public IAggregateFunctionDataHelper<AggregateFunctionAnalysisOfVarianceData, AggregateFunctionAnalysisOfVariance>
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
explicit AggregateFunctionAnalysisOfVariance(const DataTypes & arguments, const Array & params)
|
||||||
|
: IAggregateFunctionDataHelper(arguments, params)
|
||||||
|
{}
|
||||||
|
|
||||||
|
DataTypePtr getReturnType() const override
|
||||||
|
{
|
||||||
|
DataTypes types {std::make_shared<DataTypeNumber<Float64>>(), std::make_shared<DataTypeNumber<Float64>>() };
|
||||||
|
Strings names {"f_statistic", "p_value"};
|
||||||
|
return std::make_shared<DataTypeTuple>(
|
||||||
|
std::move(types),
|
||||||
|
std::move(names)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
String getName() const override { return "analysisOfVariance"; }
|
||||||
|
|
||||||
|
bool allocatesMemoryInArena() const override { return false; }
|
||||||
|
|
||||||
|
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||||
|
{
|
||||||
|
data(place).add(columns[0]->getFloat64(row_num), columns[1]->getUInt(row_num));
|
||||||
|
}
|
||||||
|
|
||||||
|
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||||
|
{
|
||||||
|
data(place).merge(data(rhs));
|
||||||
|
}
|
||||||
|
|
||||||
|
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||||
|
{
|
||||||
|
data(place).write(buf);
|
||||||
|
}
|
||||||
|
|
||||||
|
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
|
||||||
|
{
|
||||||
|
data(place).read(buf);
|
||||||
|
}
|
||||||
|
|
||||||
|
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||||
|
{
|
||||||
|
auto f_stat = data(place).getFStatistic();
|
||||||
|
if (std::isinf(f_stat) || isNaN(f_stat))
|
||||||
|
throw Exception("F statistic is not defined or infinite for these arguments", ErrorCodes::BAD_ARGUMENTS);
|
||||||
|
|
||||||
|
auto p_value = data(place).getPValue(f_stat);
|
||||||
|
|
||||||
|
/// Because p-value is a probability.
|
||||||
|
p_value = std::min(1.0, std::max(0.0, p_value));
|
||||||
|
|
||||||
|
auto & column_tuple = assert_cast<ColumnTuple &>(to);
|
||||||
|
auto & column_stat = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(0));
|
||||||
|
auto & column_value = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(1));
|
||||||
|
|
||||||
|
column_stat.getData().push_back(f_stat);
|
||||||
|
column_value.getData().push_back(p_value);
|
||||||
|
}
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
@ -4,7 +4,9 @@
|
|||||||
#include <IO/ReadHelpers.h>
|
#include <IO/ReadHelpers.h>
|
||||||
#include <boost/math/distributions/students_t.hpp>
|
#include <boost/math/distributions/students_t.hpp>
|
||||||
#include <boost/math/distributions/normal.hpp>
|
#include <boost/math/distributions/normal.hpp>
|
||||||
|
#include <boost/math/distributions/fisher_f.hpp>
|
||||||
#include <cfloat>
|
#include <cfloat>
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
|
|
||||||
namespace DB
|
namespace DB
|
||||||
@ -13,6 +15,7 @@ struct Settings;
|
|||||||
|
|
||||||
namespace ErrorCodes
|
namespace ErrorCodes
|
||||||
{
|
{
|
||||||
|
extern const int BAD_ARGUMENTS;
|
||||||
extern const int DECIMAL_OVERFLOW;
|
extern const int DECIMAL_OVERFLOW;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -476,4 +479,127 @@ struct ZTestMoments
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct AnalysisOfVarianceMoments
|
||||||
|
{
|
||||||
|
/// Sums of values within a group
|
||||||
|
std::vector<T> xs1{};
|
||||||
|
/// Sums of squared values within a group
|
||||||
|
std::vector<T> xs2{};
|
||||||
|
/// Sizes of each group. Total number of observations is just a sum of all these values
|
||||||
|
std::vector<size_t> ns{};
|
||||||
|
|
||||||
|
void resizeIfNeeded(size_t possible_size)
|
||||||
|
{
|
||||||
|
if (xs1.size() >= possible_size)
|
||||||
|
return;
|
||||||
|
|
||||||
|
xs1.resize(possible_size, 0.0);
|
||||||
|
xs2.resize(possible_size, 0.0);
|
||||||
|
ns.resize(possible_size, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
void add(T value, size_t group)
|
||||||
|
{
|
||||||
|
resizeIfNeeded(group + 1);
|
||||||
|
xs1[group] += value;
|
||||||
|
xs2[group] += value * value;
|
||||||
|
ns[group] += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
void merge(const AnalysisOfVarianceMoments & rhs)
|
||||||
|
{
|
||||||
|
resizeIfNeeded(rhs.xs1.size());
|
||||||
|
for (size_t i = 0; i < rhs.xs1.size(); ++i)
|
||||||
|
{
|
||||||
|
xs1[i] += rhs.xs1[i];
|
||||||
|
xs2[i] += rhs.xs2[i];
|
||||||
|
ns[i] += rhs.ns[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void write(WriteBuffer & buf) const
|
||||||
|
{
|
||||||
|
writeVectorBinary(xs1, buf);
|
||||||
|
writeVectorBinary(xs2, buf);
|
||||||
|
writeVectorBinary(ns, buf);
|
||||||
|
}
|
||||||
|
|
||||||
|
void read(ReadBuffer & buf)
|
||||||
|
{
|
||||||
|
readVectorBinary(xs1, buf);
|
||||||
|
readVectorBinary(xs2, buf);
|
||||||
|
readVectorBinary(ns, buf);
|
||||||
|
}
|
||||||
|
|
||||||
|
Float64 getMeanAll() const
|
||||||
|
{
|
||||||
|
const auto n = std::accumulate(ns.begin(), ns.end(), 0UL);
|
||||||
|
if (n == 0)
|
||||||
|
throw Exception(ErrorCodes::BAD_ARGUMENTS, "There are no observations to calculate mean value");
|
||||||
|
|
||||||
|
return std::accumulate(xs1.begin(), xs1.end(), 0.0) / n;
|
||||||
|
}
|
||||||
|
|
||||||
|
Float64 getMeanGroup(size_t group) const
|
||||||
|
{
|
||||||
|
if (ns[group] == 0)
|
||||||
|
throw Exception(ErrorCodes::BAD_ARGUMENTS, "There is no observations for group {}", group);
|
||||||
|
|
||||||
|
return xs1[group] / ns[group];
|
||||||
|
}
|
||||||
|
|
||||||
|
Float64 getBetweenGroupsVariation() const
|
||||||
|
{
|
||||||
|
Float64 res = 0;
|
||||||
|
auto mean = getMeanAll();
|
||||||
|
|
||||||
|
for (size_t i = 0; i < xs1.size(); ++i)
|
||||||
|
{
|
||||||
|
auto group_mean = getMeanGroup(i);
|
||||||
|
res += ns[i] * (group_mean - mean) * (group_mean - mean);
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
Float64 getWithinGroupsVariation() const
|
||||||
|
{
|
||||||
|
Float64 res = 0;
|
||||||
|
for (size_t i = 0; i < xs1.size(); ++i)
|
||||||
|
{
|
||||||
|
auto group_mean = getMeanGroup(i);
|
||||||
|
res += xs2[i] + ns[i] * group_mean * group_mean - 2 * group_mean * xs1[i];
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
Float64 getFStatistic() const
|
||||||
|
{
|
||||||
|
const auto k = xs1.size();
|
||||||
|
const auto n = std::accumulate(ns.begin(), ns.end(), 0UL);
|
||||||
|
|
||||||
|
if (k == 1)
|
||||||
|
throw Exception(ErrorCodes::BAD_ARGUMENTS, "There should be more than one group to calculate f-statistics");
|
||||||
|
|
||||||
|
if (k == n)
|
||||||
|
throw Exception(ErrorCodes::BAD_ARGUMENTS, "There is only one observation in each group");
|
||||||
|
|
||||||
|
return (getBetweenGroupsVariation() * (n - k)) / (getWithinGroupsVariation() * (k - 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
Float64 getPValue(Float64 f_statistic) const
|
||||||
|
{
|
||||||
|
const auto k = xs1.size();
|
||||||
|
const auto n = std::accumulate(ns.begin(), ns.end(), 0UL);
|
||||||
|
|
||||||
|
if (k == 1)
|
||||||
|
throw Exception(ErrorCodes::BAD_ARGUMENTS, "There should be more than one group to calculate f-statistics");
|
||||||
|
|
||||||
|
if (k == n)
|
||||||
|
throw Exception(ErrorCodes::BAD_ARGUMENTS, "There is only one observation in each group");
|
||||||
|
|
||||||
|
return 1.0f - boost::math::cdf(boost::math::fisher_f(k - 1, n - k), f_statistic);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -72,6 +72,7 @@ void registerAggregateFunctionNothing(AggregateFunctionFactory &);
|
|||||||
void registerAggregateFunctionExponentialMovingAverage(AggregateFunctionFactory &);
|
void registerAggregateFunctionExponentialMovingAverage(AggregateFunctionFactory &);
|
||||||
void registerAggregateFunctionSparkbar(AggregateFunctionFactory &);
|
void registerAggregateFunctionSparkbar(AggregateFunctionFactory &);
|
||||||
void registerAggregateFunctionIntervalLengthSum(AggregateFunctionFactory &);
|
void registerAggregateFunctionIntervalLengthSum(AggregateFunctionFactory &);
|
||||||
|
void registerAggregateFunctionAnalysisOfVariance(AggregateFunctionFactory &);
|
||||||
|
|
||||||
class AggregateFunctionCombinatorFactory;
|
class AggregateFunctionCombinatorFactory;
|
||||||
void registerAggregateFunctionCombinatorIf(AggregateFunctionCombinatorFactory &);
|
void registerAggregateFunctionCombinatorIf(AggregateFunctionCombinatorFactory &);
|
||||||
@ -156,6 +157,7 @@ void registerAggregateFunctions()
|
|||||||
registerAggregateFunctionIntervalLengthSum(factory);
|
registerAggregateFunctionIntervalLengthSum(factory);
|
||||||
registerAggregateFunctionExponentialMovingAverage(factory);
|
registerAggregateFunctionExponentialMovingAverage(factory);
|
||||||
registerAggregateFunctionSparkbar(factory);
|
registerAggregateFunctionSparkbar(factory);
|
||||||
|
registerAggregateFunctionAnalysisOfVariance(factory);
|
||||||
|
|
||||||
registerWindowFunctions(factory);
|
registerWindowFunctions(factory);
|
||||||
}
|
}
|
||||||
|
86
tests/queries/0_stateless/02294_anova_cmp.python
Normal file
86
tests/queries/0_stateless/02294_anova_cmp.python
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from statistics import variance
|
||||||
|
from scipy import stats
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
CURDIR = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
sys.path.insert(0, os.path.join(CURDIR, 'helpers'))
|
||||||
|
|
||||||
|
from pure_http_client import ClickHouseClient
|
||||||
|
|
||||||
|
|
||||||
|
# unpooled variance z-test for means of two samples
|
||||||
|
def scipy_anova(rvs):
|
||||||
|
return stats.f_oneway(*rvs)
|
||||||
|
|
||||||
|
|
||||||
|
def test_and_check(rvs, n_groups, f_stat, p_value, precision=1e-2):
|
||||||
|
client = ClickHouseClient()
|
||||||
|
client.query("DROP TABLE IF EXISTS anova;")
|
||||||
|
client.query("CREATE TABLE anova (left Float64, right UInt64) ENGINE = Memory;")
|
||||||
|
for group in range(n_groups):
|
||||||
|
client.query(f'''INSERT INTO anova VALUES {", ".join([f'({i},{group})' for i in rvs[group]])};''')
|
||||||
|
|
||||||
|
real = client.query_return_df(
|
||||||
|
'''SELECT roundBankers(a.1, 16) as f_stat, roundBankers(a.2, 16) as p_value FROM (SELECT anova(left, right) as a FROM anova) FORMAT TabSeparatedWithNames;''')
|
||||||
|
|
||||||
|
real_f_stat = real['f_stat'][0]
|
||||||
|
real_p_value = real['p_value'][0]
|
||||||
|
assert(abs(real_f_stat - np.float64(f_stat)) < precision), f"clickhouse_f_stat {real_f_stat}, py_f_stat {f_stat}"
|
||||||
|
assert(abs(real_p_value - np.float64(p_value)) < precision), f"clickhouse_p_value {real_p_value}, py_p_value {p_value}"
|
||||||
|
client.query("DROP TABLE IF EXISTS anova;")
|
||||||
|
|
||||||
|
|
||||||
|
def test_anova():
|
||||||
|
n_groups = 3
|
||||||
|
rvs = []
|
||||||
|
loc = 0
|
||||||
|
scale = 5
|
||||||
|
size = 500
|
||||||
|
for _ in range(n_groups):
|
||||||
|
rvs.append(np.round(stats.norm.rvs(loc=loc, scale=scale, size=size), 2))
|
||||||
|
loc += 5
|
||||||
|
f_stat, p_value = scipy_anova(rvs)
|
||||||
|
test_and_check(rvs, n_groups, f_stat, p_value)
|
||||||
|
|
||||||
|
n_groups = 6
|
||||||
|
rvs = []
|
||||||
|
loc = 0
|
||||||
|
scale = 5
|
||||||
|
size = 500
|
||||||
|
for _ in range(n_groups):
|
||||||
|
rvs.append(np.round(stats.norm.rvs(loc=loc, scale=scale, size=size), 2))
|
||||||
|
f_stat, p_value = scipy_anova(rvs)
|
||||||
|
test_and_check(rvs, n_groups, f_stat, p_value)
|
||||||
|
|
||||||
|
n_groups = 10
|
||||||
|
rvs = []
|
||||||
|
loc = 1
|
||||||
|
scale = 2
|
||||||
|
size = 100
|
||||||
|
for _ in range(n_groups):
|
||||||
|
rvs.append(np.round(stats.norm.rvs(loc=loc, scale=scale, size=size), 2))
|
||||||
|
loc += 1
|
||||||
|
scale += 2
|
||||||
|
size += 100
|
||||||
|
f_stat, p_value = scipy_anova(rvs)
|
||||||
|
test_and_check(rvs, n_groups, f_stat, p_value)
|
||||||
|
|
||||||
|
n_groups = 20
|
||||||
|
rvs = []
|
||||||
|
loc = 0
|
||||||
|
scale = 10
|
||||||
|
size = 1100
|
||||||
|
for _ in range(n_groups):
|
||||||
|
rvs.append(np.round(stats.norm.rvs(loc=loc, scale=scale, size=size), 2))
|
||||||
|
size -= 50
|
||||||
|
f_stat, p_value = scipy_anova(rvs)
|
||||||
|
test_and_check(rvs, n_groups, f_stat, p_value)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_anova()
|
||||||
|
print("Ok.")
|
1
tests/queries/0_stateless/02294_anova_cmp.reference
Normal file
1
tests/queries/0_stateless/02294_anova_cmp.reference
Normal file
@ -0,0 +1 @@
|
|||||||
|
Ok.
|
9
tests/queries/0_stateless/02294_anova_cmp.sh
Executable file
9
tests/queries/0_stateless/02294_anova_cmp.sh
Executable file
@ -0,0 +1,9 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
|
||||||
|
# shellcheck source=../shell_config.sh
|
||||||
|
. "$CURDIR"/../shell_config.sh
|
||||||
|
|
||||||
|
# We should have correct env vars from shell_config.sh to run this test
|
||||||
|
|
||||||
|
python3 "$CURDIR"/02294_anova_cmp.python
|
Loading…
Reference in New Issue
Block a user