diff --git a/src/Functions/randDistribution.cpp b/src/Functions/randDistribution.cpp index dc2ddf2ef24..10a3280010c 100644 --- a/src/Functions/randDistribution.cpp +++ b/src/Functions/randDistribution.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include @@ -21,7 +22,6 @@ namespace DB namespace ErrorCodes { extern const int ILLEGAL_TYPE_OF_ARGUMENT; - extern const int ILLEGAL_COLUMN; extern const int BAD_ARGUMENTS; extern const int LOGICAL_ERROR; extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; @@ -93,7 +93,7 @@ struct ChiSquaredDistribution static void generate(Float64 degree_of_freedom, ColumnFloat64::Container & container) { - if (degree_of_freedom <= 0) + if (!container.empty() && degree_of_freedom <= 0) throw Exception(ErrorCodes::BAD_ARGUMENTS, "Argument (degrees of freedom) of function {} should be greater than zero", getName()); auto distribution = std::chi_squared_distribution<>(degree_of_freedom); @@ -110,7 +110,7 @@ struct StudentTDistribution static void generate(Float64 degree_of_freedom, ColumnFloat64::Container & container) { - if (degree_of_freedom <= 0) + if (!container.empty() && degree_of_freedom <= 0) throw Exception(ErrorCodes::BAD_ARGUMENTS, "Argument (degrees of freedom) of function {} should be greater than zero", getName()); auto distribution = std::student_t_distribution<>(degree_of_freedom); @@ -127,7 +127,7 @@ struct FisherFDistribution static void generate(Float64 d1, Float64 d2, ColumnFloat64::Container & container) { - if (d1 <= 0 || d2 <= 0) + if (!container.empty() && (d1 <= 0 || d2 <= 0)) throw Exception(ErrorCodes::BAD_ARGUMENTS, "Argument (degrees of freedom) of function {} should be greater than zero", getName()); auto distribution = std::fisher_f_distribution<>(d1, d2); @@ -215,9 +215,8 @@ template class FunctionRandomDistribution : public IFunction { private: - template - ResultType getParameterFromConstColumn(size_t parameter_number, const ColumnsWithTypeAndName & arguments) const + ResultType getParameterFromColumn(size_t parameter_number, const ColumnsWithTypeAndName & arguments) const { if (parameter_number >= arguments.size()) throw Exception( @@ -226,11 +225,22 @@ private: parameter_number, arguments.size()); const IColumn * col = arguments[parameter_number].column.get(); + ResultType parameter; - if (!isColumnConst(*col)) - throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Parameter number {} of function {} must be constant.", parameter_number, getName()); - - auto parameter = applyVisitor(FieldVisitorConvertToNumber(), assert_cast(*col).getField()); + if (const ColumnVector * col_in = checkAndGetColumn>(col)) + { + parameter = *col_in->getData().data(); + } + else if (isColumnConst(*col)) + { + parameter = applyVisitor(FieldVisitorConvertToNumber(), assert_cast(*col).getField()); + } + else + { + auto expected_type = Field(Field::Types::Which(Field::TypeToEnum>::value)).getTypeName(); + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Parameter number {} of function {} is expected to be {} but is {}", + parameter_number, getName(), expected_type, col->getName()); + } if (isNaN(parameter) || !std::isfinite(parameter)) throw Exception(ErrorCodes::BAD_ARGUMENTS, "Parameter number {} of function {} cannot be NaN of infinite", parameter_number, getName()); @@ -278,21 +288,21 @@ public: { auto res_column = ColumnUInt8::create(input_rows_count); auto & res_data = res_column->getData(); - Distribution::generate(getParameterFromConstColumn(0, arguments), res_data); + Distribution::generate(getParameterFromColumn(0, arguments), res_data); return res_column; } else if constexpr (std::is_same_v || std::is_same_v) { auto res_column = ColumnUInt64::create(input_rows_count); auto & res_data = res_column->getData(); - Distribution::generate(getParameterFromConstColumn(0, arguments), getParameterFromConstColumn(1, arguments), res_data); + Distribution::generate(getParameterFromColumn(0, arguments), getParameterFromColumn(1, arguments), res_data); return res_column; } else if constexpr (std::is_same_v) { auto res_column = ColumnUInt64::create(input_rows_count); auto & res_data = res_column->getData(); - Distribution::generate(getParameterFromConstColumn(0, arguments), res_data); + Distribution::generate(getParameterFromColumn(0, arguments), res_data); return res_column; } else @@ -301,11 +311,11 @@ public: auto & res_data = res_column->getData(); if constexpr (Distribution::getNumberOfArguments() == 1) { - Distribution::generate(getParameterFromConstColumn(0, arguments), res_data); + Distribution::generate(getParameterFromColumn(0, arguments), res_data); } else if constexpr (Distribution::getNumberOfArguments() == 2) { - Distribution::generate(getParameterFromConstColumn(0, arguments), getParameterFromConstColumn(1, arguments), res_data); + Distribution::generate(getParameterFromColumn(0, arguments), getParameterFromColumn(1, arguments), res_data); } else { diff --git a/tests/queries/0_stateless/02462_distributions.reference b/tests/queries/0_stateless/02462_distributions.reference index 56b04bcb856..2341fe20bcc 100644 --- a/tests/queries/0_stateless/02462_distributions.reference +++ b/tests/queries/0_stateless/02462_distributions.reference @@ -10,3 +10,5 @@ Ok Ok Ok Ok +Ok +Ok diff --git a/tests/queries/0_stateless/02462_distributions.sql b/tests/queries/0_stateless/02462_distributions.sql index b45dc897f2a..62a3ec3f6fa 100644 --- a/tests/queries/0_stateless/02462_distributions.sql +++ b/tests/queries/0_stateless/02462_distributions.sql @@ -22,3 +22,7 @@ SELECT DISTINCT if (a >= toFloat64(0), 'Ok', 'Fail') FROM (SELECT randNegativeBi SELECT DISTINCT if (a >= toFloat64(0), 'Ok', 'Fail') FROM (SELECT randPoisson(44) AS a FROM numbers(100000)); # No errors SELECT randUniform(1, 2, 1), randNormal(0, 1, 'abacaba'), randLogNormal(0, 10, 'b'), randChiSquared(1, 1), randStudentT(7, '8'), randFisherF(23, 42, 100), randBernoulli(0.5, 2), randBinomial(3, 0.5, 1), randNegativeBinomial(3, 0.5, 2), randPoisson(44, 44) FORMAT Null; +# Values should be >= 0 +SELECT DISTINCT if (a >= toFloat64(-15), 'Ok', 'Fail') FROM (SELECT randNormal(randNormal(0, 1), randUniform(1, 2)) AS a FROM numbers(10000)); +# Values should be >= 0 +SELECT DISTINCT if (a >= toFloat64(0), 'Ok', 'Fail') FROM (SELECT randUniform(randUniform(0, 1), randUniform(1, 2)) AS a FROM numbers(10000)); \ No newline at end of file