diff --git a/src/Functions/randDistribution.cpp b/src/Functions/randDistribution.cpp index dc2ddf2ef24..c14aba26d63 100644 --- a/src/Functions/randDistribution.cpp +++ b/src/Functions/randDistribution.cpp @@ -215,9 +215,8 @@ template class FunctionRandomDistribution : public IFunction { private: - template - ResultType getParameterFromConstColumn(size_t parameter_number, const ColumnsWithTypeAndName & arguments) const + ResultType getParameterFromColumn(size_t parameter_number, const ColumnsWithTypeAndName & arguments) const { if (parameter_number >= arguments.size()) throw Exception( @@ -226,11 +225,20 @@ private: parameter_number, arguments.size()); const IColumn * col = arguments[parameter_number].column.get(); + ResultType parameter; - if (!isColumnConst(*col)) - throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Parameter number {} of function {} must be constant.", parameter_number, getName()); - - auto parameter = applyVisitor(FieldVisitorConvertToNumber(), assert_cast(*col).getField()); + if (const ColumnVector * col_in = checkAndGetColumn>(col)) + { + parameter = *col_in->getData().data(); + } + else if (isColumnConst(*col)) + { + parameter = applyVisitor(FieldVisitorConvertToNumber(), assert_cast(*col).getField()); + } + else + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Parameter number {} of function {} is expected to be constant or ColumnVector", parameter_number, getName()); + } if (isNaN(parameter) || !std::isfinite(parameter)) throw Exception(ErrorCodes::BAD_ARGUMENTS, "Parameter number {} of function {} cannot be NaN of infinite", parameter_number, getName()); @@ -278,21 +286,21 @@ public: { auto res_column = ColumnUInt8::create(input_rows_count); auto & res_data = res_column->getData(); - Distribution::generate(getParameterFromConstColumn(0, arguments), res_data); + Distribution::generate(getParameterFromColumn(0, arguments), res_data); return res_column; } else if constexpr (std::is_same_v || std::is_same_v) { auto res_column = ColumnUInt64::create(input_rows_count); auto & res_data = res_column->getData(); - Distribution::generate(getParameterFromConstColumn(0, arguments), getParameterFromConstColumn(1, arguments), res_data); + Distribution::generate(getParameterFromColumn(0, arguments), getParameterFromColumn(1, arguments), res_data); return res_column; } else if constexpr (std::is_same_v) { auto res_column = ColumnUInt64::create(input_rows_count); auto & res_data = res_column->getData(); - Distribution::generate(getParameterFromConstColumn(0, arguments), res_data); + Distribution::generate(getParameterFromColumn(0, arguments), res_data); return res_column; } else @@ -301,11 +309,11 @@ public: auto & res_data = res_column->getData(); if constexpr (Distribution::getNumberOfArguments() == 1) { - Distribution::generate(getParameterFromConstColumn(0, arguments), res_data); + Distribution::generate(getParameterFromColumn(0, arguments), res_data); } else if constexpr (Distribution::getNumberOfArguments() == 2) { - Distribution::generate(getParameterFromConstColumn(0, arguments), getParameterFromConstColumn(1, arguments), res_data); + Distribution::generate(getParameterFromColumn(0, arguments), getParameterFromColumn(1, arguments), res_data); } else {