This commit is contained in:
Michał Tabaszewski 2024-09-18 23:23:59 +03:00 committed by GitHub
commit 08dc241b06
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 31 additions and 15 deletions

View File

@ -1,6 +1,7 @@
#include <Functions/IFunction.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/FunctionFactory.h>
#include <Core/Field.h>
#include <Common/Exception.h>
#include <Common/thread_local_rng.h>
#include <Common/NaNUtils.h>
@ -21,7 +22,6 @@ namespace DB
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int ILLEGAL_COLUMN;
extern const int BAD_ARGUMENTS;
extern const int LOGICAL_ERROR;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
@ -93,7 +93,7 @@ struct ChiSquaredDistribution
static void generate(Float64 degree_of_freedom, ColumnFloat64::Container & container)
{
if (degree_of_freedom <= 0)
if (!container.empty() && degree_of_freedom <= 0)
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Argument (degrees of freedom) of function {} should be greater than zero", getName());
auto distribution = std::chi_squared_distribution<>(degree_of_freedom);
@ -110,7 +110,7 @@ struct StudentTDistribution
static void generate(Float64 degree_of_freedom, ColumnFloat64::Container & container)
{
if (degree_of_freedom <= 0)
if (!container.empty() && degree_of_freedom <= 0)
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Argument (degrees of freedom) of function {} should be greater than zero", getName());
auto distribution = std::student_t_distribution<>(degree_of_freedom);
@ -127,7 +127,7 @@ struct FisherFDistribution
static void generate(Float64 d1, Float64 d2, ColumnFloat64::Container & container)
{
if (d1 <= 0 || d2 <= 0)
if (!container.empty() && (d1 <= 0 || d2 <= 0))
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Argument (degrees of freedom) of function {} should be greater than zero", getName());
auto distribution = std::fisher_f_distribution<>(d1, d2);
@ -215,9 +215,8 @@ template <typename Distribution>
class FunctionRandomDistribution : public IFunction
{
private:
template <typename ResultType>
ResultType getParameterFromConstColumn(size_t parameter_number, const ColumnsWithTypeAndName & arguments) const
ResultType getParameterFromColumn(size_t parameter_number, const ColumnsWithTypeAndName & arguments) const
{
if (parameter_number >= arguments.size())
throw Exception(
@ -226,11 +225,22 @@ private:
parameter_number, arguments.size());
const IColumn * col = arguments[parameter_number].column.get();
ResultType parameter;
if (!isColumnConst(*col))
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Parameter number {} of function {} must be constant.", parameter_number, getName());
auto parameter = applyVisitor(FieldVisitorConvertToNumber<ResultType>(), assert_cast<const ColumnConst &>(*col).getField());
if (const ColumnVector<ResultType> * col_in = checkAndGetColumn<ColumnVector<ResultType>>(col))
{
parameter = *col_in->getData().data();
}
else if (isColumnConst(*col))
{
parameter = applyVisitor(FieldVisitorConvertToNumber<ResultType>(), assert_cast<const ColumnConst &>(*col).getField());
}
else
{
auto expected_type = Field(Field::Types::Which(Field::TypeToEnum<std::decay_t<ResultType>>::value)).getTypeName();
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Parameter number {} of function {} is expected to be {} but is {}",
parameter_number, getName(), expected_type, col->getName());
}
if (isNaN(parameter) || !std::isfinite(parameter))
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Parameter number {} of function {} cannot be NaN of infinite", parameter_number, getName());
@ -278,21 +288,21 @@ public:
{
auto res_column = ColumnUInt8::create(input_rows_count);
auto & res_data = res_column->getData();
Distribution::generate(getParameterFromConstColumn<Float64>(0, arguments), res_data);
Distribution::generate(getParameterFromColumn<Float64>(0, arguments), res_data);
return res_column;
}
else if constexpr (std::is_same_v<Distribution, BinomialDistribution> || std::is_same_v<Distribution, NegativeBinomialDistribution>)
{
auto res_column = ColumnUInt64::create(input_rows_count);
auto & res_data = res_column->getData();
Distribution::generate(getParameterFromConstColumn<UInt64>(0, arguments), getParameterFromConstColumn<Float64>(1, arguments), res_data);
Distribution::generate(getParameterFromColumn<UInt64>(0, arguments), getParameterFromColumn<Float64>(1, arguments), res_data);
return res_column;
}
else if constexpr (std::is_same_v<Distribution, PoissonDistribution>)
{
auto res_column = ColumnUInt64::create(input_rows_count);
auto & res_data = res_column->getData();
Distribution::generate(getParameterFromConstColumn<UInt64>(0, arguments), res_data);
Distribution::generate(getParameterFromColumn<UInt64>(0, arguments), res_data);
return res_column;
}
else
@ -301,11 +311,11 @@ public:
auto & res_data = res_column->getData();
if constexpr (Distribution::getNumberOfArguments() == 1)
{
Distribution::generate(getParameterFromConstColumn<Float64>(0, arguments), res_data);
Distribution::generate(getParameterFromColumn<Float64>(0, arguments), res_data);
}
else if constexpr (Distribution::getNumberOfArguments() == 2)
{
Distribution::generate(getParameterFromConstColumn<Float64>(0, arguments), getParameterFromConstColumn<Float64>(1, arguments), res_data);
Distribution::generate(getParameterFromColumn<Float64>(0, arguments), getParameterFromColumn<Float64>(1, arguments), res_data);
}
else
{

View File

@ -10,3 +10,5 @@ Ok
Ok
Ok
Ok
Ok
Ok

View File

@ -22,3 +22,7 @@ SELECT DISTINCT if (a >= toFloat64(0), 'Ok', 'Fail') FROM (SELECT randNegativeBi
SELECT DISTINCT if (a >= toFloat64(0), 'Ok', 'Fail') FROM (SELECT randPoisson(44) AS a FROM numbers(100000));
# No errors
SELECT randUniform(1, 2, 1), randNormal(0, 1, 'abacaba'), randLogNormal(0, 10, 'b'), randChiSquared(1, 1), randStudentT(7, '8'), randFisherF(23, 42, 100), randBernoulli(0.5, 2), randBinomial(3, 0.5, 1), randNegativeBinomial(3, 0.5, 2), randPoisson(44, 44) FORMAT Null;
# Values should be >= 0
SELECT DISTINCT if (a >= toFloat64(-15), 'Ok', 'Fail') FROM (SELECT randNormal(randNormal(0, 1), randUniform(1, 2)) AS a FROM numbers(10000));
# Values should be >= 0
SELECT DISTINCT if (a >= toFloat64(0), 'Ok', 'Fail') FROM (SELECT randUniform(randUniform(0, 1), randUniform(1, 2)) AS a FROM numbers(10000));