mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-21 23:21:59 +00:00
test in comparison with scipy
This commit is contained in:
parent
e65a2a1cbd
commit
744013d4b8
@ -115,18 +115,22 @@ struct AggregateFunctionStudentTTestData final
|
|||||||
throw Exception("Division by zero encountered in Aggregate function StudentTTest", ErrorCodes::BAD_ARGUMENTS);
|
throw Exception("Division by zero encountered in Aggregate function StudentTTest", ErrorCodes::BAD_ARGUMENTS);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (mean_x - mean_y < 1e-8)
|
|
||||||
{
|
|
||||||
return static_cast<Float64>(0.0);
|
|
||||||
}
|
|
||||||
|
|
||||||
return std::pow(mean_x - mean_y, 2) / getStandartErrorSquared();
|
return std::pow(mean_x - mean_y, 2) / getStandartErrorSquared();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Float64 getTStatistic() const
|
||||||
|
{
|
||||||
|
if (size_x == 0 || size_y == 0)
|
||||||
|
{
|
||||||
|
throw Exception("Division by zero encountered in Aggregate function StudentTTest", ErrorCodes::BAD_ARGUMENTS);
|
||||||
|
}
|
||||||
|
|
||||||
|
return (mean_x - mean_y) / std::sqrt(getStandartErrorSquared());
|
||||||
|
}
|
||||||
|
|
||||||
Float64 getStandartErrorSquared() const
|
Float64 getStandartErrorSquared() const
|
||||||
{
|
{
|
||||||
return getSSquared() * (1 / size_x + 1 / size_y);
|
return getSSquared() * (1.0 / static_cast<Float64>(size_x) + 1.0 / static_cast<Float64>(size_y));
|
||||||
}
|
}
|
||||||
|
|
||||||
Float64 getDegreesOfFreedom() const
|
Float64 getDegreesOfFreedom() const
|
||||||
@ -150,20 +154,23 @@ struct AggregateFunctionStudentTTestData final
|
|||||||
{
|
{
|
||||||
const Float64 v = getDegreesOfFreedom();
|
const Float64 v = getDegreesOfFreedom();
|
||||||
const Float64 t = getTStatisticSquared();
|
const Float64 t = getTStatisticSquared();
|
||||||
std::cout << "getDegreesOfFreedom " << v << " getTStatisticSquared " << t << std::endl;
|
std::cout << "getDegreesOfFreedom() " << getDegreesOfFreedom() << std::endl;
|
||||||
|
std::cout << "getTStatisticSquared() " << getTStatisticSquared() << std::endl;
|
||||||
auto f = [&v] (double x) { return std::pow(x, v/2 - 1) / std::sqrt(1 - x); };
|
auto f = [&v] (double x) { return std::pow(x, v/2 - 1) / std::sqrt(1 - x); };
|
||||||
Float64 numenator = integrateSimpson(0, v / (t + v), f);
|
Float64 numenator = integrateSimpson(0, v / (t + v), f);
|
||||||
Float64 denominator = std::exp(std::lgammal(v/2) + std::lgammal(0.5) - std::lgammal(v/2 + 0.5));
|
Float64 denominator = std::exp(std::lgammal(v/2) + std::lgammal(0.5) - std::lgammal(v/2 + 0.5));
|
||||||
|
std::cout << "numenator " << numenator << std::endl;
|
||||||
|
std::cout << "denominator " << denominator << std::endl;
|
||||||
return numenator / denominator;
|
return numenator / denominator;
|
||||||
}
|
}
|
||||||
|
|
||||||
Float64 getResult() const
|
std::pair<Float64, Float64> getResult() const
|
||||||
{
|
{
|
||||||
return getPValue();
|
return std::make_pair(getTStatistic(), getPValue());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Returns p-value
|
/// Returns tuple of (t-statistic, p-value)
|
||||||
/// https://cpb-us-w2.wpmucdn.com/voices.uchicago.edu/dist/9/1193/files/2016/01/05b-TandP.pdf
|
/// https://cpb-us-w2.wpmucdn.com/voices.uchicago.edu/dist/9/1193/files/2016/01/05b-TandP.pdf
|
||||||
template <typename X = Float64, typename Y = Float64>
|
template <typename X = Float64, typename Y = Float64>
|
||||||
class AggregateFunctionStudentTTest :
|
class AggregateFunctionStudentTTest :
|
||||||
@ -182,7 +189,22 @@ public:
|
|||||||
|
|
||||||
DataTypePtr getReturnType() const override
|
DataTypePtr getReturnType() const override
|
||||||
{
|
{
|
||||||
return std::make_shared<DataTypeNumber<Float64>>();
|
DataTypes types
|
||||||
|
{
|
||||||
|
std::make_shared<DataTypeNumber<Float64>>(),
|
||||||
|
std::make_shared<DataTypeNumber<Float64>>(),
|
||||||
|
};
|
||||||
|
|
||||||
|
Strings names
|
||||||
|
{
|
||||||
|
"t-statistic",
|
||||||
|
"p-value"
|
||||||
|
};
|
||||||
|
|
||||||
|
return std::make_shared<DataTypeTuple>(
|
||||||
|
std::move(types),
|
||||||
|
std::move(names)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
|
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||||
@ -221,8 +243,16 @@ public:
|
|||||||
throw Exception("Aggregate function " + getName() + " requires samples to be of size > 1", ErrorCodes::BAD_ARGUMENTS);
|
throw Exception("Aggregate function " + getName() + " requires samples to be of size > 1", ErrorCodes::BAD_ARGUMENTS);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto & column = static_cast<ColumnVector<Float64> &>(to);
|
Float64 t_statistic = 0.0;
|
||||||
column.getData().push_back(this->data(place).getResult());
|
Float64 p_value = 0.0;
|
||||||
|
std::tie(t_statistic, p_value) = this->data(place).getResult();
|
||||||
|
|
||||||
|
auto & column_tuple = assert_cast<ColumnTuple &>(to);
|
||||||
|
auto & column_stat = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(0));
|
||||||
|
auto & column_value = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(1));
|
||||||
|
|
||||||
|
column_stat.getData().push_back(t_statistic);
|
||||||
|
column_value.getData().push_back(p_value);
|
||||||
}
|
}
|
||||||
|
|
||||||
};
|
};
|
||||||
|
@ -122,6 +122,16 @@ struct AggregateFunctionWelchTTestData final
|
|||||||
return std::pow(mean_x - mean_y, 2) / (getSxSquared() / size_x + getSySquared() / size_y);
|
return std::pow(mean_x - mean_y, 2) / (getSxSquared() / size_x + getSySquared() / size_y);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Float64 getTStatistic() const
|
||||||
|
{
|
||||||
|
if (size_x == 0 || size_y == 0)
|
||||||
|
{
|
||||||
|
throw Exception("Division by zero encountered in Aggregate function WelchTTest", ErrorCodes::BAD_ARGUMENTS);
|
||||||
|
}
|
||||||
|
|
||||||
|
return (mean_x - mean_y) / std::sqrt(getSxSquared() / size_x + getSySquared() / size_y);
|
||||||
|
}
|
||||||
|
|
||||||
Float64 getDegreesOfFreedom() const
|
Float64 getDegreesOfFreedom() const
|
||||||
{
|
{
|
||||||
auto sx = getSxSquared();
|
auto sx = getSxSquared();
|
||||||
@ -154,9 +164,9 @@ struct AggregateFunctionWelchTTestData final
|
|||||||
return numenator / denominator;
|
return numenator / denominator;
|
||||||
}
|
}
|
||||||
|
|
||||||
Float64 getResult() const
|
std::pair<Float64, Float64> getResult() const
|
||||||
{
|
{
|
||||||
return getPValue();
|
return std::make_pair(getTStatistic(), getPValue());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -178,7 +188,22 @@ public:
|
|||||||
|
|
||||||
DataTypePtr getReturnType() const override
|
DataTypePtr getReturnType() const override
|
||||||
{
|
{
|
||||||
return std::make_shared<DataTypeNumber<Float64>>();
|
DataTypes types
|
||||||
|
{
|
||||||
|
std::make_shared<DataTypeNumber<Float64>>(),
|
||||||
|
std::make_shared<DataTypeNumber<Float64>>(),
|
||||||
|
};
|
||||||
|
|
||||||
|
Strings names
|
||||||
|
{
|
||||||
|
"t-statistic",
|
||||||
|
"p-value"
|
||||||
|
};
|
||||||
|
|
||||||
|
return std::make_shared<DataTypeTuple>(
|
||||||
|
std::move(types),
|
||||||
|
std::move(names)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
|
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||||
@ -217,8 +242,16 @@ public:
|
|||||||
throw Exception("Aggregate function " + getName() + " requires samples to be of size > 1", ErrorCodes::BAD_ARGUMENTS);
|
throw Exception("Aggregate function " + getName() + " requires samples to be of size > 1", ErrorCodes::BAD_ARGUMENTS);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto & column = static_cast<ColumnVector<Float64> &>(to);
|
Float64 t_statistic = 0.0;
|
||||||
column.getData().push_back(this->data(place).getResult());
|
Float64 p_value = 0.0;
|
||||||
|
std::tie(t_statistic, p_value) = this->data(place).getResult();
|
||||||
|
|
||||||
|
auto & column_tuple = assert_cast<ColumnTuple &>(to);
|
||||||
|
auto & column_stat = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(0));
|
||||||
|
auto & column_value = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(1));
|
||||||
|
|
||||||
|
column_stat.getData().push_back(t_statistic);
|
||||||
|
column_value.getData().push_back(p_value);
|
||||||
}
|
}
|
||||||
|
|
||||||
};
|
};
|
||||||
|
@ -42,6 +42,7 @@ SRCS(
|
|||||||
AggregateFunctionState.cpp
|
AggregateFunctionState.cpp
|
||||||
AggregateFunctionStatistics.cpp
|
AggregateFunctionStatistics.cpp
|
||||||
AggregateFunctionStatisticsSimple.cpp
|
AggregateFunctionStatisticsSimple.cpp
|
||||||
|
AggregateFunctionStudentTTest.cpp
|
||||||
AggregateFunctionSum.cpp
|
AggregateFunctionSum.cpp
|
||||||
AggregateFunctionSumMap.cpp
|
AggregateFunctionSumMap.cpp
|
||||||
AggregateFunctionTimeSeriesGroupSum.cpp
|
AggregateFunctionTimeSeriesGroupSum.cpp
|
||||||
@ -49,12 +50,13 @@ SRCS(
|
|||||||
AggregateFunctionUniqCombined.cpp
|
AggregateFunctionUniqCombined.cpp
|
||||||
AggregateFunctionUniq.cpp
|
AggregateFunctionUniq.cpp
|
||||||
AggregateFunctionUniqUpTo.cpp
|
AggregateFunctionUniqUpTo.cpp
|
||||||
|
AggregateFunctionWelchTTest.cpp
|
||||||
AggregateFunctionWindowFunnel.cpp
|
AggregateFunctionWindowFunnel.cpp
|
||||||
parseAggregateFunctionParameters.cpp
|
parseAggregateFunctionParameters.cpp
|
||||||
registerAggregateFunctions.cpp
|
registerAggregateFunctions.cpp
|
||||||
UniqCombinedBiasData.cpp
|
UniqCombinedBiasData.cpp
|
||||||
UniqVariadicHash.cpp
|
UniqVariadicHash.cpp
|
||||||
AggregateFunctionWelchTTest.cpp
|
|
||||||
)
|
)
|
||||||
|
|
||||||
END()
|
END()
|
||||||
|
4
tests/queries/0_stateless/01322_student_ttest.reference
Normal file
4
tests/queries/0_stateless/01322_student_ttest.reference
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
-2.610898982580138 0.00916587538237954
|
||||||
|
-2.610898982580134 0.0091658753823792
|
||||||
|
-28.740781574102936 7.667329672103986e-133
|
||||||
|
-28.74078157410298 0
|
19
tests/queries/0_stateless/01322_student_ttest.sql
Normal file
19
tests/queries/0_stateless/01322_student_ttest.sql
Normal file
File diff suppressed because one or more lines are too long
@ -1,6 +1,10 @@
|
|||||||
0.021378001462867
|
0.021378001462867
|
||||||
0.021378
|
0.0213780014628671
|
||||||
0.090773324285671
|
0.090773324285671
|
||||||
0.09077332
|
0.0907733242891952
|
||||||
0.00339907162713746
|
0.00339907162713746
|
||||||
0.00339907
|
0.0033990715715539
|
||||||
|
-0.5028215369186904 0.6152361677168877
|
||||||
|
-0.5028215369187079 0.6152361677170834
|
||||||
|
14.971190998235835 5.898143508382202e-44
|
||||||
|
14.971190998235837 0
|
||||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user