Added aggregate function analysisOfVariance (anova). Merging #37872 (#42131)

This commit is contained in:
Nikita Mikhaylov 2022-10-18 14:57:56 +02:00 committed by GitHub
parent c1e958a7d2
commit 8755f94548
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 360 additions and 0 deletions

View File

@ -0,0 +1,38 @@
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/AggregateFunctionAnalysisOfVariance.h>
#include <AggregateFunctions/FactoryHelpers.h>
namespace DB
{
namespace ErrorCodes
{
extern const int BAD_ARGUMENTS;
}
namespace
{
AggregateFunctionPtr createAggregateFunctionAnalysisOfVariance(const std::string & name, const DataTypes & arguments, const Array & parameters, const Settings *)
{
assertNoParameters(name, parameters);
assertBinary(name, arguments);
if (!isNumber(arguments[0]) || !isNumber(arguments[1]))
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Aggregate function {} only supports numerical types", name);
return std::make_shared<AggregateFunctionAnalysisOfVariance>(arguments, parameters);
}
}
void registerAggregateFunctionAnalysisOfVariance(AggregateFunctionFactory & factory)
{
AggregateFunctionProperties properties = { .is_order_dependent = false };
factory.registerFunction("analysisOfVariance", {createAggregateFunctionAnalysisOfVariance, properties}, AggregateFunctionFactory::CaseInsensitive);
/// This is widely used term
factory.registerAlias("anova", "analysisOfVariance", AggregateFunctionFactory::CaseInsensitive);
}
}

View File

@ -0,0 +1,98 @@
#pragma once
#include <IO/VarInt.h>
#include <IO/WriteHelpers.h>
#include <array>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypeTuple.h>
#include <Columns/ColumnNullable.h>
#include <Columns/ColumnsCommon.h>
#include <AggregateFunctions/IAggregateFunction.h>
#include <AggregateFunctions/Moments.h>
#include "Common/NaNUtils.h"
#include <Common/assert_cast.h>
#include <Core/Types.h>
namespace DB
{
namespace ErrorCodes
{
extern const int BAD_ARGUMENTS;
}
class AggregateFunctionAnalysisOfVarianceData final : public AnalysisOfVarianceMoments<Float64>
{
};
/// One way analysis of variance
/// Provides a statistical test of whether two or more population means are equal (null hypothesis)
/// Has an assumption that subjects from group i have normal distribution.
/// Accepts two arguments - a value and a group number which this value belongs to.
/// Groups are enumerated starting from 0 and there should be at least two groups to perform a test
/// Moreover there should be at least one group with the number of observations greater than one.
class AggregateFunctionAnalysisOfVariance final : public IAggregateFunctionDataHelper<AggregateFunctionAnalysisOfVarianceData, AggregateFunctionAnalysisOfVariance>
{
public:
explicit AggregateFunctionAnalysisOfVariance(const DataTypes & arguments, const Array & params)
: IAggregateFunctionDataHelper(arguments, params)
{}
DataTypePtr getReturnType() const override
{
DataTypes types {std::make_shared<DataTypeNumber<Float64>>(), std::make_shared<DataTypeNumber<Float64>>() };
Strings names {"f_statistic", "p_value"};
return std::make_shared<DataTypeTuple>(
std::move(types),
std::move(names)
);
}
String getName() const override { return "analysisOfVariance"; }
bool allocatesMemoryInArena() const override { return false; }
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
{
data(place).add(columns[0]->getFloat64(row_num), columns[1]->getUInt(row_num));
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
{
data(place).merge(data(rhs));
}
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
{
data(place).write(buf);
}
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
{
data(place).read(buf);
}
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
{
auto f_stat = data(place).getFStatistic();
if (std::isinf(f_stat) || isNaN(f_stat))
throw Exception("F statistic is not defined or infinite for these arguments", ErrorCodes::BAD_ARGUMENTS);
auto p_value = data(place).getPValue(f_stat);
/// Because p-value is a probability.
p_value = std::min(1.0, std::max(0.0, p_value));
auto & column_tuple = assert_cast<ColumnTuple &>(to);
auto & column_stat = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(0));
auto & column_value = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(1));
column_stat.getData().push_back(f_stat);
column_value.getData().push_back(p_value);
}
};
}

View File

@ -4,7 +4,9 @@
#include <IO/ReadHelpers.h> #include <IO/ReadHelpers.h>
#include <boost/math/distributions/students_t.hpp> #include <boost/math/distributions/students_t.hpp>
#include <boost/math/distributions/normal.hpp> #include <boost/math/distributions/normal.hpp>
#include <boost/math/distributions/fisher_f.hpp>
#include <cfloat> #include <cfloat>
#include <numeric>
namespace DB namespace DB
@ -13,6 +15,7 @@ struct Settings;
namespace ErrorCodes namespace ErrorCodes
{ {
extern const int BAD_ARGUMENTS;
extern const int DECIMAL_OVERFLOW; extern const int DECIMAL_OVERFLOW;
} }
@ -476,4 +479,127 @@ struct ZTestMoments
} }
}; };
template <typename T>
struct AnalysisOfVarianceMoments
{
/// Sums of values within a group
std::vector<T> xs1{};
/// Sums of squared values within a group
std::vector<T> xs2{};
/// Sizes of each group. Total number of observations is just a sum of all these values
std::vector<size_t> ns{};
void resizeIfNeeded(size_t possible_size)
{
if (xs1.size() >= possible_size)
return;
xs1.resize(possible_size, 0.0);
xs2.resize(possible_size, 0.0);
ns.resize(possible_size, 0);
}
void add(T value, size_t group)
{
resizeIfNeeded(group + 1);
xs1[group] += value;
xs2[group] += value * value;
ns[group] += 1;
}
void merge(const AnalysisOfVarianceMoments & rhs)
{
resizeIfNeeded(rhs.xs1.size());
for (size_t i = 0; i < rhs.xs1.size(); ++i)
{
xs1[i] += rhs.xs1[i];
xs2[i] += rhs.xs2[i];
ns[i] += rhs.ns[i];
}
}
void write(WriteBuffer & buf) const
{
writeVectorBinary(xs1, buf);
writeVectorBinary(xs2, buf);
writeVectorBinary(ns, buf);
}
void read(ReadBuffer & buf)
{
readVectorBinary(xs1, buf);
readVectorBinary(xs2, buf);
readVectorBinary(ns, buf);
}
Float64 getMeanAll() const
{
const auto n = std::accumulate(ns.begin(), ns.end(), 0UL);
if (n == 0)
throw Exception(ErrorCodes::BAD_ARGUMENTS, "There are no observations to calculate mean value");
return std::accumulate(xs1.begin(), xs1.end(), 0.0) / n;
}
Float64 getMeanGroup(size_t group) const
{
if (ns[group] == 0)
throw Exception(ErrorCodes::BAD_ARGUMENTS, "There is no observations for group {}", group);
return xs1[group] / ns[group];
}
Float64 getBetweenGroupsVariation() const
{
Float64 res = 0;
auto mean = getMeanAll();
for (size_t i = 0; i < xs1.size(); ++i)
{
auto group_mean = getMeanGroup(i);
res += ns[i] * (group_mean - mean) * (group_mean - mean);
}
return res;
}
Float64 getWithinGroupsVariation() const
{
Float64 res = 0;
for (size_t i = 0; i < xs1.size(); ++i)
{
auto group_mean = getMeanGroup(i);
res += xs2[i] + ns[i] * group_mean * group_mean - 2 * group_mean * xs1[i];
}
return res;
}
Float64 getFStatistic() const
{
const auto k = xs1.size();
const auto n = std::accumulate(ns.begin(), ns.end(), 0UL);
if (k == 1)
throw Exception(ErrorCodes::BAD_ARGUMENTS, "There should be more than one group to calculate f-statistics");
if (k == n)
throw Exception(ErrorCodes::BAD_ARGUMENTS, "There is only one observation in each group");
return (getBetweenGroupsVariation() * (n - k)) / (getWithinGroupsVariation() * (k - 1));
}
Float64 getPValue(Float64 f_statistic) const
{
const auto k = xs1.size();
const auto n = std::accumulate(ns.begin(), ns.end(), 0UL);
if (k == 1)
throw Exception(ErrorCodes::BAD_ARGUMENTS, "There should be more than one group to calculate f-statistics");
if (k == n)
throw Exception(ErrorCodes::BAD_ARGUMENTS, "There is only one observation in each group");
return 1.0f - boost::math::cdf(boost::math::fisher_f(k - 1, n - k), f_statistic);
}
};
} }

View File

@ -72,6 +72,7 @@ void registerAggregateFunctionNothing(AggregateFunctionFactory &);
void registerAggregateFunctionExponentialMovingAverage(AggregateFunctionFactory &); void registerAggregateFunctionExponentialMovingAverage(AggregateFunctionFactory &);
void registerAggregateFunctionSparkbar(AggregateFunctionFactory &); void registerAggregateFunctionSparkbar(AggregateFunctionFactory &);
void registerAggregateFunctionIntervalLengthSum(AggregateFunctionFactory &); void registerAggregateFunctionIntervalLengthSum(AggregateFunctionFactory &);
void registerAggregateFunctionAnalysisOfVariance(AggregateFunctionFactory &);
class AggregateFunctionCombinatorFactory; class AggregateFunctionCombinatorFactory;
void registerAggregateFunctionCombinatorIf(AggregateFunctionCombinatorFactory &); void registerAggregateFunctionCombinatorIf(AggregateFunctionCombinatorFactory &);
@ -156,6 +157,7 @@ void registerAggregateFunctions()
registerAggregateFunctionIntervalLengthSum(factory); registerAggregateFunctionIntervalLengthSum(factory);
registerAggregateFunctionExponentialMovingAverage(factory); registerAggregateFunctionExponentialMovingAverage(factory);
registerAggregateFunctionSparkbar(factory); registerAggregateFunctionSparkbar(factory);
registerAggregateFunctionAnalysisOfVariance(factory);
registerWindowFunctions(factory); registerWindowFunctions(factory);
} }

View File

@ -0,0 +1,86 @@
#!/usr/bin/env python3
import os
import sys
from statistics import variance
from scipy import stats
import pandas as pd
import numpy as np
CURDIR = os.path.dirname(os.path.realpath(__file__))
sys.path.insert(0, os.path.join(CURDIR, 'helpers'))
from pure_http_client import ClickHouseClient
# unpooled variance z-test for means of two samples
def scipy_anova(rvs):
return stats.f_oneway(*rvs)
def test_and_check(rvs, n_groups, f_stat, p_value, precision=1e-2):
client = ClickHouseClient()
client.query("DROP TABLE IF EXISTS anova;")
client.query("CREATE TABLE anova (left Float64, right UInt64) ENGINE = Memory;")
for group in range(n_groups):
client.query(f'''INSERT INTO anova VALUES {", ".join([f'({i},{group})' for i in rvs[group]])};''')
real = client.query_return_df(
'''SELECT roundBankers(a.1, 16) as f_stat, roundBankers(a.2, 16) as p_value FROM (SELECT anova(left, right) as a FROM anova) FORMAT TabSeparatedWithNames;''')
real_f_stat = real['f_stat'][0]
real_p_value = real['p_value'][0]
assert(abs(real_f_stat - np.float64(f_stat)) < precision), f"clickhouse_f_stat {real_f_stat}, py_f_stat {f_stat}"
assert(abs(real_p_value - np.float64(p_value)) < precision), f"clickhouse_p_value {real_p_value}, py_p_value {p_value}"
client.query("DROP TABLE IF EXISTS anova;")
def test_anova():
n_groups = 3
rvs = []
loc = 0
scale = 5
size = 500
for _ in range(n_groups):
rvs.append(np.round(stats.norm.rvs(loc=loc, scale=scale, size=size), 2))
loc += 5
f_stat, p_value = scipy_anova(rvs)
test_and_check(rvs, n_groups, f_stat, p_value)
n_groups = 6
rvs = []
loc = 0
scale = 5
size = 500
for _ in range(n_groups):
rvs.append(np.round(stats.norm.rvs(loc=loc, scale=scale, size=size), 2))
f_stat, p_value = scipy_anova(rvs)
test_and_check(rvs, n_groups, f_stat, p_value)
n_groups = 10
rvs = []
loc = 1
scale = 2
size = 100
for _ in range(n_groups):
rvs.append(np.round(stats.norm.rvs(loc=loc, scale=scale, size=size), 2))
loc += 1
scale += 2
size += 100
f_stat, p_value = scipy_anova(rvs)
test_and_check(rvs, n_groups, f_stat, p_value)
n_groups = 20
rvs = []
loc = 0
scale = 10
size = 1100
for _ in range(n_groups):
rvs.append(np.round(stats.norm.rvs(loc=loc, scale=scale, size=size), 2))
size -= 50
f_stat, p_value = scipy_anova(rvs)
test_and_check(rvs, n_groups, f_stat, p_value)
if __name__ == "__main__":
test_anova()
print("Ok.")

View File

@ -0,0 +1 @@
Ok.

View File

@ -0,0 +1,9 @@
#!/usr/bin/env bash
CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
# shellcheck source=../shell_config.sh
. "$CURDIR"/../shell_config.sh
# We should have correct env vars from shell_config.sh to run this test
python3 "$CURDIR"/02294_anova_cmp.python