test in comparison with scipy

This commit is contained in:
nikitamikhaylov 2020-10-13 21:46:15 +03:00
parent e65a2a1cbd
commit 744013d4b8
7 changed files with 137 additions and 26 deletions

View File

@ -115,18 +115,22 @@ struct AggregateFunctionStudentTTestData final
throw Exception("Division by zero encountered in Aggregate function StudentTTest", ErrorCodes::BAD_ARGUMENTS);
}
if (mean_x - mean_y < 1e-8)
{
return static_cast<Float64>(0.0);
}
return std::pow(mean_x - mean_y, 2) / getStandartErrorSquared();
}
Float64 getTStatistic() const
{
if (size_x == 0 || size_y == 0)
{
throw Exception("Division by zero encountered in Aggregate function StudentTTest", ErrorCodes::BAD_ARGUMENTS);
}
return (mean_x - mean_y) / std::sqrt(getStandartErrorSquared());
}
Float64 getStandartErrorSquared() const
{
return getSSquared() * (1 / size_x + 1 / size_y);
return getSSquared() * (1.0 / static_cast<Float64>(size_x) + 1.0 / static_cast<Float64>(size_y));
}
Float64 getDegreesOfFreedom() const
@ -150,20 +154,23 @@ struct AggregateFunctionStudentTTestData final
{
const Float64 v = getDegreesOfFreedom();
const Float64 t = getTStatisticSquared();
std::cout << "getDegreesOfFreedom " << v << " getTStatisticSquared " << t << std::endl;
std::cout << "getDegreesOfFreedom() " << getDegreesOfFreedom() << std::endl;
std::cout << "getTStatisticSquared() " << getTStatisticSquared() << std::endl;
auto f = [&v] (double x) { return std::pow(x, v/2 - 1) / std::sqrt(1 - x); };
Float64 numenator = integrateSimpson(0, v / (t + v), f);
Float64 denominator = std::exp(std::lgammal(v/2) + std::lgammal(0.5) - std::lgammal(v/2 + 0.5));
std::cout << "numenator " << numenator << std::endl;
std::cout << "denominator " << denominator << std::endl;
return numenator / denominator;
}
Float64 getResult() const
std::pair<Float64, Float64> getResult() const
{
return getPValue();
return std::make_pair(getTStatistic(), getPValue());
}
};
/// Returns p-value
/// Returns tuple of (t-statistic, p-value)
/// https://cpb-us-w2.wpmucdn.com/voices.uchicago.edu/dist/9/1193/files/2016/01/05b-TandP.pdf
template <typename X = Float64, typename Y = Float64>
class AggregateFunctionStudentTTest :
@ -182,7 +189,22 @@ public:
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeNumber<Float64>>();
DataTypes types
{
std::make_shared<DataTypeNumber<Float64>>(),
std::make_shared<DataTypeNumber<Float64>>(),
};
Strings names
{
"t-statistic",
"p-value"
};
return std::make_shared<DataTypeTuple>(
std::move(types),
std::move(names)
);
}
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
@ -221,8 +243,16 @@ public:
throw Exception("Aggregate function " + getName() + " requires samples to be of size > 1", ErrorCodes::BAD_ARGUMENTS);
}
auto & column = static_cast<ColumnVector<Float64> &>(to);
column.getData().push_back(this->data(place).getResult());
Float64 t_statistic = 0.0;
Float64 p_value = 0.0;
std::tie(t_statistic, p_value) = this->data(place).getResult();
auto & column_tuple = assert_cast<ColumnTuple &>(to);
auto & column_stat = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(0));
auto & column_value = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(1));
column_stat.getData().push_back(t_statistic);
column_value.getData().push_back(p_value);
}
};

View File

@ -122,6 +122,16 @@ struct AggregateFunctionWelchTTestData final
return std::pow(mean_x - mean_y, 2) / (getSxSquared() / size_x + getSySquared() / size_y);
}
Float64 getTStatistic() const
{
if (size_x == 0 || size_y == 0)
{
throw Exception("Division by zero encountered in Aggregate function WelchTTest", ErrorCodes::BAD_ARGUMENTS);
}
return (mean_x - mean_y) / std::sqrt(getSxSquared() / size_x + getSySquared() / size_y);
}
Float64 getDegreesOfFreedom() const
{
auto sx = getSxSquared();
@ -154,9 +164,9 @@ struct AggregateFunctionWelchTTestData final
return numenator / denominator;
}
Float64 getResult() const
std::pair<Float64, Float64> getResult() const
{
return getPValue();
return std::make_pair(getTStatistic(), getPValue());
}
};
@ -178,7 +188,22 @@ public:
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeNumber<Float64>>();
DataTypes types
{
std::make_shared<DataTypeNumber<Float64>>(),
std::make_shared<DataTypeNumber<Float64>>(),
};
Strings names
{
"t-statistic",
"p-value"
};
return std::make_shared<DataTypeTuple>(
std::move(types),
std::move(names)
);
}
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
@ -217,8 +242,16 @@ public:
throw Exception("Aggregate function " + getName() + " requires samples to be of size > 1", ErrorCodes::BAD_ARGUMENTS);
}
auto & column = static_cast<ColumnVector<Float64> &>(to);
column.getData().push_back(this->data(place).getResult());
Float64 t_statistic = 0.0;
Float64 p_value = 0.0;
std::tie(t_statistic, p_value) = this->data(place).getResult();
auto & column_tuple = assert_cast<ColumnTuple &>(to);
auto & column_stat = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(0));
auto & column_value = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(1));
column_stat.getData().push_back(t_statistic);
column_value.getData().push_back(p_value);
}
};

View File

@ -42,6 +42,7 @@ SRCS(
AggregateFunctionState.cpp
AggregateFunctionStatistics.cpp
AggregateFunctionStatisticsSimple.cpp
AggregateFunctionStudentTTest.cpp
AggregateFunctionSum.cpp
AggregateFunctionSumMap.cpp
AggregateFunctionTimeSeriesGroupSum.cpp
@ -49,12 +50,13 @@ SRCS(
AggregateFunctionUniqCombined.cpp
AggregateFunctionUniq.cpp
AggregateFunctionUniqUpTo.cpp
AggregateFunctionWelchTTest.cpp
AggregateFunctionWindowFunnel.cpp
parseAggregateFunctionParameters.cpp
registerAggregateFunctions.cpp
UniqCombinedBiasData.cpp
UniqVariadicHash.cpp
AggregateFunctionWelchTTest.cpp
)
END()

View File

@ -0,0 +1,4 @@
-2.610898982580138 0.00916587538237954
-2.610898982580134 0.0091658753823792
-28.740781574102936 7.667329672103986e-133
-28.74078157410298 0

File diff suppressed because one or more lines are too long

View File

@ -1,6 +1,10 @@
0.021378001462867
0.021378
0.0213780014628671
0.090773324285671
0.09077332
0.0907733242891952
0.00339907162713746
0.00339907
0.0033990715715539
-0.5028215369186904 0.6152361677168877
-0.5028215369187079 0.6152361677170834
14.971190998235835 5.898143508382202e-44
14.971190998235837 0

File diff suppressed because one or more lines are too long