small test code

This commit is contained in:
alexander kozhikhov 2019-02-11 00:16:16 +03:00
parent 7cea77b8c0
commit bfccafef49
5 changed files with 110 additions and 19 deletions

View File

@ -60,6 +60,7 @@ public:
return batch_gradient;
}
virtual Float64 predict(const std::vector<Float64> & predict_feature, const std::vector<Float64> & weights, Float64 bias) const = 0;
virtual void predict_for_all(ColumnVector<Float64>::Container & container, Block & block, const ColumnNumbers & arguments, const std::vector<Float64> & weights, Float64 bias) const = 0;
protected:
std::vector<Float64> batch_gradient; // gradient for bias lies in batch_gradient[batch_gradient.size() - 1]
@ -85,7 +86,7 @@ public:
batch_gradient[weights.size()] += derivative;
for (size_t i = 0; i < weights.size(); ++i)
{
batch_gradient[i] += derivative * static_cast<const ColumnVector<Float64> &>(*columns[i + 1]).getData()[row_num];;
batch_gradient[i] += derivative * static_cast<const ColumnVector<Float64> &>(*columns[i + 1]).getData()[row_num];
}
}
Float64 predict(const std::vector<Float64> & predict_feature, const std::vector<Float64> & weights, Float64 bias) const override
@ -106,6 +107,36 @@ public:
return res;
}
void predict_for_all(ColumnVector<Float64>::Container & container, Block & block, const ColumnNumbers & arguments, const std::vector<Float64> & weights, Float64 bias) const override
{
size_t rows_num = block.rows();
std::cout << "\n\nROWS NUM: " << rows_num << "\n\n";
std::vector<Float64> results(rows_num, bias);
for (size_t i = 1; i < arguments.size(); ++i)
{
ColumnPtr cur_col = block.getByPosition(arguments[i]).column;
for (size_t row_num = 0; row_num != rows_num; ++row_num)
{
const auto &element = (*cur_col)[row_num];
if (element.getType() != Field::Types::Float64)
throw Exception("Prediction arguments must be values of type Float",
ErrorCodes::BAD_ARGUMENTS);
results[row_num] += weights[row_num] * element.get<Float64>();
// predict_features[i - 1] = element.get<Float64>();
}
}
for (size_t row_num = 0; row_num != rows_num; ++row_num)
{
container.emplace_back(results[row_num]);
}
// column.getData().push_back(this->data(place).predict(predict_features));
// column.getData().push_back(this->data(place).predict_for_all());
// this->data(place).predict_for_all(column.getData(), block, arguments);
}
};
class LogisticRegression : public IGradientComputer
{
@ -149,6 +180,14 @@ public:
res = 1 / (1 + exp(-res));
return res;
}
void predict_for_all(ColumnVector<Float64>::Container & container, Block & block, const ColumnNumbers & arguments, const std::vector<Float64> & weights, Float64 bias) const override
{
std::ignore = container;
std::ignore = block;
std::ignore = arguments;
std::ignore = weights;
std::ignore = bias;
}
};
class IWeightsUpdater
@ -191,6 +230,7 @@ public:
}
bias += hk_[weights.size()] / cur_batch;
}
/// virtual?
virtual void merge(const std::shared_ptr<IWeightsUpdater> rhs, Float64 frac, Float64 rhs_frac) override {
auto momentum_rhs = std::dynamic_pointer_cast<Momentum>(rhs);
for (size_t i = 0; i < hk_.size(); ++i)
@ -199,9 +239,10 @@ public:
}
}
Float64 alpha_{0.1};
std::vector<Float64> hk_;
Float64 alpha_{0.1};
std::vector<Float64> hk_;
};
class LinearModelData
{
public:
@ -223,7 +264,6 @@ public:
}
void add(Float64 target, const IColumn ** columns, size_t row_num)
{
gradient_computer->compute(weights, bias, learning_rate, target, columns, row_num);
@ -285,6 +325,10 @@ public:
return gradient_computer->predict(predict_feature, weights, bias);
}
void predict_for_all(ColumnVector<Float64>::Container & container, Block & block, const ColumnNumbers & arguments) const
{
gradient_computer->predict_for_all(container, block, arguments, weights, bias);
}
private:
std::vector<Float64> weights;
@ -366,23 +410,28 @@ public:
void predictResultInto(ConstAggregateDataPtr place, IColumn & to, Block & block, size_t row_num, const ColumnNumbers & arguments) const
{
std::ignore = row_num;
std::cout << "\n\n IM CALLED \n\n";
if (arguments.size() != param_num + 1)
throw Exception("Predict got incorrect number of arguments. Got: " + std::to_string(arguments.size()) + ". Required: " + std::to_string(param_num + 1),
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
auto &column = dynamic_cast<ColumnVector<Float64> &>(to);
std::vector<Float64> predict_features(arguments.size() - 1);
for (size_t i = 1; i < arguments.size(); ++i)
{
const auto& element = (*block.getByPosition(arguments[i]).column)[row_num];
if (element.getType() != Field::Types::Float64)
throw Exception("Prediction arguments must be values of type Float",
ErrorCodes::BAD_ARGUMENTS);
predict_features[i - 1] = element.get<Float64>();
}
column.getData().push_back(this->data(place).predict(predict_features));
// std::vector<Float64> predict_features(arguments.size() - 1);
// for (size_t i = 1; i < arguments.size(); ++i)
// {
// const auto& element = (*block.getByPosition(arguments[i]).column)[row_num];
// if (element.getType() != Field::Types::Float64)
// throw Exception("Prediction arguments must be values of type Float",
// ErrorCodes::BAD_ARGUMENTS);
//
//// predict_features[i - 1] = element.get<Float64>();
// }
// column.getData().push_back(this->data(place).predict(predict_features));
// column.getData().push_back(this->data(place).predict_for_all());
this->data(place).predict_for_all(column.getData(), block, arguments);
}
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override

View File

@ -33,9 +33,9 @@ void ColumnAggregateFunction::addArena(ArenaPtr arena_)
arenas.push_back(arena_);
}
bool ColumnAggregateFunction::convertion(MutableColumnPtr* res_) const
bool ColumnAggregateFunction::convertion(MutableColumnPtr *res_) const
{
if (const AggregateFunctionState * function_state = typeid_cast<const AggregateFunctionState *>(func.get()))
if (const AggregateFunctionState *function_state = typeid_cast<const AggregateFunctionState *>(func.get()))
{
auto res = createView();
res->set(function_state->getNestedFunction());
@ -122,16 +122,25 @@ MutableColumnPtr ColumnAggregateFunction::predictValues(Block & block, const Col
return res;
}
std::cout << "\n\nHELLO: " << data.size() << "\n\n";
/// На моих тестах дважды в эту функцию приходит нечтно, имеющее data.size() == 0 однако оно по сути ничего не делает в следующих строках
if (1 != data.size())
return res;
// auto ML_function = typeid_cast<AggregateFunctionMLMethod<LinearRegressionData, NameLinearRegression> *>(func.get());
auto ML_function_Linear = typeid_cast<AggregateFunctionMLMethod<LinearModelData, NameLinearRegression> *>(func.get());
auto ML_function_Logistic = typeid_cast<AggregateFunctionMLMethod<LinearModelData, NameLogisticRegression> *>(func.get());
if (ML_function_Linear)
{
size_t row_num = 0;
std::cout << "\n\nIM HERE\n" << data.size() << "\n";
for (auto val : data) {
std::cout << "HIII\n\n";
ML_function_Linear->predictResultInto(val, *res, block, row_num, arguments);
++row_num;
}
// ML_function_Linear->predictResultInto(data[0], *res, block, row_num, arguments);
} else if (ML_function_Logistic)
{
size_t row_num = 0;

View File

@ -59,9 +59,20 @@ namespace DB
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) override
{
// const ColumnAggregateFunction * column_with_states =
// typeid_cast<const ColumnAggregateFunction *>(static_cast<const ColumnConst *>(&*block.getByPosition(arguments.at(0)).column));
// std::cout << "\n\n\nHELOOOOO\n\n\n";
// завести МЛ_аггрункции как отдельный класс, чтобы тут сразу это проверять, а не делать это внутри predictValues()
const ColumnAggregateFunction * column_with_states
= typeid_cast<const ColumnAggregateFunction *>(&*block.getByPosition(arguments.at(0)).column);
// std::cout << "\n\n\nHELOOOOO 2\n\n\n";
// const ColumnConst * column_with_states
// = typeid_cast<const ColumnConst *>(&*block.getByPosition(arguments.at(0)).column);
if (!column_with_states)
throw Exception("Illegal column " + block.getByPosition(arguments.at(0)).column->getName()
+ " of first argument of function "

View File

@ -41,7 +41,17 @@ create table test.model engine = Memory as select LinearRegressionState(0.1, 5,
-- -- select multiply(param1, param2) from test.defaults;
--
-- -- select evalLinReg(LinRegState(0.01)(target, param1, param2), 20.0, 40.0) from test.defaults;
select evalMLMethod(state, predict1, predict2) from test.model cross join test.defaults;
-- select evalMLMethod(state, predict1, predict2) from test.model cross join test.defaults;
-- select evalMLMethod(state, predict1, predict2) from test.model cross join test.defaults;
-- select evalMLMethod(state, 0.1, 0.2) from test.model;
-- with (select state from test.model) as model select evalMLMethod(model, predict1, predict2) from test.defaults;
-- -- select evalLinReg(state, predict1, predict2) from test.model inner join (select * from test.tests) using state;
-- -- select evalLinReg(state1, predict1, predict2) from test.tests;
--