test in comparison with scipy

This commit is contained in:
nikitamikhaylov 2020-10-13 21:46:15 +03:00
parent e65a2a1cbd
commit 744013d4b8
7 changed files with 137 additions and 26 deletions

View File

@ -115,18 +115,22 @@ struct AggregateFunctionStudentTTestData final
throw Exception("Division by zero encountered in Aggregate function StudentTTest", ErrorCodes::BAD_ARGUMENTS); throw Exception("Division by zero encountered in Aggregate function StudentTTest", ErrorCodes::BAD_ARGUMENTS);
} }
if (mean_x - mean_y < 1e-8)
{
return static_cast<Float64>(0.0);
}
return std::pow(mean_x - mean_y, 2) / getStandartErrorSquared(); return std::pow(mean_x - mean_y, 2) / getStandartErrorSquared();
} }
Float64 getTStatistic() const
{
if (size_x == 0 || size_y == 0)
{
throw Exception("Division by zero encountered in Aggregate function StudentTTest", ErrorCodes::BAD_ARGUMENTS);
}
return (mean_x - mean_y) / std::sqrt(getStandartErrorSquared());
}
Float64 getStandartErrorSquared() const Float64 getStandartErrorSquared() const
{ {
return getSSquared() * (1 / size_x + 1 / size_y); return getSSquared() * (1.0 / static_cast<Float64>(size_x) + 1.0 / static_cast<Float64>(size_y));
} }
Float64 getDegreesOfFreedom() const Float64 getDegreesOfFreedom() const
@ -150,20 +154,23 @@ struct AggregateFunctionStudentTTestData final
{ {
const Float64 v = getDegreesOfFreedom(); const Float64 v = getDegreesOfFreedom();
const Float64 t = getTStatisticSquared(); const Float64 t = getTStatisticSquared();
std::cout << "getDegreesOfFreedom " << v << " getTStatisticSquared " << t << std::endl; std::cout << "getDegreesOfFreedom() " << getDegreesOfFreedom() << std::endl;
std::cout << "getTStatisticSquared() " << getTStatisticSquared() << std::endl;
auto f = [&v] (double x) { return std::pow(x, v/2 - 1) / std::sqrt(1 - x); }; auto f = [&v] (double x) { return std::pow(x, v/2 - 1) / std::sqrt(1 - x); };
Float64 numenator = integrateSimpson(0, v / (t + v), f); Float64 numenator = integrateSimpson(0, v / (t + v), f);
Float64 denominator = std::exp(std::lgammal(v/2) + std::lgammal(0.5) - std::lgammal(v/2 + 0.5)); Float64 denominator = std::exp(std::lgammal(v/2) + std::lgammal(0.5) - std::lgammal(v/2 + 0.5));
std::cout << "numenator " << numenator << std::endl;
std::cout << "denominator " << denominator << std::endl;
return numenator / denominator; return numenator / denominator;
} }
Float64 getResult() const std::pair<Float64, Float64> getResult() const
{ {
return getPValue(); return std::make_pair(getTStatistic(), getPValue());
} }
}; };
/// Returns p-value /// Returns tuple of (t-statistic, p-value)
/// https://cpb-us-w2.wpmucdn.com/voices.uchicago.edu/dist/9/1193/files/2016/01/05b-TandP.pdf /// https://cpb-us-w2.wpmucdn.com/voices.uchicago.edu/dist/9/1193/files/2016/01/05b-TandP.pdf
template <typename X = Float64, typename Y = Float64> template <typename X = Float64, typename Y = Float64>
class AggregateFunctionStudentTTest : class AggregateFunctionStudentTTest :
@ -182,7 +189,22 @@ public:
DataTypePtr getReturnType() const override DataTypePtr getReturnType() const override
{ {
return std::make_shared<DataTypeNumber<Float64>>(); DataTypes types
{
std::make_shared<DataTypeNumber<Float64>>(),
std::make_shared<DataTypeNumber<Float64>>(),
};
Strings names
{
"t-statistic",
"p-value"
};
return std::make_shared<DataTypeTuple>(
std::move(types),
std::move(names)
);
} }
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
@ -221,8 +243,16 @@ public:
throw Exception("Aggregate function " + getName() + " requires samples to be of size > 1", ErrorCodes::BAD_ARGUMENTS); throw Exception("Aggregate function " + getName() + " requires samples to be of size > 1", ErrorCodes::BAD_ARGUMENTS);
} }
auto & column = static_cast<ColumnVector<Float64> &>(to); Float64 t_statistic = 0.0;
column.getData().push_back(this->data(place).getResult()); Float64 p_value = 0.0;
std::tie(t_statistic, p_value) = this->data(place).getResult();
auto & column_tuple = assert_cast<ColumnTuple &>(to);
auto & column_stat = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(0));
auto & column_value = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(1));
column_stat.getData().push_back(t_statistic);
column_value.getData().push_back(p_value);
} }
}; };

View File

@ -122,6 +122,16 @@ struct AggregateFunctionWelchTTestData final
return std::pow(mean_x - mean_y, 2) / (getSxSquared() / size_x + getSySquared() / size_y); return std::pow(mean_x - mean_y, 2) / (getSxSquared() / size_x + getSySquared() / size_y);
} }
Float64 getTStatistic() const
{
if (size_x == 0 || size_y == 0)
{
throw Exception("Division by zero encountered in Aggregate function WelchTTest", ErrorCodes::BAD_ARGUMENTS);
}
return (mean_x - mean_y) / std::sqrt(getSxSquared() / size_x + getSySquared() / size_y);
}
Float64 getDegreesOfFreedom() const Float64 getDegreesOfFreedom() const
{ {
auto sx = getSxSquared(); auto sx = getSxSquared();
@ -154,9 +164,9 @@ struct AggregateFunctionWelchTTestData final
return numenator / denominator; return numenator / denominator;
} }
Float64 getResult() const std::pair<Float64, Float64> getResult() const
{ {
return getPValue(); return std::make_pair(getTStatistic(), getPValue());
} }
}; };
@ -178,7 +188,22 @@ public:
DataTypePtr getReturnType() const override DataTypePtr getReturnType() const override
{ {
return std::make_shared<DataTypeNumber<Float64>>(); DataTypes types
{
std::make_shared<DataTypeNumber<Float64>>(),
std::make_shared<DataTypeNumber<Float64>>(),
};
Strings names
{
"t-statistic",
"p-value"
};
return std::make_shared<DataTypeTuple>(
std::move(types),
std::move(names)
);
} }
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
@ -217,8 +242,16 @@ public:
throw Exception("Aggregate function " + getName() + " requires samples to be of size > 1", ErrorCodes::BAD_ARGUMENTS); throw Exception("Aggregate function " + getName() + " requires samples to be of size > 1", ErrorCodes::BAD_ARGUMENTS);
} }
auto & column = static_cast<ColumnVector<Float64> &>(to); Float64 t_statistic = 0.0;
column.getData().push_back(this->data(place).getResult()); Float64 p_value = 0.0;
std::tie(t_statistic, p_value) = this->data(place).getResult();
auto & column_tuple = assert_cast<ColumnTuple &>(to);
auto & column_stat = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(0));
auto & column_value = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(1));
column_stat.getData().push_back(t_statistic);
column_value.getData().push_back(p_value);
} }
}; };

View File

@ -42,6 +42,7 @@ SRCS(
AggregateFunctionState.cpp AggregateFunctionState.cpp
AggregateFunctionStatistics.cpp AggregateFunctionStatistics.cpp
AggregateFunctionStatisticsSimple.cpp AggregateFunctionStatisticsSimple.cpp
AggregateFunctionStudentTTest.cpp
AggregateFunctionSum.cpp AggregateFunctionSum.cpp
AggregateFunctionSumMap.cpp AggregateFunctionSumMap.cpp
AggregateFunctionTimeSeriesGroupSum.cpp AggregateFunctionTimeSeriesGroupSum.cpp
@ -49,12 +50,13 @@ SRCS(
AggregateFunctionUniqCombined.cpp AggregateFunctionUniqCombined.cpp
AggregateFunctionUniq.cpp AggregateFunctionUniq.cpp
AggregateFunctionUniqUpTo.cpp AggregateFunctionUniqUpTo.cpp
AggregateFunctionWelchTTest.cpp
AggregateFunctionWindowFunnel.cpp AggregateFunctionWindowFunnel.cpp
parseAggregateFunctionParameters.cpp parseAggregateFunctionParameters.cpp
registerAggregateFunctions.cpp registerAggregateFunctions.cpp
UniqCombinedBiasData.cpp UniqCombinedBiasData.cpp
UniqVariadicHash.cpp UniqVariadicHash.cpp
AggregateFunctionWelchTTest.cpp
) )
END() END()

View File

@ -0,0 +1,4 @@
-2.610898982580138 0.00916587538237954
-2.610898982580134 0.0091658753823792
-28.740781574102936 7.667329672103986e-133
-28.74078157410298 0

File diff suppressed because one or more lines are too long

View File

@ -1,6 +1,10 @@
0.021378001462867 0.021378001462867
0.021378 0.0213780014628671
0.090773324285671 0.090773324285671
0.09077332 0.0907733242891952
0.00339907162713746 0.00339907162713746
0.00339907 0.0033990715715539
-0.5028215369186904 0.6152361677168877
-0.5028215369187079 0.6152361677170834
14.971190998235835 5.898143508382202e-44
14.971190998235837 0

File diff suppressed because one or more lines are too long