From 20f8176a2dbee5f43d131dffe3727e2918ee462d Mon Sep 17 00:00:00 2001 From: Michal Tabaszewski Date: Tue, 27 Aug 2024 13:14:04 +0200 Subject: [PATCH 1/7] Allowed non-const args for randDistrib --- src/Functions/randDistribution.cpp | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/src/Functions/randDistribution.cpp b/src/Functions/randDistribution.cpp index dc2ddf2ef24..c14aba26d63 100644 --- a/src/Functions/randDistribution.cpp +++ b/src/Functions/randDistribution.cpp @@ -215,9 +215,8 @@ template class FunctionRandomDistribution : public IFunction { private: - template - ResultType getParameterFromConstColumn(size_t parameter_number, const ColumnsWithTypeAndName & arguments) const + ResultType getParameterFromColumn(size_t parameter_number, const ColumnsWithTypeAndName & arguments) const { if (parameter_number >= arguments.size()) throw Exception( @@ -226,11 +225,20 @@ private: parameter_number, arguments.size()); const IColumn * col = arguments[parameter_number].column.get(); + ResultType parameter; - if (!isColumnConst(*col)) - throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Parameter number {} of function {} must be constant.", parameter_number, getName()); - - auto parameter = applyVisitor(FieldVisitorConvertToNumber(), assert_cast(*col).getField()); + if (const ColumnVector * col_in = checkAndGetColumn>(col)) + { + parameter = *col_in->getData().data(); + } + else if (isColumnConst(*col)) + { + parameter = applyVisitor(FieldVisitorConvertToNumber(), assert_cast(*col).getField()); + } + else + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Parameter number {} of function {} is expected to be constant or ColumnVector", parameter_number, getName()); + } if (isNaN(parameter) || !std::isfinite(parameter)) throw Exception(ErrorCodes::BAD_ARGUMENTS, "Parameter number {} of function {} cannot be NaN of infinite", parameter_number, getName()); @@ -278,21 +286,21 @@ public: { auto res_column = ColumnUInt8::create(input_rows_count); auto & res_data = res_column->getData(); - Distribution::generate(getParameterFromConstColumn(0, arguments), res_data); + Distribution::generate(getParameterFromColumn(0, arguments), res_data); return res_column; } else if constexpr (std::is_same_v || std::is_same_v) { auto res_column = ColumnUInt64::create(input_rows_count); auto & res_data = res_column->getData(); - Distribution::generate(getParameterFromConstColumn(0, arguments), getParameterFromConstColumn(1, arguments), res_data); + Distribution::generate(getParameterFromColumn(0, arguments), getParameterFromColumn(1, arguments), res_data); return res_column; } else if constexpr (std::is_same_v) { auto res_column = ColumnUInt64::create(input_rows_count); auto & res_data = res_column->getData(); - Distribution::generate(getParameterFromConstColumn(0, arguments), res_data); + Distribution::generate(getParameterFromColumn(0, arguments), res_data); return res_column; } else @@ -301,11 +309,11 @@ public: auto & res_data = res_column->getData(); if constexpr (Distribution::getNumberOfArguments() == 1) { - Distribution::generate(getParameterFromConstColumn(0, arguments), res_data); + Distribution::generate(getParameterFromColumn(0, arguments), res_data); } else if constexpr (Distribution::getNumberOfArguments() == 2) { - Distribution::generate(getParameterFromConstColumn(0, arguments), getParameterFromConstColumn(1, arguments), res_data); + Distribution::generate(getParameterFromColumn(0, arguments), getParameterFromColumn(1, arguments), res_data); } else { From f25fa9fa58053845940aca3f37a66a8e0313bcde Mon Sep 17 00:00:00 2001 From: Michal Tabaszewski Date: Wed, 28 Aug 2024 20:08:12 +0200 Subject: [PATCH 2/7] Added better logging for invalid parameters, added test case --- src/Functions/randDistribution.cpp | 9 ++++++++- tests/queries/0_stateless/02462_distributions.reference | 2 ++ tests/queries/0_stateless/02462_distributions.sql | 4 ++++ 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/src/Functions/randDistribution.cpp b/src/Functions/randDistribution.cpp index c14aba26d63..479fab5b297 100644 --- a/src/Functions/randDistribution.cpp +++ b/src/Functions/randDistribution.cpp @@ -237,7 +237,14 @@ private: } else { - throw Exception(ErrorCodes::BAD_ARGUMENTS, "Parameter number {} of function {} is expected to be constant or ColumnVector", parameter_number, getName()); + std::string expectedType; + if(std::is_same_v) + expectedType = "UInt64"; + else if(std::is_same_v) + expectedType = "Float64"; + + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Parameter number {} of function {} is expected to be {} but is {}", + parameter_number, getName(), expectedType, col->getName()); } if (isNaN(parameter) || !std::isfinite(parameter)) diff --git a/tests/queries/0_stateless/02462_distributions.reference b/tests/queries/0_stateless/02462_distributions.reference index 56b04bcb856..8856a96d68e 100644 --- a/tests/queries/0_stateless/02462_distributions.reference +++ b/tests/queries/0_stateless/02462_distributions.reference @@ -10,3 +10,5 @@ Ok Ok Ok Ok +Ok +Ok \ No newline at end of file diff --git a/tests/queries/0_stateless/02462_distributions.sql b/tests/queries/0_stateless/02462_distributions.sql index b45dc897f2a..5353dd7abfd 100644 --- a/tests/queries/0_stateless/02462_distributions.sql +++ b/tests/queries/0_stateless/02462_distributions.sql @@ -22,3 +22,7 @@ SELECT DISTINCT if (a >= toFloat64(0), 'Ok', 'Fail') FROM (SELECT randNegativeBi SELECT DISTINCT if (a >= toFloat64(0), 'Ok', 'Fail') FROM (SELECT randPoisson(44) AS a FROM numbers(100000)); # No errors SELECT randUniform(1, 2, 1), randNormal(0, 1, 'abacaba'), randLogNormal(0, 10, 'b'), randChiSquared(1, 1), randStudentT(7, '8'), randFisherF(23, 42, 100), randBernoulli(0.5, 2), randBinomial(3, 0.5, 1), randNegativeBinomial(3, 0.5, 2), randPoisson(44, 44) FORMAT Null; +# Values should be >= 0 +SELECT DISTINCT if (a >= toFloat64(-10), 'Ok', 'Fail') FROM (SELECT randNormal(randNormal(0, 1), randUniform(1, 2)) AS a FROM numbers(10000)); +# Values should be >= 0 +SELECT DISTINCT if (a >= toFloat64(0), 'Ok', 'Fail') FROM (SELECT randUniform(randUniform(0, 1), randUniform(1, 2)) AS a FROM numbers(10000)); \ No newline at end of file From e81516bdc68f29fc214ae51973d43b299637b7b8 Mon Sep 17 00:00:00 2001 From: Michal Tabaszewski Date: Wed, 28 Aug 2024 21:49:01 +0200 Subject: [PATCH 3/7] Fixed issues with degrees of freedom function. --- src/Functions/randDistribution.cpp | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/Functions/randDistribution.cpp b/src/Functions/randDistribution.cpp index 479fab5b297..592602dcc63 100644 --- a/src/Functions/randDistribution.cpp +++ b/src/Functions/randDistribution.cpp @@ -93,6 +93,9 @@ struct ChiSquaredDistribution static void generate(Float64 degree_of_freedom, ColumnFloat64::Container & container) { + if(container.empty()) + return; + if (degree_of_freedom <= 0) throw Exception(ErrorCodes::BAD_ARGUMENTS, "Argument (degrees of freedom) of function {} should be greater than zero", getName()); @@ -110,6 +113,9 @@ struct StudentTDistribution static void generate(Float64 degree_of_freedom, ColumnFloat64::Container & container) { + if(container.empty()) + return; + if (degree_of_freedom <= 0) throw Exception(ErrorCodes::BAD_ARGUMENTS, "Argument (degrees of freedom) of function {} should be greater than zero", getName()); @@ -127,6 +133,9 @@ struct FisherFDistribution static void generate(Float64 d1, Float64 d2, ColumnFloat64::Container & container) { + if(container.empty()) + return; + if (d1 <= 0 || d2 <= 0) throw Exception(ErrorCodes::BAD_ARGUMENTS, "Argument (degrees of freedom) of function {} should be greater than zero", getName()); From 6bc820d3ee5dab8e910a33d9b3892279378c0cd1 Mon Sep 17 00:00:00 2001 From: Michal Tabaszewski Date: Thu, 29 Aug 2024 18:54:43 +0200 Subject: [PATCH 4/7] Fixed style check --- src/Functions/randDistribution.cpp | 20 +++++--------------- 1 file changed, 5 insertions(+), 15 deletions(-) diff --git a/src/Functions/randDistribution.cpp b/src/Functions/randDistribution.cpp index 592602dcc63..602f4fd18e1 100644 --- a/src/Functions/randDistribution.cpp +++ b/src/Functions/randDistribution.cpp @@ -21,7 +21,6 @@ namespace DB namespace ErrorCodes { extern const int ILLEGAL_TYPE_OF_ARGUMENT; - extern const int ILLEGAL_COLUMN; extern const int BAD_ARGUMENTS; extern const int LOGICAL_ERROR; extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; @@ -93,10 +92,7 @@ struct ChiSquaredDistribution static void generate(Float64 degree_of_freedom, ColumnFloat64::Container & container) { - if(container.empty()) - return; - - if (degree_of_freedom <= 0) + if (!container.empty() && degree_of_freedom <= 0) throw Exception(ErrorCodes::BAD_ARGUMENTS, "Argument (degrees of freedom) of function {} should be greater than zero", getName()); auto distribution = std::chi_squared_distribution<>(degree_of_freedom); @@ -113,10 +109,7 @@ struct StudentTDistribution static void generate(Float64 degree_of_freedom, ColumnFloat64::Container & container) { - if(container.empty()) - return; - - if (degree_of_freedom <= 0) + if (!container.empty() && degree_of_freedom <= 0) throw Exception(ErrorCodes::BAD_ARGUMENTS, "Argument (degrees of freedom) of function {} should be greater than zero", getName()); auto distribution = std::student_t_distribution<>(degree_of_freedom); @@ -133,10 +126,7 @@ struct FisherFDistribution static void generate(Float64 d1, Float64 d2, ColumnFloat64::Container & container) { - if(container.empty()) - return; - - if (d1 <= 0 || d2 <= 0) + if (!container.empty() && (d1 <= 0 || d2 <= 0)) throw Exception(ErrorCodes::BAD_ARGUMENTS, "Argument (degrees of freedom) of function {} should be greater than zero", getName()); auto distribution = std::fisher_f_distribution<>(d1, d2); @@ -247,9 +237,9 @@ private: else { std::string expectedType; - if(std::is_same_v) + if (std::is_same_v) expectedType = "UInt64"; - else if(std::is_same_v) + else if (std::is_same_v) expectedType = "Float64"; throw Exception(ErrorCodes::BAD_ARGUMENTS, "Parameter number {} of function {} is expected to be {} but is {}", From 1b8ac3d440a8704d99e0bcf2cf4244139d8f3c29 Mon Sep 17 00:00:00 2001 From: Michal Tabaszewski Date: Thu, 29 Aug 2024 19:32:10 +0200 Subject: [PATCH 5/7] Added missing newline in reference 02462_distributions.reference --- tests/queries/0_stateless/02462_distributions.reference | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/queries/0_stateless/02462_distributions.reference b/tests/queries/0_stateless/02462_distributions.reference index 8856a96d68e..2341fe20bcc 100644 --- a/tests/queries/0_stateless/02462_distributions.reference +++ b/tests/queries/0_stateless/02462_distributions.reference @@ -11,4 +11,4 @@ Ok Ok Ok Ok -Ok \ No newline at end of file +Ok From 906c99f0ea2fec044079a1e90fa1fa891d79cd05 Mon Sep 17 00:00:00 2001 From: Michal Tabaszewski Date: Fri, 30 Aug 2024 00:23:24 +0200 Subject: [PATCH 6/7] Updated minimal value for distributions tests --- tests/queries/0_stateless/02462_distributions.sql | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/queries/0_stateless/02462_distributions.sql b/tests/queries/0_stateless/02462_distributions.sql index 5353dd7abfd..62a3ec3f6fa 100644 --- a/tests/queries/0_stateless/02462_distributions.sql +++ b/tests/queries/0_stateless/02462_distributions.sql @@ -23,6 +23,6 @@ SELECT DISTINCT if (a >= toFloat64(0), 'Ok', 'Fail') FROM (SELECT randPoisson(44 # No errors SELECT randUniform(1, 2, 1), randNormal(0, 1, 'abacaba'), randLogNormal(0, 10, 'b'), randChiSquared(1, 1), randStudentT(7, '8'), randFisherF(23, 42, 100), randBernoulli(0.5, 2), randBinomial(3, 0.5, 1), randNegativeBinomial(3, 0.5, 2), randPoisson(44, 44) FORMAT Null; # Values should be >= 0 -SELECT DISTINCT if (a >= toFloat64(-10), 'Ok', 'Fail') FROM (SELECT randNormal(randNormal(0, 1), randUniform(1, 2)) AS a FROM numbers(10000)); +SELECT DISTINCT if (a >= toFloat64(-15), 'Ok', 'Fail') FROM (SELECT randNormal(randNormal(0, 1), randUniform(1, 2)) AS a FROM numbers(10000)); # Values should be >= 0 SELECT DISTINCT if (a >= toFloat64(0), 'Ok', 'Fail') FROM (SELECT randUniform(randUniform(0, 1), randUniform(1, 2)) AS a FROM numbers(10000)); \ No newline at end of file From 652067781479fa724a6b75c4ee4fbdd3e3bf62f0 Mon Sep 17 00:00:00 2001 From: Nikita Mikhaylov Date: Fri, 30 Aug 2024 20:47:38 +0000 Subject: [PATCH 7/7] Beautify --- src/Functions/randDistribution.cpp | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/Functions/randDistribution.cpp b/src/Functions/randDistribution.cpp index 602f4fd18e1..10a3280010c 100644 --- a/src/Functions/randDistribution.cpp +++ b/src/Functions/randDistribution.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include @@ -236,14 +237,9 @@ private: } else { - std::string expectedType; - if (std::is_same_v) - expectedType = "UInt64"; - else if (std::is_same_v) - expectedType = "Float64"; - + auto expected_type = Field(Field::Types::Which(Field::TypeToEnum>::value)).getTypeName(); throw Exception(ErrorCodes::BAD_ARGUMENTS, "Parameter number {} of function {} is expected to be {} but is {}", - parameter_number, getName(), expectedType, col->getName()); + parameter_number, getName(), expected_type, col->getName()); } if (isNaN(parameter) || !std::isfinite(parameter))