Merge branch 'ml_methods'

This commit is contained in:
alexander kozhikhov 2019-02-13 01:36:37 +03:00
commit 2feee7ebe5
7 changed files with 393 additions and 381 deletions

View File

@ -60,6 +60,7 @@ public:
return batch_gradient; return batch_gradient;
} }
virtual Float64 predict(const std::vector<Float64> & predict_feature, const std::vector<Float64> & weights, Float64 bias) const = 0; virtual Float64 predict(const std::vector<Float64> & predict_feature, const std::vector<Float64> & weights, Float64 bias) const = 0;
virtual void predict_for_all(ColumnVector<Float64>::Container & container, Block & block, const ColumnNumbers & arguments, const std::vector<Float64> & weights, Float64 bias) const = 0;
protected: protected:
std::vector<Float64> batch_gradient; // gradient for bias lies in batch_gradient[batch_gradient.size() - 1] std::vector<Float64> batch_gradient; // gradient for bias lies in batch_gradient[batch_gradient.size() - 1]
@ -106,6 +107,31 @@ public:
return res; return res;
} }
void predict_for_all(ColumnVector<Float64>::Container & container, Block & block, const ColumnNumbers & arguments, const std::vector<Float64> & weights, Float64 bias) const override
{
size_t rows_num = block.rows();
std::vector<Float64> results(rows_num, bias);
for (size_t i = 1; i < arguments.size(); ++i)
{
ColumnPtr cur_col = block.getByPosition(arguments[i]).column;
for (size_t row_num = 0; row_num != rows_num; ++row_num)
{
const auto &element = (*cur_col)[row_num];
if (element.getType() != Field::Types::Float64)
throw Exception("Prediction arguments must be values of type Float",
ErrorCodes::BAD_ARGUMENTS);
results[row_num] += weights[i - 1] * element.get<Float64>();
}
}
for (size_t row_num = 0; row_num != rows_num; ++row_num)
{
container.emplace_back(results[row_num]);
}
}
}; };
class LogisticRegression : public IGradientComputer class LogisticRegression : public IGradientComputer
@ -150,6 +176,15 @@ public:
res = 1 / (1 + exp(-res)); res = 1 / (1 + exp(-res));
return res; return res;
} }
void predict_for_all(ColumnVector<Float64>::Container & container, Block & block, const ColumnNumbers & arguments, const std::vector<Float64> & weights, Float64 bias) const override
{
// TODO
std::ignore = container;
std::ignore = block;
std::ignore = arguments;
std::ignore = weights;
std::ignore = bias;
}
}; };
class IWeightsUpdater class IWeightsUpdater
@ -294,6 +329,10 @@ public:
return gradient_computer->predict(predict_feature, weights, bias); return gradient_computer->predict(predict_feature, weights, bias);
} }
void predict_for_all(ColumnVector<Float64>::Container & container, Block & block, const ColumnNumbers & arguments) const
{
gradient_computer->predict_for_all(container, block, arguments, weights, bias);
}
private: private:
std::vector<Float64> weights; std::vector<Float64> weights;
@ -374,7 +413,7 @@ public:
this->data(place).read(buf); this->data(place).read(buf);
} }
void predictResultInto(ConstAggregateDataPtr place, IColumn & to, Block & block, size_t row_num, const ColumnNumbers & arguments) const void predictResultInto(ConstAggregateDataPtr place, IColumn & to, Block & block, const ColumnNumbers & arguments) const
{ {
if (arguments.size() != param_num + 1) if (arguments.size() != param_num + 1)
throw Exception("Predict got incorrect number of arguments. Got: " + std::to_string(arguments.size()) + ". Required: " + std::to_string(param_num + 1), throw Exception("Predict got incorrect number of arguments. Got: " + std::to_string(arguments.size()) + ". Required: " + std::to_string(param_num + 1),
@ -382,17 +421,20 @@ public:
auto &column = dynamic_cast<ColumnVector<Float64> &>(to); auto &column = dynamic_cast<ColumnVector<Float64> &>(to);
std::vector<Float64> predict_features(arguments.size() - 1); /// Так делали с одним предиктом, пока пусть побудет тут
for (size_t i = 1; i < arguments.size(); ++i) // std::vector<Float64> predict_features(arguments.size() - 1);
{ // for (size_t i = 1; i < arguments.size(); ++i)
const auto& element = (*block.getByPosition(arguments[i]).column)[row_num]; // {
if (element.getType() != Field::Types::Float64) // const auto& element = (*block.getByPosition(arguments[i]).column)[row_num];
throw Exception("Prediction arguments must be values of type Float", // if (element.getType() != Field::Types::Float64)
ErrorCodes::BAD_ARGUMENTS); // throw Exception("Prediction arguments must be values of type Float",
// ErrorCodes::BAD_ARGUMENTS);
predict_features[i - 1] = element.get<Float64>(); //
} //// predict_features[i - 1] = element.get<Float64>();
column.getData().push_back(this->data(place).predict(predict_features)); // }
// column.getData().push_back(this->data(place).predict(predict_features));
// column.getData().push_back(this->data(place).predict_for_all());
this->data(place).predict_for_all(column.getData(), block, arguments);
} }
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override

View File

@ -35,9 +35,9 @@ void ColumnAggregateFunction::addArena(ArenaPtr arena_)
arenas.push_back(arena_); arenas.push_back(arena_);
} }
bool ColumnAggregateFunction::convertion(MutableColumnPtr* res_) const bool ColumnAggregateFunction::convertion(MutableColumnPtr *res_) const
{ {
if (const AggregateFunctionState * function_state = typeid_cast<const AggregateFunctionState *>(func.get())) if (const AggregateFunctionState *function_state = typeid_cast<const AggregateFunctionState *>(func.get()))
{ {
auto res = createView(); auto res = createView();
res->set(function_state->getNestedFunction()); res->set(function_state->getNestedFunction());
@ -108,23 +108,14 @@ MutableColumnPtr ColumnAggregateFunction::convertToValues() const
MutableColumnPtr ColumnAggregateFunction::predictValues(Block & block, const ColumnNumbers & arguments) const MutableColumnPtr ColumnAggregateFunction::predictValues(Block & block, const ColumnNumbers & arguments) const
{ {
// if (const AggregateFunctionState * function_state = typeid_cast<const AggregateFunctionState *>(func.get()))
// {
// auto res = createView();
// res->set(function_state->getNestedFunction());
// res->data.assign(data.begin(), data.end());
// return res;
// }
//
// MutableColumnPtr res = func->getReturnType()->createColumn();
// res->reserve(data.size());
MutableColumnPtr res; MutableColumnPtr res;
if (convertion(&res)) if (convertion(&res))
{ {
return res; return res;
} }
// auto ML_function = typeid_cast<AggregateFunctionMLMethod<LinearRegressionData, NameLinearRegression> *>(func.get()); /// На моих тестах дважды в эту функцию приходит нечтно, имеющее data.size() == 0 однако оно по сути ничего не делает в следующих строках
auto ML_function_Linear = typeid_cast<AggregateFunctionMLMethod<LinearModelData, NameLinearRegression> *>(func.get()); auto ML_function_Linear = typeid_cast<AggregateFunctionMLMethod<LinearModelData, NameLinearRegression> *>(func.get());
auto ML_function_Logistic = typeid_cast<AggregateFunctionMLMethod<LinearModelData, NameLogisticRegression> *>(func.get()); auto ML_function_Logistic = typeid_cast<AggregateFunctionMLMethod<LinearModelData, NameLogisticRegression> *>(func.get());
if (ML_function_Linear) if (ML_function_Linear)
@ -132,15 +123,16 @@ MutableColumnPtr ColumnAggregateFunction::predictValues(Block & block, const Col
size_t row_num = 0; size_t row_num = 0;
for (auto val : data) for (auto val : data)
{ {
ML_function_Linear->predictResultInto(val, *res, block, row_num, arguments); ML_function_Linear->predictResultInto(val, *res, block, arguments);
++row_num; ++row_num;
} }
} else if (ML_function_Logistic) } else if (ML_function_Logistic)
{ {
size_t row_num = 0; size_t row_num = 0;
for (auto val : data) for (auto val : data)
{ {
ML_function_Logistic->predictResultInto(val, *res, block, row_num, arguments); ML_function_Logistic->predictResultInto(val, *res, block, arguments);
++row_num; ++row_num;
} }
} else } else

View File

@ -59,9 +59,20 @@ namespace DB
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) override void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) override
{ {
// const ColumnAggregateFunction * column_with_states =
// typeid_cast<const ColumnAggregateFunction *>(static_cast<const ColumnConst *>(&*block.getByPosition(arguments.at(0)).column));
// завести МЛ_аггрункции как отдельный класс, чтобы тут сразу это проверять, а не делать это внутри predictValues() // завести МЛ_аггрункции как отдельный класс, чтобы тут сразу это проверять, а не делать это внутри predictValues()
const ColumnAggregateFunction * column_with_states
= typeid_cast<const ColumnAggregateFunction *>(&*block.getByPosition(arguments.at(0)).column); // const ColumnAggregateFunction * column_with_states
// = typeid_cast<const ColumnAggregateFunction *>(&*block.getByPosition(arguments.at(0)).column);
const ColumnConst * column_with_states
= typeid_cast<const ColumnConst *>(&*block.getByPosition(arguments.at(0)).column);
if (!column_with_states) if (!column_with_states)
throw Exception("Illegal column " + block.getByPosition(arguments.at(0)).column->getName() throw Exception("Illegal column " + block.getByPosition(arguments.at(0)).column->getName()
+ " of first argument of function " + " of first argument of function "
@ -83,7 +94,8 @@ namespace DB
} }
block.getByPosition(result).column = column_with_states->predictValues(predict_features); block.getByPosition(result).column = column_with_states->predictValues(predict_features);
*/ */
block.getByPosition(result).column = column_with_states->predictValues(block, arguments); block.getByPosition(result).column =
typeid_cast<const ColumnAggregateFunction *>(&*column_with_states->getDataColumnPtr())->predictValues(block, arguments);
} }
}; };

View File

@ -0,0 +1,300 @@
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804

File diff suppressed because one or more lines are too long

View File

@ -1,300 +0,0 @@
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556

File diff suppressed because one or more lines are too long