Add checks in AggregateFunctionAnalysisOfVariance

This commit is contained in:
vdimir 2022-10-31 11:56:52 +00:00
parent fe48a1ce7e
commit 6d798cbc9d
No known key found for this signature in database
GPG Key ID: 6EE4CE2BEDC51862
5 changed files with 21 additions and 3 deletions

View File

@ -18,8 +18,10 @@ AggregateFunctionPtr createAggregateFunctionAnalysisOfVariance(const std::string
assertNoParameters(name, parameters);
assertBinary(name, arguments);
if (!isNumber(arguments[0]) || !isNumber(arguments[1]))
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Aggregate function {} only supports numerical types", name);
if (!isNumber(arguments[0]))
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Aggregate function {} only supports numerical argument types", name);
if (!WhichDataType(arguments[1]).isNativeUInt())
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Second argument of aggregate function {} should be a native unsigned integer", name);
return std::make_shared<AggregateFunctionAnalysisOfVariance>(arguments, parameters);
}

View File

@ -77,7 +77,7 @@ public:
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
{
auto f_stat = data(place).getFStatistic();
if (std::isinf(f_stat) || isNaN(f_stat))
if (std::isinf(f_stat) || isNaN(f_stat) || f_stat < 0)
throw Exception("F statistic is not defined or infinite for these arguments", ErrorCodes::BAD_ARGUMENTS);
auto p_value = data(place).getPValue(f_stat);

View File

@ -482,6 +482,8 @@ struct ZTestMoments
template <typename T>
struct AnalysisOfVarianceMoments
{
constexpr static size_t MAX_GROUPS_NUMBER = 1024 * 1024;
/// Sums of values within a group
std::vector<T> xs1{};
/// Sums of squared values within a group
@ -494,6 +496,10 @@ struct AnalysisOfVarianceMoments
if (xs1.size() >= possible_size)
return;
if (possible_size > MAX_GROUPS_NUMBER)
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Too many groups for analysis of variance (should be no more than {}, got {})",
MAX_GROUPS_NUMBER, possible_size);
xs1.resize(possible_size, 0.0);
xs2.resize(possible_size, 0.0);
ns.resize(possible_size, 0);

View File

@ -0,0 +1,10 @@
SELECT analysisOfVariance(number, number % 2) FROM numbers(10) FORMAT Null;
SELECT analysisOfVariance(number :: Decimal32(5), number % 2) FROM numbers(10) FORMAT Null;
SELECT analysisOfVariance(number :: Decimal256(5), number % 2) FROM numbers(10) FORMAT Null;
SELECT analysisOfVariance(1.11, -20); -- { serverError BAD_ARGUMENTS }
SELECT analysisOfVariance(1.11, 20 :: UInt128); -- { serverError BAD_ARGUMENTS }
SELECT analysisOfVariance(1.11, 9000000000000000); -- { serverError BAD_ARGUMENTS }
SELECT analysisOfVariance(number, number % 2), analysisOfVariance(100000000000000000000., number % 65535) FROM numbers(1048575); -- { serverError BAD_ARGUMENTS }