Add functions to generate random values according to the distribution (#42411)

This commit is contained in:
Nikita Mikhaylov 2022-10-20 17:25:28 +02:00 committed by GitHub
parent 0d8a814d80
commit 9a73eb2fbb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 508 additions and 0 deletions

View File

@ -0,0 +1,472 @@
#include <Functions/IFunction.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/FunctionFactory.h>
#include "Common/Exception.h"
#include <Common/NaNUtils.h>
#include <Columns/ColumnConst.h>
#include <Columns/ColumnsNumber.h>
#include <DataTypes/DataTypesNumber.h>
#include <Common/FieldVisitorConvertToNumber.h>
#include <Common/ProfileEvents.h>
#include <Common/assert_cast.h>
#include <IO/WriteHelpers.h>
#include <Interpreters/Context_fwd.h>
#include <random>
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int ILLEGAL_COLUMN;
extern const int BAD_ARGUMENTS;
extern const int LOGICAL_ERROR;
}
namespace
{
struct UniformDistribution
{
using ReturnType = DataTypeFloat64;
static constexpr const char * getName() { return "randUniform"; }
static constexpr size_t getNumberOfArguments() { return 2; }
static void generate(Float64 min, Float64 max, ColumnFloat64::Container & container)
{
auto distribution = std::uniform_real_distribution<>(min, max);
for (auto & elem : container)
elem = distribution(thread_local_rng);
}
};
struct NormalDistribution
{
using ReturnType = DataTypeFloat64;
static constexpr const char * getName() { return "randNormal"; }
static constexpr size_t getNumberOfArguments() { return 2; }
static void generate(Float64 mean, Float64 variance, ColumnFloat64::Container & container)
{
auto distribution = std::normal_distribution<>(mean, variance);
for (auto & elem : container)
elem = distribution(thread_local_rng);
}
};
struct LogNormalDistribution
{
using ReturnType = DataTypeFloat64;
static constexpr const char * getName() { return "randLogNormal"; }
static constexpr size_t getNumberOfArguments() { return 2; }
static void generate(Float64 mean, Float64 variance, ColumnFloat64::Container & container)
{
auto distribution = std::lognormal_distribution<>(mean, variance);
for (auto & elem : container)
elem = distribution(thread_local_rng);
}
};
struct ExponentialDistribution
{
using ReturnType = DataTypeFloat64;
static constexpr const char * getName() { return "randExponential"; }
static constexpr size_t getNumberOfArguments() { return 1; }
static void generate(Float64 lambda, ColumnFloat64::Container & container)
{
auto distribution = std::exponential_distribution<>(lambda);
for (auto & elem : container)
elem = distribution(thread_local_rng);
}
};
struct ChiSquaredDistribution
{
using ReturnType = DataTypeFloat64;
static constexpr const char * getName() { return "randChiSquared"; }
static constexpr size_t getNumberOfArguments() { return 1; }
static void generate(Float64 degree_of_freedom, ColumnFloat64::Container & container)
{
auto distribution = std::chi_squared_distribution<>(degree_of_freedom);
for (auto & elem : container)
elem = distribution(thread_local_rng);
}
};
struct StudentTDistribution
{
using ReturnType = DataTypeFloat64;
static constexpr const char * getName() { return "randStudentT"; }
static constexpr size_t getNumberOfArguments() { return 1; }
static void generate(Float64 degree_of_freedom, ColumnFloat64::Container & container)
{
auto distribution = std::student_t_distribution<>(degree_of_freedom);
for (auto & elem : container)
elem = distribution(thread_local_rng);
}
};
struct FisherFDistribution
{
using ReturnType = DataTypeFloat64;
static constexpr const char * getName() { return "randFisherF"; }
static constexpr size_t getNumberOfArguments() { return 2; }
static void generate(Float64 d1, Float64 d2, ColumnFloat64::Container & container)
{
auto distribution = std::fisher_f_distribution<>(d1, d2);
for (auto & elem : container)
elem = distribution(thread_local_rng);
}
};
struct BernoulliDistribution
{
using ReturnType = DataTypeUInt8;
static constexpr const char * getName() { return "randBernoulli"; }
static constexpr size_t getNumberOfArguments() { return 1; }
static void generate(Float64 p, ColumnUInt8::Container & container)
{
if (p < 0.0f || p > 1.0f)
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Argument of function {} should be inside [0, 1] because it is a probability", getName());
auto distribution = std::bernoulli_distribution(p);
for (auto & elem : container)
elem = static_cast<UInt8>(distribution(thread_local_rng));
}
};
struct BinomialDistribution
{
using ReturnType = DataTypeUInt64;
static constexpr const char * getName() { return "randBinomial"; }
static constexpr size_t getNumberOfArguments() { return 2; }
static void generate(UInt64 t, Float64 p, ColumnUInt64::Container & container)
{
if (p < 0.0f || p > 1.0f)
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Argument of function {} should be inside [0, 1] because it is a probability", getName());
auto distribution = std::binomial_distribution(t, p);
for (auto & elem : container)
elem = static_cast<UInt64>(distribution(thread_local_rng));
}
};
struct NegativeBinomialDistribution
{
using ReturnType = DataTypeUInt64;
static constexpr const char * getName() { return "randNegativeBinomial"; }
static constexpr size_t getNumberOfArguments() { return 2; }
static void generate(UInt64 t, Float64 p, ColumnUInt64::Container & container)
{
if (p < 0.0f || p > 1.0f)
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Argument of function {} should be inside [0, 1] because it is a probability", getName());
auto distribution = std::negative_binomial_distribution(t, p);
for (auto & elem : container)
elem = static_cast<UInt64>(distribution(thread_local_rng));
}
};
struct PoissonDistribution
{
using ReturnType = DataTypeUInt64;
static constexpr const char * getName() { return "randPoisson"; }
static constexpr size_t getNumberOfArguments() { return 1; }
static void generate(UInt64 n, ColumnUInt64::Container & container)
{
auto distribution = std::poisson_distribution(n);
for (auto & elem : container)
elem = static_cast<UInt64>(distribution(thread_local_rng));
}
};
}
/** Function which will generate values according to the specified distribution
* Accepts only constant arguments
* Similar to the functions rand and rand64 an additional 'tag' argument could be added to the
* end of arguments list (this argument will be ignored) which will guarantee that functions are not sticked together
* during optimisations.
* Example: SELECT randNormal(0, 1, 1), randNormal(0, 1, 2) FROM numbers(10)
* This query will return two different columns
*/
template <typename Distribution>
class FunctionRandomDistribution : public IFunction
{
private:
template <typename ResultType>
ResultType getParameterFromConstColumn(size_t parameter_number, const ColumnsWithTypeAndName & arguments) const
{
if (parameter_number >= arguments.size())
throw Exception(
ErrorCodes::LOGICAL_ERROR, "Parameter number ({}) is greater than the size of arguments ({}). This is a bug", parameter_number, arguments.size());
const IColumn * col = arguments[parameter_number].column.get();
if (!isColumnConst(*col))
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Parameter number {} of function must be constant.", parameter_number, getName());
auto parameter = applyVisitor(FieldVisitorConvertToNumber<ResultType>(), assert_cast<const ColumnConst &>(*col).getField());
if (isNaN(parameter) || !std::isfinite(parameter))
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Parameter number {} of function {} cannot be NaN of infinite", parameter_number, getName());
return parameter;
}
public:
static FunctionPtr create(ContextPtr)
{
return std::make_shared<FunctionRandomDistribution<Distribution>>();
}
static constexpr auto name = Distribution::getName();
String getName() const override { return name; }
size_t getNumberOfArguments() const override { return Distribution::getNumberOfArguments(); }
bool isVariadic() const override { return true; }
bool isDeterministic() const override { return false; }
bool isDeterministicInScopeOfQuery() const override { return false; }
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return false; }
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
auto desired = Distribution::getNumberOfArguments();
if (arguments.size() != desired && arguments.size() != desired + 1)
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Wrong number of arguments for function {}. Should be {} or {}", getName(), desired, desired + 1);
for (size_t i = 0; i < Distribution::getNumberOfArguments(); ++i)
{
const auto & type = arguments[i];
WhichDataType which(type);
if (!which.isFloat() && !which.isNativeUInt())
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Illegal type {} of argument of function {}, expected Float64 or integer", type->getName(), getName());
}
return std::make_shared<typename Distribution::ReturnType>();
}
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & /*result_type*/, size_t input_rows_count) const override
{
if constexpr (std::is_same_v<Distribution, BernoulliDistribution>)
{
auto res_column = ColumnUInt8::create(input_rows_count);
auto & res_data = res_column->getData();
Distribution::generate(getParameterFromConstColumn<Float64>(0, arguments), res_data);
return res_column;
}
else if constexpr (std::is_same_v<Distribution, BinomialDistribution> || std::is_same_v<Distribution, NegativeBinomialDistribution>)
{
auto res_column = ColumnUInt64::create(input_rows_count);
auto & res_data = res_column->getData();
Distribution::generate(getParameterFromConstColumn<UInt64>(0, arguments), getParameterFromConstColumn<Float64>(1, arguments), res_data);
return res_column;
}
else if constexpr (std::is_same_v<Distribution, PoissonDistribution>)
{
auto res_column = ColumnUInt64::create(input_rows_count);
auto & res_data = res_column->getData();
Distribution::generate(getParameterFromConstColumn<UInt64>(0, arguments), res_data);
return res_column;
}
else
{
auto res_column = ColumnFloat64::create(input_rows_count);
auto & res_data = res_column->getData();
if constexpr (Distribution::getNumberOfArguments() == 1)
{
Distribution::generate(getParameterFromConstColumn<Float64>(0, arguments), res_data);
}
else if constexpr (Distribution::getNumberOfArguments() == 2)
{
Distribution::generate(getParameterFromConstColumn<Float64>(0, arguments), getParameterFromConstColumn<Float64>(1, arguments), res_data);
}
else
{
throw Exception(ErrorCodes::BAD_ARGUMENTS, "More than two argument specified for function {}", getName());
}
return res_column;
}
}
};
REGISTER_FUNCTION(Distribution)
{
factory.registerFunction<FunctionRandomDistribution<UniformDistribution>>(
{
R"(
Returns a random number from the uniform distribution in the specified range.
Accepts two parameters - minimum bound and maximum bound.
Typical usage:
[example:typical]
)",
Documentation::Examples{
{"typical", "SELECT randUniform(0, 1) FROM numbers(100000);"}},
Documentation::Categories{"Distribution"}
});
factory.registerFunction<FunctionRandomDistribution<NormalDistribution>>(
{
R"(
Returns a random number from the normal distribution.
Accepts two parameters - mean and variance.
Typical usage:
[example:typical]
)",
Documentation::Examples{
{"typical", "SELECT randNormal(0, 5) FROM numbers(100000);"}},
Documentation::Categories{"Distribution"}
});
factory.registerFunction<FunctionRandomDistribution<LogNormalDistribution>>(
{
R"(
Returns a random number from the lognormal distribution (a distribution of a random variable whose logarithm is normally distributed).
Accepts two parameters - mean and variance.
Typical usage:
[example:typical]
)",
Documentation::Examples{
{"typical", "SELECT randLogNormal(0, 5) FROM numbers(100000);"}},
Documentation::Categories{"Distribution"}
});
factory.registerFunction<FunctionRandomDistribution<ExponentialDistribution>>(
{
R"(
Returns a random number from the exponential distribution.
Accepts one parameter.
Typical usage:
[example:typical]
)",
Documentation::Examples{
{"typical", "SELECT randExponential(0, 5) FROM numbers(100000);"}},
Documentation::Categories{"Distribution"}
});
factory.registerFunction<FunctionRandomDistribution<ChiSquaredDistribution>>(
{
R"(
Returns a random number from the chi-squared distribution (a distribution of a sum of the squares of k independent standard normal random variables).
Accepts one parameter - degree of freedom.
Typical usage:
[example:typical]
)",
Documentation::Examples{
{"typical", "SELECT randChiSquared(5) FROM numbers(100000);"}},
Documentation::Categories{"Distribution"}
});
factory.registerFunction<FunctionRandomDistribution<StudentTDistribution>>(
{
R"(
Returns a random number from the t-distribution.
Accepts one parameter - degree of freedom.
Typical usage:
[example:typical]
)",
Documentation::Examples{
{"typical", "SELECT randStudentT(5) FROM numbers(100000);"}},
Documentation::Categories{"Distribution"}
});
factory.registerFunction<FunctionRandomDistribution<FisherFDistribution>>(
{
R"(
Returns a random number from the f-distribution.
The F-distribution is the distribution of X = (S1 / d1) / (S2 / d2) where d1 and d2 are degrees of freedom.
Accepts two parameters - degrees of freedom.
Typical usage:
[example:typical]
)",
Documentation::Examples{
{"typical", "SELECT randFisherF(5) FROM numbers(100000);"}},
Documentation::Categories{"Distribution"}
});
factory.registerFunction<FunctionRandomDistribution<BernoulliDistribution>>(
{
R"(
Returns a random number from the Bernoulli distribution.
Accepts two parameters - probability of success.
Typical usage:
[example:typical]
)",
Documentation::Examples{
{"typical", "SELECT randBernoulli(0.1) FROM numbers(100000);"}},
Documentation::Categories{"Distribution"}
});
factory.registerFunction<FunctionRandomDistribution<BinomialDistribution>>(
{
R"(
Returns a random number from the binomial distribution.
Accepts two parameters - number of experiments and probability of success in each experiment.
Typical usage:
[example:typical]
)",
Documentation::Examples{
{"typical", "SELECT randBinomial(10, 0.1) FROM numbers(100000);"}},
Documentation::Categories{"Distribution"}
});
factory.registerFunction<FunctionRandomDistribution<NegativeBinomialDistribution>>(
{
R"(
Returns a random number from the negative binomial distribution.
Accepts two parameters - number of experiments and probability of success in each experiment.
Typical usage:
[example:typical]
)",
Documentation::Examples{
{"typical", "SELECT randNegativeBinomial(10, 0.1) FROM numbers(100000);"}},
Documentation::Categories{"Distribution"}
});
factory.registerFunction<FunctionRandomDistribution<PoissonDistribution>>(
{
R"(
Returns a random number from the poisson distribution.
Accepts two parameters - the mean number of occurrences.
Typical usage:
[example:typical]
)",
Documentation::Examples{
{"typical", "SELECT randPoisson(3) FROM numbers(100000);"}},
Documentation::Categories{"Distribution"}
});
}
}

View File

@ -0,0 +1,12 @@
Ok
Ok
Ok
Ok
Ok
Ok
Ok
0
1
Ok
Ok
Ok

View File

@ -0,0 +1,24 @@
# Values should be between 0 and 1
SELECT DISTINCT if (a >= toFloat64(0) AND a <= toFloat64(1), 'Ok', 'Fail') FROM (SELECT randUniform(0, 1) AS a FROM numbers(100000));
# Mean should be around 0
SELECT DISTINCT if (m >= toFloat64(-0.2) AND m <= toFloat64(0.2), 'Ok', 'Fail') FROM (SELECT avg(a) as m FROM (SELECT randNormal(0, 5) AS a FROM numbers(100000)));
# Values should be >= 0
SELECT DISTINCT if (a >= toFloat64(0), 'Ok', 'Fail') FROM (SELECT randLogNormal(0, 5) AS a FROM numbers(100000));
# Values should be >= 0
SELECT DISTINCT if (a >= toFloat64(0), 'Ok', 'Fail') FROM (SELECT randExponential(15) AS a FROM numbers(100000));
# Values should be >= 0
SELECT DISTINCT if (a >= toFloat64(0), 'Ok', 'Fail') FROM (SELECT randChiSquared(3) AS a FROM numbers(100000));
# Mean should be around 0
SELECT DISTINCT if (m > toFloat64(-0.2) AND m < toFloat64(0.2), 'Ok', 'Fail') FROM (SELECT avg(a) as m FROM (SELECT randStudentT(5) AS a FROM numbers(100000)));
# Values should be >= 0
SELECT DISTINCT if (a >= toFloat64(0), 'Ok', 'Fail') FROM (SELECT randFisherF(3, 4) AS a FROM numbers(100000));
# There should be only 0s and 1s
SELECT a FROM (SELECT DISTINCT randBernoulli(0.5) AS a FROM numbers(100000)) ORDER BY a;
# Values should be >= 0
SELECT DISTINCT if (a >= toFloat64(0), 'Ok', 'Fail') FROM (SELECT randBinomial(3, 0.5) AS a FROM numbers(100000));
# Values should be >= 0
SELECT DISTINCT if (a >= toFloat64(0), 'Ok', 'Fail') FROM (SELECT randNegativeBinomial(3, 0.5) AS a FROM numbers(100000));
# Values should be >= 0
SELECT DISTINCT if (a >= toFloat64(0), 'Ok', 'Fail') FROM (SELECT randPoisson(44) AS a FROM numbers(100000));
# No errors
SELECT randUniform(1, 2, 1), randNormal(0, 1, 'abacaba'), randLogNormal(0, 10, 'b'), randChiSquared(1, 1), randStudentT(7, '8'), randFisherF(23, 42, 100), randBernoulli(0.5, 2), randBinomial(3, 0.5, 1), randNegativeBinomial(3, 0.5, 2), randPoisson(44, 44) FORMAT Null;