mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-10 17:44:23 +00:00
review fixes
This commit is contained in:
parent
b6d2c9f4d2
commit
daf4690d37
@ -24,9 +24,6 @@ AggregateFunctionPtr createAggregateFunctionMLMethod(
|
||||
|
||||
for (size_t i = 0; i < argument_types.size(); ++i)
|
||||
{
|
||||
// if (!WhichDataType(argument_types[i]).isFloat64())
|
||||
// throw Exception("Illegal type " + argument_types[i]->getName() + " of argument " + std::to_string(i) + "for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
// if (!WhichDataType(argument_types[i]).isNumeric())
|
||||
if (!isNumber(argument_types[i]))
|
||||
throw Exception("Argument " + std::to_string(i) + " of type " + argument_types[i]->getName() + " must be numeric for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
}
|
||||
@ -55,28 +52,20 @@ AggregateFunctionPtr createAggregateFunctionMLMethod(
|
||||
}
|
||||
if (parameters.size() > 3)
|
||||
{
|
||||
if (applyVisitor(FieldVisitorConvertToNumber<UInt32>(), parameters[3]) == Float64{1.0})
|
||||
{
|
||||
if (applyVisitor(FieldVisitorToString(), parameters[3]) == "\'SGD\'") {
|
||||
weights_updater = std::make_shared<StochasticGradientDescent>();
|
||||
}
|
||||
else if (applyVisitor(FieldVisitorConvertToNumber<UInt32>(), parameters[3]) == Float64{2.0})
|
||||
{
|
||||
else if (applyVisitor(FieldVisitorToString(), parameters[3]) == "\'Momentum\'") {
|
||||
weights_updater = std::make_shared<Momentum>();
|
||||
}
|
||||
else if (applyVisitor(FieldVisitorConvertToNumber<UInt32>(), parameters[3]) == Float64{3.0})
|
||||
{
|
||||
else if (applyVisitor(FieldVisitorToString(), parameters[3]) == "\'Nesterov\'") {
|
||||
weights_updater = std::make_shared<Nesterov>();
|
||||
|
||||
}
|
||||
else
|
||||
{
|
||||
throw Exception("Invalid parameter for weights updater", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
}
|
||||
}
|
||||
// else
|
||||
// {
|
||||
// weights_updater = std::make_unique<StochasticGradientDescent>();
|
||||
// }
|
||||
|
||||
if (std::is_same<Method, FuncLinearRegression>::value)
|
||||
{
|
||||
@ -91,7 +80,6 @@ AggregateFunctionPtr createAggregateFunctionMLMethod(
|
||||
throw Exception("Such gradient computer is not implemented yet", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
}
|
||||
|
||||
|
||||
return std::make_shared<Method>(argument_types.size() - 1,
|
||||
gradient_computer, weights_updater,
|
||||
learning_rate, l2_reg_coef, batch_size, argument_types, parameters);
|
||||
|
@ -14,6 +14,7 @@
|
||||
#include <Columns/ColumnsNumber.h>
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
#include <Functions/FunctionHelpers.h>
|
||||
#include <Interpreters/castColumn.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <exception>
|
||||
@ -45,7 +46,7 @@ public:
|
||||
virtual void predict(ColumnVector<Float64>::Container &container,
|
||||
Block &block, const ColumnNumbers &arguments,
|
||||
const std::vector<Float64> &weights,
|
||||
Float64 bias) const = 0;
|
||||
Float64 bias, const Context & context) const = 0;
|
||||
};
|
||||
|
||||
|
||||
@ -61,19 +62,14 @@ public:
|
||||
Float64 derivative = (target - bias);
|
||||
for (size_t i = 0; i < weights.size(); ++i)
|
||||
{
|
||||
// auto value = static_cast<const ColumnVector<Float64> &>(*columns[i]).getData()[row_num];
|
||||
auto value = (*columns[i])[row_num].get<Float64>();
|
||||
// if ((*columns[i])[row_num].getType() == Field::Types::Float64)
|
||||
derivative -= weights[i] * value;
|
||||
// else
|
||||
// derivative -= weights[i] * (*columns[i])[row_num].get<Float32>();
|
||||
}
|
||||
derivative *= (2 * learning_rate);
|
||||
|
||||
batch_gradient[weights.size()] += derivative;
|
||||
for (size_t i = 0; i < weights.size(); ++i)
|
||||
{
|
||||
// auto value = static_cast<const ColumnVector<Float64> &>(*columns[i]).getData()[row_num];
|
||||
auto value = (*columns[i])[row_num].get<Float64>();
|
||||
batch_gradient[i] += derivative * value - 2 * l2_reg_coef * weights[i];
|
||||
}
|
||||
@ -82,24 +78,30 @@ public:
|
||||
void predict(ColumnVector<Float64>::Container &container,
|
||||
Block &block,
|
||||
const ColumnNumbers &arguments,
|
||||
const std::vector<Float64> &weights, Float64 bias) const override
|
||||
const std::vector<Float64> &weights, Float64 bias, const Context & context) const override
|
||||
{
|
||||
size_t rows_num = block.rows();
|
||||
std::vector<Float64> results(rows_num, bias);
|
||||
|
||||
for (size_t i = 1; i < arguments.size(); ++i)
|
||||
{
|
||||
ColumnPtr cur_col = block.safeGetByPosition(arguments[i]).column;
|
||||
const ColumnWithTypeAndName & cur_col = block.getByPosition(arguments[i]);
|
||||
if (!isNumber(cur_col.type)) {
|
||||
throw Exception("Prediction arguments must be have numeric type", ErrorCodes::BAD_ARGUMENTS);
|
||||
}
|
||||
|
||||
/// If column type is already Float64 then castColumn simply returns it
|
||||
auto features_col_ptr = castColumn(cur_col, std::make_shared<DataTypeFloat64>(), context);
|
||||
auto features_column = typeid_cast<const ColumnFloat64 *>(features_col_ptr.get());
|
||||
|
||||
if (!features_column)
|
||||
{
|
||||
throw Exception("Unexpectedly cannot dynamically cast features column " + std::to_string(i), ErrorCodes::LOGICAL_ERROR);
|
||||
}
|
||||
|
||||
for (size_t row_num = 0; row_num != rows_num; ++row_num)
|
||||
{
|
||||
|
||||
const auto &element = (*cur_col)[row_num];
|
||||
// if (element.getType() != Field::Types::Float64)
|
||||
if (!DB::Field::isSimpleNumeric(element.getType()))
|
||||
throw Exception("Prediction arguments must be have numeric type",
|
||||
ErrorCodes::BAD_ARGUMENTS);
|
||||
|
||||
results[row_num] += weights[i - 1] * element.get<Float64>();
|
||||
results[row_num] += weights[i - 1] * features_column->getElement(row_num);
|
||||
}
|
||||
}
|
||||
|
||||
@ -124,7 +126,7 @@ public:
|
||||
Float64 derivative = bias;
|
||||
for (size_t i = 0; i < weights.size(); ++i)
|
||||
{
|
||||
auto value = static_cast<const ColumnVector<Float64> &>(*columns[i]).getData()[row_num];
|
||||
auto value = (*columns[i])[row_num].get<Float64>();
|
||||
derivative += weights[i] * value;
|
||||
}
|
||||
derivative *= target;
|
||||
@ -133,7 +135,7 @@ public:
|
||||
batch_gradient[weights.size()] += learning_rate * target / (derivative + 1);
|
||||
for (size_t i = 0; i < weights.size(); ++i)
|
||||
{
|
||||
auto value = static_cast<const ColumnVector<Float64> &>(*columns[i]).getData()[row_num];
|
||||
auto value = (*columns[i])[row_num].get<Float64>();
|
||||
batch_gradient[i] += learning_rate * target * value / (derivative + 1)
|
||||
- 2 * l2_reg_coef * weights[i];
|
||||
}
|
||||
@ -142,22 +144,30 @@ public:
|
||||
void predict(ColumnVector<Float64>::Container & container,
|
||||
Block & block,
|
||||
const ColumnNumbers & arguments,
|
||||
const std::vector<Float64> & weights, Float64 bias) const override
|
||||
const std::vector<Float64> & weights, Float64 bias, const Context & context) const override
|
||||
{
|
||||
size_t rows_num = block.rows();
|
||||
std::vector<Float64> results(rows_num, bias);
|
||||
|
||||
for (size_t i = 1; i < arguments.size(); ++i)
|
||||
{
|
||||
ColumnPtr cur_col = block.safeGetByPosition(arguments[i]).column;
|
||||
const ColumnWithTypeAndName & cur_col = block.getByPosition(arguments[i]);
|
||||
if (!isNumber(cur_col.type)) {
|
||||
throw Exception("Prediction arguments must be have numeric type", ErrorCodes::BAD_ARGUMENTS);
|
||||
}
|
||||
|
||||
/// If column type is already Float64 then castColumn simply returns it
|
||||
auto features_col_ptr = castColumn(cur_col, std::make_shared<DataTypeFloat64>(), context);
|
||||
auto features_column = typeid_cast<const ColumnFloat64 *>(features_col_ptr.get());
|
||||
|
||||
if (!features_column)
|
||||
{
|
||||
throw Exception("Unexpectedly cannot dynamically cast features column " + std::to_string(i), ErrorCodes::LOGICAL_ERROR);
|
||||
}
|
||||
|
||||
for (size_t row_num = 0; row_num != rows_num; ++row_num)
|
||||
{
|
||||
const auto &element = (*cur_col)[row_num];
|
||||
// if (element.getType() != Field::Types::Float64)
|
||||
if (!DB::Field::isSimpleNumeric(element.getType()))
|
||||
throw Exception("Prediction arguments must have numeric type", ErrorCodes::BAD_ARGUMENTS);
|
||||
|
||||
results[row_num] += weights[i - 1] * element.get<Float64>();
|
||||
results[row_num] += weights[i - 1] * features_column->getElement(row_num);
|
||||
}
|
||||
}
|
||||
|
||||
@ -382,8 +392,7 @@ public:
|
||||
void add(const IColumn **columns, size_t row_num)
|
||||
{
|
||||
/// first column stores target; features start from (columns + 1)
|
||||
const auto &target = static_cast<const ColumnVector<Float64> &>(*columns[0]).getData()[row_num];
|
||||
|
||||
const auto &target = (*columns[0])[row_num].get<Float64>();
|
||||
/// Here we have columns + 1 as first column corresponds to target value, and others - to features
|
||||
weights_updater->add_to_batch(gradient_batch, *gradient_computer,
|
||||
weights, bias, learning_rate, l2_reg_coef, target, columns + 1, row_num);
|
||||
@ -435,9 +444,9 @@ public:
|
||||
weights_updater->read(buf);
|
||||
}
|
||||
|
||||
void predict(ColumnVector<Float64>::Container &container, Block &block, const ColumnNumbers &arguments) const
|
||||
void predict(ColumnVector<Float64>::Container &container, Block &block, const ColumnNumbers &arguments, const Context & context) const
|
||||
{
|
||||
gradient_computer->predict(container, block, arguments, weights, bias);
|
||||
gradient_computer->predict(container, block, arguments, weights, bias, context);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -529,7 +538,7 @@ public:
|
||||
this->data(place).read(buf);
|
||||
}
|
||||
|
||||
void predictValues(ConstAggregateDataPtr place, IColumn & to, Block & block, const ColumnNumbers & arguments) const override
|
||||
void predictValues(ConstAggregateDataPtr place, IColumn & to, Block & block, const ColumnNumbers & arguments, const Context & context) const override
|
||||
{
|
||||
if (arguments.size() != param_num + 1)
|
||||
throw Exception("Predict got incorrect number of arguments. Got: " +
|
||||
@ -538,7 +547,7 @@ public:
|
||||
|
||||
auto &column = dynamic_cast<ColumnVector<Float64> &>(to);
|
||||
|
||||
this->data(place).predict(column.getData(), block, arguments);
|
||||
this->data(place).predict(column.getData(), block, arguments, context);
|
||||
}
|
||||
|
||||
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
|
||||
|
@ -96,8 +96,10 @@ public:
|
||||
|
||||
/// This function is used for machine learning methods
|
||||
virtual void predictValues(ConstAggregateDataPtr /* place */, IColumn & /*to*/,
|
||||
Block & /*block*/, const ColumnNumbers & /*arguments*/) const
|
||||
{}
|
||||
Block & /*block*/, const ColumnNumbers & /*arguments*/, const Context & /*context*/) const
|
||||
{
|
||||
throw Exception("Method predictValues is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
|
||||
}
|
||||
|
||||
/** Returns true for aggregate functions of type -State.
|
||||
* They are executed as other aggregate functions, but not finalized (return an aggregation state that can be combined with another).
|
||||
|
@ -103,36 +103,22 @@ MutableColumnPtr ColumnAggregateFunction::convertToValues() const
|
||||
return res;
|
||||
}
|
||||
|
||||
MutableColumnPtr ColumnAggregateFunction::predictValues(Block & block, const ColumnNumbers & arguments) const
|
||||
MutableColumnPtr ColumnAggregateFunction::predictValues(Block & block, const ColumnNumbers & arguments, const Context & context) const
|
||||
{
|
||||
MutableColumnPtr res;
|
||||
if (convertion(&res))
|
||||
{
|
||||
return res;
|
||||
}
|
||||
convertion(&res);
|
||||
|
||||
// auto ML_function_Linear = typeid_cast<AggregateFunctionMLMethod<LinearModelData, NameLinearRegression> *>(func.get());
|
||||
// auto ML_function_Logistic = typeid_cast<AggregateFunctionMLMethod<LinearModelData, NameLogisticRegression> *>(func.get());
|
||||
auto ML_function = func.get();
|
||||
if (ML_function)
|
||||
{
|
||||
size_t row_num = 0;
|
||||
for (auto val : data)
|
||||
{
|
||||
ML_function->predictValues(val, *res, block, arguments);
|
||||
ML_function->predictValues(val, *res, block, arguments, context);
|
||||
++row_num;
|
||||
}
|
||||
|
||||
}
|
||||
// else if (ML_function_Logistic)
|
||||
// {
|
||||
// size_t row_num = 0;
|
||||
// for (auto val : data)
|
||||
// {
|
||||
// ML_function_Logistic->predictValues(val, *res, block, arguments);
|
||||
// ++row_num;
|
||||
// }
|
||||
// }
|
||||
else
|
||||
{
|
||||
throw Exception("Illegal aggregate function is passed",
|
||||
|
@ -119,7 +119,7 @@ public:
|
||||
const char * getFamilyName() const override { return "AggregateFunction"; }
|
||||
|
||||
bool convertion(MutableColumnPtr* res_) const;
|
||||
MutableColumnPtr predictValues(Block & block, const ColumnNumbers & arguments) const;
|
||||
MutableColumnPtr predictValues(Block & block, const ColumnNumbers & arguments, const Context & context) const;
|
||||
|
||||
size_t size() const override
|
||||
{
|
||||
|
@ -200,7 +200,6 @@ public:
|
||||
template <typename T> struct TypeToEnum;
|
||||
template <Types::Which which> struct EnumToType;
|
||||
|
||||
static bool isSimpleNumeric(Types::Which which) { return which >= Types::UInt64 && which <= Types::Int128; }
|
||||
static bool IsDecimal(Types::Which which) { return which >= Types::Decimal32 && which <= Types::Decimal128; }
|
||||
|
||||
Field()
|
||||
|
@ -29,10 +29,12 @@ class FunctionEvalMLMethod : public IFunction
|
||||
{
|
||||
public:
|
||||
static constexpr auto name = "evalMLMethod";
|
||||
static FunctionPtr create(const Context &)
|
||||
static FunctionPtr create(const Context & context)
|
||||
{
|
||||
return std::make_shared<FunctionEvalMLMethod>();
|
||||
return std::make_shared<FunctionEvalMLMethod>(context);
|
||||
}
|
||||
FunctionEvalMLMethod(const Context & context) : context(context)
|
||||
{}
|
||||
|
||||
String getName() const override
|
||||
{
|
||||
@ -72,14 +74,13 @@ public:
|
||||
|
||||
if (!column_with_states)
|
||||
throw Exception("Illegal column " + block.getByPosition(arguments[0]).column->getName()
|
||||
+ " of first argument of function "
|
||||
+ getName(),
|
||||
ErrorCodes::ILLEGAL_COLUMN);
|
||||
+ " of first argument of function " + getName(), ErrorCodes::ILLEGAL_COLUMN);
|
||||
|
||||
block.getByPosition(result).column =
|
||||
typeid_cast<const ColumnAggregateFunction *>(&*column_with_states->getDataColumnPtr())->predictValues(block, arguments);
|
||||
typeid_cast<const ColumnAggregateFunction *>(&*column_with_states->getDataColumnPtr())->predictValues(block, arguments, context);
|
||||
}
|
||||
|
||||
const Context & context;
|
||||
};
|
||||
|
||||
void registerFunctionEvalMLMethod(FunctionFactory & factory)
|
||||
|
@ -11,7 +11,7 @@ CREATE TABLE IF NOT EXISTS test.defaults
|
||||
insert into test.defaults values (-3.273, -1.452, 4.267, 20.0, 40.0), (0.121, -0.615, 4.290, 20.0, 40.0);
|
||||
|
||||
DROP TABLE IF EXISTS test.model;
|
||||
create table test.model engine = Memory as select LinearRegressionState(0.1, 0.0, 2, 1.0)(target, param1, param2) as state from test.defaults;
|
||||
create table test.model engine = Memory as select LinearRegressionState(0.1, 0.0, 2, 'SGD')(target, param1, param2) as state from test.defaults;
|
||||
|
||||
select ans < -61.374 and ans > -61.375 from
|
||||
(with (select state from remote('127.0.0.1', test.model)) as model select evalMLMethod(model, predict1, predict2) as ans from remote('127.0.0.1', test.defaults));
|
||||
|
@ -10,7 +10,7 @@ CREATE TABLE IF NOT EXISTS test.defaults
|
||||
) ENGINE = Memory;
|
||||
insert into test.defaults values (1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2)
|
||||
DROP TABLE IF EXISTS test.model;
|
||||
create table test.model engine = Memory as select LogisticRegressionState(0.1, 0.0, 1.0, 1)(target, param1, param2) as state from test.defaults;
|
||||
create table test.model engine = Memory as select LogisticRegressionState(0.1, 0.0, 1.0, 'SGD')(target, param1, param2) as state from test.defaults;
|
||||
|
||||
select ans < 1.1 and ans > 0.9 from
|
||||
(with (select state from test.model) as model select evalMLMethod(model, predict1, predict2) as ans from test.defaults limit 2);
|
||||
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user