review fixes

This commit is contained in:
Alexander Kozhikhov 2019-04-21 17:32:42 +03:00
parent b6d2c9f4d2
commit daf4690d37
12 changed files with 74 additions and 89 deletions

View File

@ -24,9 +24,6 @@ AggregateFunctionPtr createAggregateFunctionMLMethod(
for (size_t i = 0; i < argument_types.size(); ++i)
{
// if (!WhichDataType(argument_types[i]).isFloat64())
// throw Exception("Illegal type " + argument_types[i]->getName() + " of argument " + std::to_string(i) + "for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
// if (!WhichDataType(argument_types[i]).isNumeric())
if (!isNumber(argument_types[i]))
throw Exception("Argument " + std::to_string(i) + " of type " + argument_types[i]->getName() + " must be numeric for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
@ -55,28 +52,20 @@ AggregateFunctionPtr createAggregateFunctionMLMethod(
}
if (parameters.size() > 3)
{
if (applyVisitor(FieldVisitorConvertToNumber<UInt32>(), parameters[3]) == Float64{1.0})
{
if (applyVisitor(FieldVisitorToString(), parameters[3]) == "\'SGD\'") {
weights_updater = std::make_shared<StochasticGradientDescent>();
}
else if (applyVisitor(FieldVisitorConvertToNumber<UInt32>(), parameters[3]) == Float64{2.0})
{
else if (applyVisitor(FieldVisitorToString(), parameters[3]) == "\'Momentum\'") {
weights_updater = std::make_shared<Momentum>();
}
else if (applyVisitor(FieldVisitorConvertToNumber<UInt32>(), parameters[3]) == Float64{3.0})
{
else if (applyVisitor(FieldVisitorToString(), parameters[3]) == "\'Nesterov\'") {
weights_updater = std::make_shared<Nesterov>();
}
else
{
throw Exception("Invalid parameter for weights updater", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
}
// else
// {
// weights_updater = std::make_unique<StochasticGradientDescent>();
// }
if (std::is_same<Method, FuncLinearRegression>::value)
{
@ -91,7 +80,6 @@ AggregateFunctionPtr createAggregateFunctionMLMethod(
throw Exception("Such gradient computer is not implemented yet", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
return std::make_shared<Method>(argument_types.size() - 1,
gradient_computer, weights_updater,
learning_rate, l2_reg_coef, batch_size, argument_types, parameters);

View File

@ -14,6 +14,7 @@
#include <Columns/ColumnsNumber.h>
#include <AggregateFunctions/IAggregateFunction.h>
#include <Functions/FunctionHelpers.h>
#include <Interpreters/castColumn.h>
#include <cmath>
#include <exception>
@ -45,7 +46,7 @@ public:
virtual void predict(ColumnVector<Float64>::Container &container,
Block &block, const ColumnNumbers &arguments,
const std::vector<Float64> &weights,
Float64 bias) const = 0;
Float64 bias, const Context & context) const = 0;
};
@ -61,19 +62,14 @@ public:
Float64 derivative = (target - bias);
for (size_t i = 0; i < weights.size(); ++i)
{
// auto value = static_cast<const ColumnVector<Float64> &>(*columns[i]).getData()[row_num];
auto value = (*columns[i])[row_num].get<Float64>();
// if ((*columns[i])[row_num].getType() == Field::Types::Float64)
derivative -= weights[i] * value;
// else
// derivative -= weights[i] * (*columns[i])[row_num].get<Float32>();
}
derivative *= (2 * learning_rate);
batch_gradient[weights.size()] += derivative;
for (size_t i = 0; i < weights.size(); ++i)
{
// auto value = static_cast<const ColumnVector<Float64> &>(*columns[i]).getData()[row_num];
auto value = (*columns[i])[row_num].get<Float64>();
batch_gradient[i] += derivative * value - 2 * l2_reg_coef * weights[i];
}
@ -82,24 +78,30 @@ public:
void predict(ColumnVector<Float64>::Container &container,
Block &block,
const ColumnNumbers &arguments,
const std::vector<Float64> &weights, Float64 bias) const override
const std::vector<Float64> &weights, Float64 bias, const Context & context) const override
{
size_t rows_num = block.rows();
std::vector<Float64> results(rows_num, bias);
for (size_t i = 1; i < arguments.size(); ++i)
{
ColumnPtr cur_col = block.safeGetByPosition(arguments[i]).column;
const ColumnWithTypeAndName & cur_col = block.getByPosition(arguments[i]);
if (!isNumber(cur_col.type)) {
throw Exception("Prediction arguments must be have numeric type", ErrorCodes::BAD_ARGUMENTS);
}
/// If column type is already Float64 then castColumn simply returns it
auto features_col_ptr = castColumn(cur_col, std::make_shared<DataTypeFloat64>(), context);
auto features_column = typeid_cast<const ColumnFloat64 *>(features_col_ptr.get());
if (!features_column)
{
throw Exception("Unexpectedly cannot dynamically cast features column " + std::to_string(i), ErrorCodes::LOGICAL_ERROR);
}
for (size_t row_num = 0; row_num != rows_num; ++row_num)
{
const auto &element = (*cur_col)[row_num];
// if (element.getType() != Field::Types::Float64)
if (!DB::Field::isSimpleNumeric(element.getType()))
throw Exception("Prediction arguments must be have numeric type",
ErrorCodes::BAD_ARGUMENTS);
results[row_num] += weights[i - 1] * element.get<Float64>();
results[row_num] += weights[i - 1] * features_column->getElement(row_num);
}
}
@ -124,7 +126,7 @@ public:
Float64 derivative = bias;
for (size_t i = 0; i < weights.size(); ++i)
{
auto value = static_cast<const ColumnVector<Float64> &>(*columns[i]).getData()[row_num];
auto value = (*columns[i])[row_num].get<Float64>();
derivative += weights[i] * value;
}
derivative *= target;
@ -133,7 +135,7 @@ public:
batch_gradient[weights.size()] += learning_rate * target / (derivative + 1);
for (size_t i = 0; i < weights.size(); ++i)
{
auto value = static_cast<const ColumnVector<Float64> &>(*columns[i]).getData()[row_num];
auto value = (*columns[i])[row_num].get<Float64>();
batch_gradient[i] += learning_rate * target * value / (derivative + 1)
- 2 * l2_reg_coef * weights[i];
}
@ -142,22 +144,30 @@ public:
void predict(ColumnVector<Float64>::Container & container,
Block & block,
const ColumnNumbers & arguments,
const std::vector<Float64> & weights, Float64 bias) const override
const std::vector<Float64> & weights, Float64 bias, const Context & context) const override
{
size_t rows_num = block.rows();
std::vector<Float64> results(rows_num, bias);
for (size_t i = 1; i < arguments.size(); ++i)
{
ColumnPtr cur_col = block.safeGetByPosition(arguments[i]).column;
const ColumnWithTypeAndName & cur_col = block.getByPosition(arguments[i]);
if (!isNumber(cur_col.type)) {
throw Exception("Prediction arguments must be have numeric type", ErrorCodes::BAD_ARGUMENTS);
}
/// If column type is already Float64 then castColumn simply returns it
auto features_col_ptr = castColumn(cur_col, std::make_shared<DataTypeFloat64>(), context);
auto features_column = typeid_cast<const ColumnFloat64 *>(features_col_ptr.get());
if (!features_column)
{
throw Exception("Unexpectedly cannot dynamically cast features column " + std::to_string(i), ErrorCodes::LOGICAL_ERROR);
}
for (size_t row_num = 0; row_num != rows_num; ++row_num)
{
const auto &element = (*cur_col)[row_num];
// if (element.getType() != Field::Types::Float64)
if (!DB::Field::isSimpleNumeric(element.getType()))
throw Exception("Prediction arguments must have numeric type", ErrorCodes::BAD_ARGUMENTS);
results[row_num] += weights[i - 1] * element.get<Float64>();
results[row_num] += weights[i - 1] * features_column->getElement(row_num);
}
}
@ -382,8 +392,7 @@ public:
void add(const IColumn **columns, size_t row_num)
{
/// first column stores target; features start from (columns + 1)
const auto &target = static_cast<const ColumnVector<Float64> &>(*columns[0]).getData()[row_num];
const auto &target = (*columns[0])[row_num].get<Float64>();
/// Here we have columns + 1 as first column corresponds to target value, and others - to features
weights_updater->add_to_batch(gradient_batch, *gradient_computer,
weights, bias, learning_rate, l2_reg_coef, target, columns + 1, row_num);
@ -435,9 +444,9 @@ public:
weights_updater->read(buf);
}
void predict(ColumnVector<Float64>::Container &container, Block &block, const ColumnNumbers &arguments) const
void predict(ColumnVector<Float64>::Container &container, Block &block, const ColumnNumbers &arguments, const Context & context) const
{
gradient_computer->predict(container, block, arguments, weights, bias);
gradient_computer->predict(container, block, arguments, weights, bias, context);
}
private:
@ -529,7 +538,7 @@ public:
this->data(place).read(buf);
}
void predictValues(ConstAggregateDataPtr place, IColumn & to, Block & block, const ColumnNumbers & arguments) const override
void predictValues(ConstAggregateDataPtr place, IColumn & to, Block & block, const ColumnNumbers & arguments, const Context & context) const override
{
if (arguments.size() != param_num + 1)
throw Exception("Predict got incorrect number of arguments. Got: " +
@ -538,7 +547,7 @@ public:
auto &column = dynamic_cast<ColumnVector<Float64> &>(to);
this->data(place).predict(column.getData(), block, arguments);
this->data(place).predict(column.getData(), block, arguments, context);
}
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override

View File

@ -96,8 +96,10 @@ public:
/// This function is used for machine learning methods
virtual void predictValues(ConstAggregateDataPtr /* place */, IColumn & /*to*/,
Block & /*block*/, const ColumnNumbers & /*arguments*/) const
{}
Block & /*block*/, const ColumnNumbers & /*arguments*/, const Context & /*context*/) const
{
throw Exception("Method predictValues is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
/** Returns true for aggregate functions of type -State.
* They are executed as other aggregate functions, but not finalized (return an aggregation state that can be combined with another).

View File

@ -103,36 +103,22 @@ MutableColumnPtr ColumnAggregateFunction::convertToValues() const
return res;
}
MutableColumnPtr ColumnAggregateFunction::predictValues(Block & block, const ColumnNumbers & arguments) const
MutableColumnPtr ColumnAggregateFunction::predictValues(Block & block, const ColumnNumbers & arguments, const Context & context) const
{
MutableColumnPtr res;
if (convertion(&res))
{
return res;
}
convertion(&res);
// auto ML_function_Linear = typeid_cast<AggregateFunctionMLMethod<LinearModelData, NameLinearRegression> *>(func.get());
// auto ML_function_Logistic = typeid_cast<AggregateFunctionMLMethod<LinearModelData, NameLogisticRegression> *>(func.get());
auto ML_function = func.get();
if (ML_function)
{
size_t row_num = 0;
for (auto val : data)
{
ML_function->predictValues(val, *res, block, arguments);
ML_function->predictValues(val, *res, block, arguments, context);
++row_num;
}
}
// else if (ML_function_Logistic)
// {
// size_t row_num = 0;
// for (auto val : data)
// {
// ML_function_Logistic->predictValues(val, *res, block, arguments);
// ++row_num;
// }
// }
else
{
throw Exception("Illegal aggregate function is passed",

View File

@ -119,7 +119,7 @@ public:
const char * getFamilyName() const override { return "AggregateFunction"; }
bool convertion(MutableColumnPtr* res_) const;
MutableColumnPtr predictValues(Block & block, const ColumnNumbers & arguments) const;
MutableColumnPtr predictValues(Block & block, const ColumnNumbers & arguments, const Context & context) const;
size_t size() const override
{

View File

@ -200,7 +200,6 @@ public:
template <typename T> struct TypeToEnum;
template <Types::Which which> struct EnumToType;
static bool isSimpleNumeric(Types::Which which) { return which >= Types::UInt64 && which <= Types::Int128; }
static bool IsDecimal(Types::Which which) { return which >= Types::Decimal32 && which <= Types::Decimal128; }
Field()

View File

@ -29,10 +29,12 @@ class FunctionEvalMLMethod : public IFunction
{
public:
static constexpr auto name = "evalMLMethod";
static FunctionPtr create(const Context &)
static FunctionPtr create(const Context & context)
{
return std::make_shared<FunctionEvalMLMethod>();
return std::make_shared<FunctionEvalMLMethod>(context);
}
FunctionEvalMLMethod(const Context & context) : context(context)
{}
String getName() const override
{
@ -72,14 +74,13 @@ public:
if (!column_with_states)
throw Exception("Illegal column " + block.getByPosition(arguments[0]).column->getName()
+ " of first argument of function "
+ getName(),
ErrorCodes::ILLEGAL_COLUMN);
+ " of first argument of function " + getName(), ErrorCodes::ILLEGAL_COLUMN);
block.getByPosition(result).column =
typeid_cast<const ColumnAggregateFunction *>(&*column_with_states->getDataColumnPtr())->predictValues(block, arguments);
typeid_cast<const ColumnAggregateFunction *>(&*column_with_states->getDataColumnPtr())->predictValues(block, arguments, context);
}
const Context & context;
};
void registerFunctionEvalMLMethod(FunctionFactory & factory)

View File

@ -11,7 +11,7 @@ CREATE TABLE IF NOT EXISTS test.defaults
insert into test.defaults values (-3.273, -1.452, 4.267, 20.0, 40.0), (0.121, -0.615, 4.290, 20.0, 40.0);
DROP TABLE IF EXISTS test.model;
create table test.model engine = Memory as select LinearRegressionState(0.1, 0.0, 2, 1.0)(target, param1, param2) as state from test.defaults;
create table test.model engine = Memory as select LinearRegressionState(0.1, 0.0, 2, 'SGD')(target, param1, param2) as state from test.defaults;
select ans < -61.374 and ans > -61.375 from
(with (select state from remote('127.0.0.1', test.model)) as model select evalMLMethod(model, predict1, predict2) as ans from remote('127.0.0.1', test.defaults));

View File

@ -10,7 +10,7 @@ CREATE TABLE IF NOT EXISTS test.defaults
) ENGINE = Memory;
insert into test.defaults values (1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2)
DROP TABLE IF EXISTS test.model;
create table test.model engine = Memory as select LogisticRegressionState(0.1, 0.0, 1.0, 1)(target, param1, param2) as state from test.defaults;
create table test.model engine = Memory as select LogisticRegressionState(0.1, 0.0, 1.0, 'SGD')(target, param1, param2) as state from test.defaults;
select ans < 1.1 and ans > 0.9 from
(with (select state from test.model) as model select evalMLMethod(model, predict1, predict2) as ans from test.defaults limit 2);

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long