Merge branch 'ml_methods'

This commit is contained in:
alexander kozhikhov 2019-02-13 01:36:37 +03:00
commit 2feee7ebe5
7 changed files with 393 additions and 381 deletions

View File

@ -60,6 +60,7 @@ public:
return batch_gradient;
}
virtual Float64 predict(const std::vector<Float64> & predict_feature, const std::vector<Float64> & weights, Float64 bias) const = 0;
virtual void predict_for_all(ColumnVector<Float64>::Container & container, Block & block, const ColumnNumbers & arguments, const std::vector<Float64> & weights, Float64 bias) const = 0;
protected:
std::vector<Float64> batch_gradient; // gradient for bias lies in batch_gradient[batch_gradient.size() - 1]
@ -106,6 +107,31 @@ public:
return res;
}
void predict_for_all(ColumnVector<Float64>::Container & container, Block & block, const ColumnNumbers & arguments, const std::vector<Float64> & weights, Float64 bias) const override
{
size_t rows_num = block.rows();
std::vector<Float64> results(rows_num, bias);
for (size_t i = 1; i < arguments.size(); ++i)
{
ColumnPtr cur_col = block.getByPosition(arguments[i]).column;
for (size_t row_num = 0; row_num != rows_num; ++row_num)
{
const auto &element = (*cur_col)[row_num];
if (element.getType() != Field::Types::Float64)
throw Exception("Prediction arguments must be values of type Float",
ErrorCodes::BAD_ARGUMENTS);
results[row_num] += weights[i - 1] * element.get<Float64>();
}
}
for (size_t row_num = 0; row_num != rows_num; ++row_num)
{
container.emplace_back(results[row_num]);
}
}
};
class LogisticRegression : public IGradientComputer
@ -150,6 +176,15 @@ public:
res = 1 / (1 + exp(-res));
return res;
}
void predict_for_all(ColumnVector<Float64>::Container & container, Block & block, const ColumnNumbers & arguments, const std::vector<Float64> & weights, Float64 bias) const override
{
// TODO
std::ignore = container;
std::ignore = block;
std::ignore = arguments;
std::ignore = weights;
std::ignore = bias;
}
};
class IWeightsUpdater
@ -294,6 +329,10 @@ public:
return gradient_computer->predict(predict_feature, weights, bias);
}
void predict_for_all(ColumnVector<Float64>::Container & container, Block & block, const ColumnNumbers & arguments) const
{
gradient_computer->predict_for_all(container, block, arguments, weights, bias);
}
private:
std::vector<Float64> weights;
@ -374,7 +413,7 @@ public:
this->data(place).read(buf);
}
void predictResultInto(ConstAggregateDataPtr place, IColumn & to, Block & block, size_t row_num, const ColumnNumbers & arguments) const
void predictResultInto(ConstAggregateDataPtr place, IColumn & to, Block & block, const ColumnNumbers & arguments) const
{
if (arguments.size() != param_num + 1)
throw Exception("Predict got incorrect number of arguments. Got: " + std::to_string(arguments.size()) + ". Required: " + std::to_string(param_num + 1),
@ -382,17 +421,20 @@ public:
auto &column = dynamic_cast<ColumnVector<Float64> &>(to);
std::vector<Float64> predict_features(arguments.size() - 1);
for (size_t i = 1; i < arguments.size(); ++i)
{
const auto& element = (*block.getByPosition(arguments[i]).column)[row_num];
if (element.getType() != Field::Types::Float64)
throw Exception("Prediction arguments must be values of type Float",
ErrorCodes::BAD_ARGUMENTS);
predict_features[i - 1] = element.get<Float64>();
}
column.getData().push_back(this->data(place).predict(predict_features));
/// Так делали с одним предиктом, пока пусть побудет тут
// std::vector<Float64> predict_features(arguments.size() - 1);
// for (size_t i = 1; i < arguments.size(); ++i)
// {
// const auto& element = (*block.getByPosition(arguments[i]).column)[row_num];
// if (element.getType() != Field::Types::Float64)
// throw Exception("Prediction arguments must be values of type Float",
// ErrorCodes::BAD_ARGUMENTS);
//
//// predict_features[i - 1] = element.get<Float64>();
// }
// column.getData().push_back(this->data(place).predict(predict_features));
// column.getData().push_back(this->data(place).predict_for_all());
this->data(place).predict_for_all(column.getData(), block, arguments);
}
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override

View File

@ -35,9 +35,9 @@ void ColumnAggregateFunction::addArena(ArenaPtr arena_)
arenas.push_back(arena_);
}
bool ColumnAggregateFunction::convertion(MutableColumnPtr* res_) const
bool ColumnAggregateFunction::convertion(MutableColumnPtr *res_) const
{
if (const AggregateFunctionState * function_state = typeid_cast<const AggregateFunctionState *>(func.get()))
if (const AggregateFunctionState *function_state = typeid_cast<const AggregateFunctionState *>(func.get()))
{
auto res = createView();
res->set(function_state->getNestedFunction());
@ -108,23 +108,14 @@ MutableColumnPtr ColumnAggregateFunction::convertToValues() const
MutableColumnPtr ColumnAggregateFunction::predictValues(Block & block, const ColumnNumbers & arguments) const
{
// if (const AggregateFunctionState * function_state = typeid_cast<const AggregateFunctionState *>(func.get()))
// {
// auto res = createView();
// res->set(function_state->getNestedFunction());
// res->data.assign(data.begin(), data.end());
// return res;
// }
//
// MutableColumnPtr res = func->getReturnType()->createColumn();
// res->reserve(data.size());
MutableColumnPtr res;
if (convertion(&res))
{
return res;
}
// auto ML_function = typeid_cast<AggregateFunctionMLMethod<LinearRegressionData, NameLinearRegression> *>(func.get());
/// На моих тестах дважды в эту функцию приходит нечтно, имеющее data.size() == 0 однако оно по сути ничего не делает в следующих строках
auto ML_function_Linear = typeid_cast<AggregateFunctionMLMethod<LinearModelData, NameLinearRegression> *>(func.get());
auto ML_function_Logistic = typeid_cast<AggregateFunctionMLMethod<LinearModelData, NameLogisticRegression> *>(func.get());
if (ML_function_Linear)
@ -132,15 +123,16 @@ MutableColumnPtr ColumnAggregateFunction::predictValues(Block & block, const Col
size_t row_num = 0;
for (auto val : data)
{
ML_function_Linear->predictResultInto(val, *res, block, row_num, arguments);
ML_function_Linear->predictResultInto(val, *res, block, arguments);
++row_num;
}
} else if (ML_function_Logistic)
{
size_t row_num = 0;
for (auto val : data)
{
ML_function_Logistic->predictResultInto(val, *res, block, row_num, arguments);
ML_function_Logistic->predictResultInto(val, *res, block, arguments);
++row_num;
}
} else

View File

@ -59,9 +59,20 @@ namespace DB
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) override
{
// const ColumnAggregateFunction * column_with_states =
// typeid_cast<const ColumnAggregateFunction *>(static_cast<const ColumnConst *>(&*block.getByPosition(arguments.at(0)).column));
// завести МЛ_аггрункции как отдельный класс, чтобы тут сразу это проверять, а не делать это внутри predictValues()
const ColumnAggregateFunction * column_with_states
= typeid_cast<const ColumnAggregateFunction *>(&*block.getByPosition(arguments.at(0)).column);
// const ColumnAggregateFunction * column_with_states
// = typeid_cast<const ColumnAggregateFunction *>(&*block.getByPosition(arguments.at(0)).column);
const ColumnConst * column_with_states
= typeid_cast<const ColumnConst *>(&*block.getByPosition(arguments.at(0)).column);
if (!column_with_states)
throw Exception("Illegal column " + block.getByPosition(arguments.at(0)).column->getName()
+ " of first argument of function "
@ -83,7 +94,8 @@ namespace DB
}
block.getByPosition(result).column = column_with_states->predictValues(predict_features);
*/
block.getByPosition(result).column = column_with_states->predictValues(block, arguments);
block.getByPosition(result).column =
typeid_cast<const ColumnAggregateFunction *>(&*column_with_states->getDataColumnPtr())->predictValues(block, arguments);
}
};

View File

@ -0,0 +1,300 @@
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804
-67.00423606399804

File diff suppressed because one or more lines are too long

View File

@ -1,300 +0,0 @@
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556

File diff suppressed because one or more lines are too long