Allowed non-const args for randDistrib

This commit is contained in:
Michal Tabaszewski 2024-08-27 13:14:04 +02:00
parent 0c14ac782e
commit 20f8176a2d

View File

@ -215,9 +215,8 @@ template <typename Distribution>
class FunctionRandomDistribution : public IFunction
{
private:
template <typename ResultType>
ResultType getParameterFromConstColumn(size_t parameter_number, const ColumnsWithTypeAndName & arguments) const
ResultType getParameterFromColumn(size_t parameter_number, const ColumnsWithTypeAndName & arguments) const
{
if (parameter_number >= arguments.size())
throw Exception(
@ -226,11 +225,20 @@ private:
parameter_number, arguments.size());
const IColumn * col = arguments[parameter_number].column.get();
ResultType parameter;
if (!isColumnConst(*col))
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Parameter number {} of function {} must be constant.", parameter_number, getName());
auto parameter = applyVisitor(FieldVisitorConvertToNumber<ResultType>(), assert_cast<const ColumnConst &>(*col).getField());
if (const ColumnVector<ResultType> * col_in = checkAndGetColumn<ColumnVector<ResultType>>(col))
{
parameter = *col_in->getData().data();
}
else if (isColumnConst(*col))
{
parameter = applyVisitor(FieldVisitorConvertToNumber<ResultType>(), assert_cast<const ColumnConst &>(*col).getField());
}
else
{
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Parameter number {} of function {} is expected to be constant or ColumnVector", parameter_number, getName());
}
if (isNaN(parameter) || !std::isfinite(parameter))
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Parameter number {} of function {} cannot be NaN of infinite", parameter_number, getName());
@ -278,21 +286,21 @@ public:
{
auto res_column = ColumnUInt8::create(input_rows_count);
auto & res_data = res_column->getData();
Distribution::generate(getParameterFromConstColumn<Float64>(0, arguments), res_data);
Distribution::generate(getParameterFromColumn<Float64>(0, arguments), res_data);
return res_column;
}
else if constexpr (std::is_same_v<Distribution, BinomialDistribution> || std::is_same_v<Distribution, NegativeBinomialDistribution>)
{
auto res_column = ColumnUInt64::create(input_rows_count);
auto & res_data = res_column->getData();
Distribution::generate(getParameterFromConstColumn<UInt64>(0, arguments), getParameterFromConstColumn<Float64>(1, arguments), res_data);
Distribution::generate(getParameterFromColumn<UInt64>(0, arguments), getParameterFromColumn<Float64>(1, arguments), res_data);
return res_column;
}
else if constexpr (std::is_same_v<Distribution, PoissonDistribution>)
{
auto res_column = ColumnUInt64::create(input_rows_count);
auto & res_data = res_column->getData();
Distribution::generate(getParameterFromConstColumn<UInt64>(0, arguments), res_data);
Distribution::generate(getParameterFromColumn<UInt64>(0, arguments), res_data);
return res_column;
}
else
@ -301,11 +309,11 @@ public:
auto & res_data = res_column->getData();
if constexpr (Distribution::getNumberOfArguments() == 1)
{
Distribution::generate(getParameterFromConstColumn<Float64>(0, arguments), res_data);
Distribution::generate(getParameterFromColumn<Float64>(0, arguments), res_data);
}
else if constexpr (Distribution::getNumberOfArguments() == 2)
{
Distribution::generate(getParameterFromConstColumn<Float64>(0, arguments), getParameterFromConstColumn<Float64>(1, arguments), res_data);
Distribution::generate(getParameterFromColumn<Float64>(0, arguments), getParameterFromColumn<Float64>(1, arguments), res_data);
}
else
{