mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-26 09:32:01 +00:00
This commit is contained in:
parent
c1e958a7d2
commit
8755f94548
@ -0,0 +1,38 @@
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/AggregateFunctionAnalysisOfVariance.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int BAD_ARGUMENTS;
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionAnalysisOfVariance(const std::string & name, const DataTypes & arguments, const Array & parameters, const Settings *)
|
||||
{
|
||||
assertNoParameters(name, parameters);
|
||||
assertBinary(name, arguments);
|
||||
|
||||
if (!isNumber(arguments[0]) || !isNumber(arguments[1]))
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Aggregate function {} only supports numerical types", name);
|
||||
|
||||
return std::make_shared<AggregateFunctionAnalysisOfVariance>(arguments, parameters);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void registerAggregateFunctionAnalysisOfVariance(AggregateFunctionFactory & factory)
|
||||
{
|
||||
AggregateFunctionProperties properties = { .is_order_dependent = false };
|
||||
factory.registerFunction("analysisOfVariance", {createAggregateFunctionAnalysisOfVariance, properties}, AggregateFunctionFactory::CaseInsensitive);
|
||||
|
||||
/// This is widely used term
|
||||
factory.registerAlias("anova", "analysisOfVariance", AggregateFunctionFactory::CaseInsensitive);
|
||||
}
|
||||
|
||||
}
|
98
src/AggregateFunctions/AggregateFunctionAnalysisOfVariance.h
Normal file
98
src/AggregateFunctions/AggregateFunctionAnalysisOfVariance.h
Normal file
@ -0,0 +1,98 @@
|
||||
#pragma once
|
||||
|
||||
#include <IO/VarInt.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
|
||||
#include <array>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <DataTypes/DataTypeTuple.h>
|
||||
#include <Columns/ColumnNullable.h>
|
||||
#include <Columns/ColumnsCommon.h>
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
#include <AggregateFunctions/Moments.h>
|
||||
#include "Common/NaNUtils.h"
|
||||
#include <Common/assert_cast.h>
|
||||
#include <Core/Types.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int BAD_ARGUMENTS;
|
||||
}
|
||||
|
||||
class AggregateFunctionAnalysisOfVarianceData final : public AnalysisOfVarianceMoments<Float64>
|
||||
{
|
||||
};
|
||||
|
||||
|
||||
/// One way analysis of variance
|
||||
/// Provides a statistical test of whether two or more population means are equal (null hypothesis)
|
||||
/// Has an assumption that subjects from group i have normal distribution.
|
||||
/// Accepts two arguments - a value and a group number which this value belongs to.
|
||||
/// Groups are enumerated starting from 0 and there should be at least two groups to perform a test
|
||||
/// Moreover there should be at least one group with the number of observations greater than one.
|
||||
class AggregateFunctionAnalysisOfVariance final : public IAggregateFunctionDataHelper<AggregateFunctionAnalysisOfVarianceData, AggregateFunctionAnalysisOfVariance>
|
||||
{
|
||||
public:
|
||||
explicit AggregateFunctionAnalysisOfVariance(const DataTypes & arguments, const Array & params)
|
||||
: IAggregateFunctionDataHelper(arguments, params)
|
||||
{}
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
{
|
||||
DataTypes types {std::make_shared<DataTypeNumber<Float64>>(), std::make_shared<DataTypeNumber<Float64>>() };
|
||||
Strings names {"f_statistic", "p_value"};
|
||||
return std::make_shared<DataTypeTuple>(
|
||||
std::move(types),
|
||||
std::move(names)
|
||||
);
|
||||
}
|
||||
|
||||
String getName() const override { return "analysisOfVariance"; }
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
{
|
||||
data(place).add(columns[0]->getFloat64(row_num), columns[1]->getUInt(row_num));
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
data(place).merge(data(rhs));
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
data(place).write(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
|
||||
{
|
||||
data(place).read(buf);
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
auto f_stat = data(place).getFStatistic();
|
||||
if (std::isinf(f_stat) || isNaN(f_stat))
|
||||
throw Exception("F statistic is not defined or infinite for these arguments", ErrorCodes::BAD_ARGUMENTS);
|
||||
|
||||
auto p_value = data(place).getPValue(f_stat);
|
||||
|
||||
/// Because p-value is a probability.
|
||||
p_value = std::min(1.0, std::max(0.0, p_value));
|
||||
|
||||
auto & column_tuple = assert_cast<ColumnTuple &>(to);
|
||||
auto & column_stat = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(0));
|
||||
auto & column_value = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(1));
|
||||
|
||||
column_stat.getData().push_back(f_stat);
|
||||
column_value.getData().push_back(p_value);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
}
|
@ -4,7 +4,9 @@
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <boost/math/distributions/students_t.hpp>
|
||||
#include <boost/math/distributions/normal.hpp>
|
||||
#include <boost/math/distributions/fisher_f.hpp>
|
||||
#include <cfloat>
|
||||
#include <numeric>
|
||||
|
||||
|
||||
namespace DB
|
||||
@ -13,6 +15,7 @@ struct Settings;
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int BAD_ARGUMENTS;
|
||||
extern const int DECIMAL_OVERFLOW;
|
||||
}
|
||||
|
||||
@ -476,4 +479,127 @@ struct ZTestMoments
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct AnalysisOfVarianceMoments
|
||||
{
|
||||
/// Sums of values within a group
|
||||
std::vector<T> xs1{};
|
||||
/// Sums of squared values within a group
|
||||
std::vector<T> xs2{};
|
||||
/// Sizes of each group. Total number of observations is just a sum of all these values
|
||||
std::vector<size_t> ns{};
|
||||
|
||||
void resizeIfNeeded(size_t possible_size)
|
||||
{
|
||||
if (xs1.size() >= possible_size)
|
||||
return;
|
||||
|
||||
xs1.resize(possible_size, 0.0);
|
||||
xs2.resize(possible_size, 0.0);
|
||||
ns.resize(possible_size, 0);
|
||||
}
|
||||
|
||||
void add(T value, size_t group)
|
||||
{
|
||||
resizeIfNeeded(group + 1);
|
||||
xs1[group] += value;
|
||||
xs2[group] += value * value;
|
||||
ns[group] += 1;
|
||||
}
|
||||
|
||||
void merge(const AnalysisOfVarianceMoments & rhs)
|
||||
{
|
||||
resizeIfNeeded(rhs.xs1.size());
|
||||
for (size_t i = 0; i < rhs.xs1.size(); ++i)
|
||||
{
|
||||
xs1[i] += rhs.xs1[i];
|
||||
xs2[i] += rhs.xs2[i];
|
||||
ns[i] += rhs.ns[i];
|
||||
}
|
||||
}
|
||||
|
||||
void write(WriteBuffer & buf) const
|
||||
{
|
||||
writeVectorBinary(xs1, buf);
|
||||
writeVectorBinary(xs2, buf);
|
||||
writeVectorBinary(ns, buf);
|
||||
}
|
||||
|
||||
void read(ReadBuffer & buf)
|
||||
{
|
||||
readVectorBinary(xs1, buf);
|
||||
readVectorBinary(xs2, buf);
|
||||
readVectorBinary(ns, buf);
|
||||
}
|
||||
|
||||
Float64 getMeanAll() const
|
||||
{
|
||||
const auto n = std::accumulate(ns.begin(), ns.end(), 0UL);
|
||||
if (n == 0)
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "There are no observations to calculate mean value");
|
||||
|
||||
return std::accumulate(xs1.begin(), xs1.end(), 0.0) / n;
|
||||
}
|
||||
|
||||
Float64 getMeanGroup(size_t group) const
|
||||
{
|
||||
if (ns[group] == 0)
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "There is no observations for group {}", group);
|
||||
|
||||
return xs1[group] / ns[group];
|
||||
}
|
||||
|
||||
Float64 getBetweenGroupsVariation() const
|
||||
{
|
||||
Float64 res = 0;
|
||||
auto mean = getMeanAll();
|
||||
|
||||
for (size_t i = 0; i < xs1.size(); ++i)
|
||||
{
|
||||
auto group_mean = getMeanGroup(i);
|
||||
res += ns[i] * (group_mean - mean) * (group_mean - mean);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
Float64 getWithinGroupsVariation() const
|
||||
{
|
||||
Float64 res = 0;
|
||||
for (size_t i = 0; i < xs1.size(); ++i)
|
||||
{
|
||||
auto group_mean = getMeanGroup(i);
|
||||
res += xs2[i] + ns[i] * group_mean * group_mean - 2 * group_mean * xs1[i];
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
Float64 getFStatistic() const
|
||||
{
|
||||
const auto k = xs1.size();
|
||||
const auto n = std::accumulate(ns.begin(), ns.end(), 0UL);
|
||||
|
||||
if (k == 1)
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "There should be more than one group to calculate f-statistics");
|
||||
|
||||
if (k == n)
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "There is only one observation in each group");
|
||||
|
||||
return (getBetweenGroupsVariation() * (n - k)) / (getWithinGroupsVariation() * (k - 1));
|
||||
}
|
||||
|
||||
Float64 getPValue(Float64 f_statistic) const
|
||||
{
|
||||
const auto k = xs1.size();
|
||||
const auto n = std::accumulate(ns.begin(), ns.end(), 0UL);
|
||||
|
||||
if (k == 1)
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "There should be more than one group to calculate f-statistics");
|
||||
|
||||
if (k == n)
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "There is only one observation in each group");
|
||||
|
||||
return 1.0f - boost::math::cdf(boost::math::fisher_f(k - 1, n - k), f_statistic);
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
|
@ -72,6 +72,7 @@ void registerAggregateFunctionNothing(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionExponentialMovingAverage(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionSparkbar(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionIntervalLengthSum(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionAnalysisOfVariance(AggregateFunctionFactory &);
|
||||
|
||||
class AggregateFunctionCombinatorFactory;
|
||||
void registerAggregateFunctionCombinatorIf(AggregateFunctionCombinatorFactory &);
|
||||
@ -156,6 +157,7 @@ void registerAggregateFunctions()
|
||||
registerAggregateFunctionIntervalLengthSum(factory);
|
||||
registerAggregateFunctionExponentialMovingAverage(factory);
|
||||
registerAggregateFunctionSparkbar(factory);
|
||||
registerAggregateFunctionAnalysisOfVariance(factory);
|
||||
|
||||
registerWindowFunctions(factory);
|
||||
}
|
||||
|
86
tests/queries/0_stateless/02294_anova_cmp.python
Normal file
86
tests/queries/0_stateless/02294_anova_cmp.python
Normal file
@ -0,0 +1,86 @@
|
||||
#!/usr/bin/env python3
|
||||
import os
|
||||
import sys
|
||||
from statistics import variance
|
||||
from scipy import stats
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
CURDIR = os.path.dirname(os.path.realpath(__file__))
|
||||
sys.path.insert(0, os.path.join(CURDIR, 'helpers'))
|
||||
|
||||
from pure_http_client import ClickHouseClient
|
||||
|
||||
|
||||
# unpooled variance z-test for means of two samples
|
||||
def scipy_anova(rvs):
|
||||
return stats.f_oneway(*rvs)
|
||||
|
||||
|
||||
def test_and_check(rvs, n_groups, f_stat, p_value, precision=1e-2):
|
||||
client = ClickHouseClient()
|
||||
client.query("DROP TABLE IF EXISTS anova;")
|
||||
client.query("CREATE TABLE anova (left Float64, right UInt64) ENGINE = Memory;")
|
||||
for group in range(n_groups):
|
||||
client.query(f'''INSERT INTO anova VALUES {", ".join([f'({i},{group})' for i in rvs[group]])};''')
|
||||
|
||||
real = client.query_return_df(
|
||||
'''SELECT roundBankers(a.1, 16) as f_stat, roundBankers(a.2, 16) as p_value FROM (SELECT anova(left, right) as a FROM anova) FORMAT TabSeparatedWithNames;''')
|
||||
|
||||
real_f_stat = real['f_stat'][0]
|
||||
real_p_value = real['p_value'][0]
|
||||
assert(abs(real_f_stat - np.float64(f_stat)) < precision), f"clickhouse_f_stat {real_f_stat}, py_f_stat {f_stat}"
|
||||
assert(abs(real_p_value - np.float64(p_value)) < precision), f"clickhouse_p_value {real_p_value}, py_p_value {p_value}"
|
||||
client.query("DROP TABLE IF EXISTS anova;")
|
||||
|
||||
|
||||
def test_anova():
|
||||
n_groups = 3
|
||||
rvs = []
|
||||
loc = 0
|
||||
scale = 5
|
||||
size = 500
|
||||
for _ in range(n_groups):
|
||||
rvs.append(np.round(stats.norm.rvs(loc=loc, scale=scale, size=size), 2))
|
||||
loc += 5
|
||||
f_stat, p_value = scipy_anova(rvs)
|
||||
test_and_check(rvs, n_groups, f_stat, p_value)
|
||||
|
||||
n_groups = 6
|
||||
rvs = []
|
||||
loc = 0
|
||||
scale = 5
|
||||
size = 500
|
||||
for _ in range(n_groups):
|
||||
rvs.append(np.round(stats.norm.rvs(loc=loc, scale=scale, size=size), 2))
|
||||
f_stat, p_value = scipy_anova(rvs)
|
||||
test_and_check(rvs, n_groups, f_stat, p_value)
|
||||
|
||||
n_groups = 10
|
||||
rvs = []
|
||||
loc = 1
|
||||
scale = 2
|
||||
size = 100
|
||||
for _ in range(n_groups):
|
||||
rvs.append(np.round(stats.norm.rvs(loc=loc, scale=scale, size=size), 2))
|
||||
loc += 1
|
||||
scale += 2
|
||||
size += 100
|
||||
f_stat, p_value = scipy_anova(rvs)
|
||||
test_and_check(rvs, n_groups, f_stat, p_value)
|
||||
|
||||
n_groups = 20
|
||||
rvs = []
|
||||
loc = 0
|
||||
scale = 10
|
||||
size = 1100
|
||||
for _ in range(n_groups):
|
||||
rvs.append(np.round(stats.norm.rvs(loc=loc, scale=scale, size=size), 2))
|
||||
size -= 50
|
||||
f_stat, p_value = scipy_anova(rvs)
|
||||
test_and_check(rvs, n_groups, f_stat, p_value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_anova()
|
||||
print("Ok.")
|
1
tests/queries/0_stateless/02294_anova_cmp.reference
Normal file
1
tests/queries/0_stateless/02294_anova_cmp.reference
Normal file
@ -0,0 +1 @@
|
||||
Ok.
|
9
tests/queries/0_stateless/02294_anova_cmp.sh
Executable file
9
tests/queries/0_stateless/02294_anova_cmp.sh
Executable file
@ -0,0 +1,9 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
|
||||
# shellcheck source=../shell_config.sh
|
||||
. "$CURDIR"/../shell_config.sh
|
||||
|
||||
# We should have correct env vars from shell_config.sh to run this test
|
||||
|
||||
python3 "$CURDIR"/02294_anova_cmp.python
|
Loading…
Reference in New Issue
Block a user