diff --git a/src/AggregateFunctions/AggregateFunctionAnalysisOfVariance.cpp b/src/AggregateFunctions/AggregateFunctionAnalysisOfVariance.cpp index ffb651b3288..9ef2d295828 100644 --- a/src/AggregateFunctions/AggregateFunctionAnalysisOfVariance.cpp +++ b/src/AggregateFunctions/AggregateFunctionAnalysisOfVariance.cpp @@ -18,8 +18,10 @@ AggregateFunctionPtr createAggregateFunctionAnalysisOfVariance(const std::string assertNoParameters(name, parameters); assertBinary(name, arguments); - if (!isNumber(arguments[0]) || !isNumber(arguments[1])) - throw Exception(ErrorCodes::BAD_ARGUMENTS, "Aggregate function {} only supports numerical types", name); + if (!isNumber(arguments[0])) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Aggregate function {} only supports numerical argument types", name); + if (!WhichDataType(arguments[1]).isNativeUInt()) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Second argument of aggregate function {} should be a native unsigned integer", name); return std::make_shared(arguments, parameters); } diff --git a/src/AggregateFunctions/AggregateFunctionAnalysisOfVariance.h b/src/AggregateFunctions/AggregateFunctionAnalysisOfVariance.h index efb6426a96c..e891fb191f6 100644 --- a/src/AggregateFunctions/AggregateFunctionAnalysisOfVariance.h +++ b/src/AggregateFunctions/AggregateFunctionAnalysisOfVariance.h @@ -77,7 +77,7 @@ public: void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override { auto f_stat = data(place).getFStatistic(); - if (std::isinf(f_stat) || isNaN(f_stat)) + if (std::isinf(f_stat) || isNaN(f_stat) || f_stat < 0) throw Exception("F statistic is not defined or infinite for these arguments", ErrorCodes::BAD_ARGUMENTS); auto p_value = data(place).getPValue(f_stat); diff --git a/src/AggregateFunctions/Moments.h b/src/AggregateFunctions/Moments.h index 16279cb93a4..2dfd5bc46d6 100644 --- a/src/AggregateFunctions/Moments.h +++ b/src/AggregateFunctions/Moments.h @@ -482,6 +482,8 @@ struct ZTestMoments template struct AnalysisOfVarianceMoments { + constexpr static size_t MAX_GROUPS_NUMBER = 1024 * 1024; + /// Sums of values within a group std::vector xs1{}; /// Sums of squared values within a group @@ -494,6 +496,10 @@ struct AnalysisOfVarianceMoments if (xs1.size() >= possible_size) return; + if (possible_size > MAX_GROUPS_NUMBER) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Too many groups for analysis of variance (should be no more than {}, got {})", + MAX_GROUPS_NUMBER, possible_size); + xs1.resize(possible_size, 0.0); xs2.resize(possible_size, 0.0); ns.resize(possible_size, 0); diff --git a/tests/queries/0_stateless/02475_analysis_of_variance.reference b/tests/queries/0_stateless/02475_analysis_of_variance.reference new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/queries/0_stateless/02475_analysis_of_variance.sql b/tests/queries/0_stateless/02475_analysis_of_variance.sql new file mode 100644 index 00000000000..86996f784ea --- /dev/null +++ b/tests/queries/0_stateless/02475_analysis_of_variance.sql @@ -0,0 +1,10 @@ + +SELECT analysisOfVariance(number, number % 2) FROM numbers(10) FORMAT Null; +SELECT analysisOfVariance(number :: Decimal32(5), number % 2) FROM numbers(10) FORMAT Null; +SELECT analysisOfVariance(number :: Decimal256(5), number % 2) FROM numbers(10) FORMAT Null; + +SELECT analysisOfVariance(1.11, -20); -- { serverError BAD_ARGUMENTS } +SELECT analysisOfVariance(1.11, 20 :: UInt128); -- { serverError BAD_ARGUMENTS } +SELECT analysisOfVariance(1.11, 9000000000000000); -- { serverError BAD_ARGUMENTS } + +SELECT analysisOfVariance(number, number % 2), analysisOfVariance(100000000000000000000., number % 65535) FROM numbers(1048575); -- { serverError BAD_ARGUMENTS }