mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-28 02:21:59 +00:00
small test code
This commit is contained in:
parent
7cea77b8c0
commit
bfccafef49
@ -60,6 +60,7 @@ public:
|
||||
return batch_gradient;
|
||||
}
|
||||
virtual Float64 predict(const std::vector<Float64> & predict_feature, const std::vector<Float64> & weights, Float64 bias) const = 0;
|
||||
virtual void predict_for_all(ColumnVector<Float64>::Container & container, Block & block, const ColumnNumbers & arguments, const std::vector<Float64> & weights, Float64 bias) const = 0;
|
||||
|
||||
protected:
|
||||
std::vector<Float64> batch_gradient; // gradient for bias lies in batch_gradient[batch_gradient.size() - 1]
|
||||
@ -85,7 +86,7 @@ public:
|
||||
batch_gradient[weights.size()] += derivative;
|
||||
for (size_t i = 0; i < weights.size(); ++i)
|
||||
{
|
||||
batch_gradient[i] += derivative * static_cast<const ColumnVector<Float64> &>(*columns[i + 1]).getData()[row_num];;
|
||||
batch_gradient[i] += derivative * static_cast<const ColumnVector<Float64> &>(*columns[i + 1]).getData()[row_num];
|
||||
}
|
||||
}
|
||||
Float64 predict(const std::vector<Float64> & predict_feature, const std::vector<Float64> & weights, Float64 bias) const override
|
||||
@ -106,6 +107,36 @@ public:
|
||||
|
||||
return res;
|
||||
}
|
||||
void predict_for_all(ColumnVector<Float64>::Container & container, Block & block, const ColumnNumbers & arguments, const std::vector<Float64> & weights, Float64 bias) const override
|
||||
{
|
||||
size_t rows_num = block.rows();
|
||||
std::cout << "\n\nROWS NUM: " << rows_num << "\n\n";
|
||||
std::vector<Float64> results(rows_num, bias);
|
||||
|
||||
|
||||
for (size_t i = 1; i < arguments.size(); ++i)
|
||||
{
|
||||
ColumnPtr cur_col = block.getByPosition(arguments[i]).column;
|
||||
for (size_t row_num = 0; row_num != rows_num; ++row_num)
|
||||
{
|
||||
const auto &element = (*cur_col)[row_num];
|
||||
if (element.getType() != Field::Types::Float64)
|
||||
throw Exception("Prediction arguments must be values of type Float",
|
||||
ErrorCodes::BAD_ARGUMENTS);
|
||||
|
||||
results[row_num] += weights[row_num] * element.get<Float64>();
|
||||
// predict_features[i - 1] = element.get<Float64>();
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t row_num = 0; row_num != rows_num; ++row_num)
|
||||
{
|
||||
container.emplace_back(results[row_num]);
|
||||
}
|
||||
// column.getData().push_back(this->data(place).predict(predict_features));
|
||||
// column.getData().push_back(this->data(place).predict_for_all());
|
||||
// this->data(place).predict_for_all(column.getData(), block, arguments);
|
||||
}
|
||||
};
|
||||
class LogisticRegression : public IGradientComputer
|
||||
{
|
||||
@ -149,6 +180,14 @@ public:
|
||||
res = 1 / (1 + exp(-res));
|
||||
return res;
|
||||
}
|
||||
void predict_for_all(ColumnVector<Float64>::Container & container, Block & block, const ColumnNumbers & arguments, const std::vector<Float64> & weights, Float64 bias) const override
|
||||
{
|
||||
std::ignore = container;
|
||||
std::ignore = block;
|
||||
std::ignore = arguments;
|
||||
std::ignore = weights;
|
||||
std::ignore = bias;
|
||||
}
|
||||
};
|
||||
|
||||
class IWeightsUpdater
|
||||
@ -191,6 +230,7 @@ public:
|
||||
}
|
||||
bias += hk_[weights.size()] / cur_batch;
|
||||
}
|
||||
/// virtual?
|
||||
virtual void merge(const std::shared_ptr<IWeightsUpdater> rhs, Float64 frac, Float64 rhs_frac) override {
|
||||
auto momentum_rhs = std::dynamic_pointer_cast<Momentum>(rhs);
|
||||
for (size_t i = 0; i < hk_.size(); ++i)
|
||||
@ -199,9 +239,10 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
Float64 alpha_{0.1};
|
||||
std::vector<Float64> hk_;
|
||||
Float64 alpha_{0.1};
|
||||
std::vector<Float64> hk_;
|
||||
};
|
||||
|
||||
class LinearModelData
|
||||
{
|
||||
public:
|
||||
@ -223,7 +264,6 @@ public:
|
||||
}
|
||||
|
||||
|
||||
|
||||
void add(Float64 target, const IColumn ** columns, size_t row_num)
|
||||
{
|
||||
gradient_computer->compute(weights, bias, learning_rate, target, columns, row_num);
|
||||
@ -285,6 +325,10 @@ public:
|
||||
|
||||
return gradient_computer->predict(predict_feature, weights, bias);
|
||||
}
|
||||
void predict_for_all(ColumnVector<Float64>::Container & container, Block & block, const ColumnNumbers & arguments) const
|
||||
{
|
||||
gradient_computer->predict_for_all(container, block, arguments, weights, bias);
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<Float64> weights;
|
||||
@ -366,23 +410,28 @@ public:
|
||||
|
||||
void predictResultInto(ConstAggregateDataPtr place, IColumn & to, Block & block, size_t row_num, const ColumnNumbers & arguments) const
|
||||
{
|
||||
std::ignore = row_num;
|
||||
std::cout << "\n\n IM CALLED \n\n";
|
||||
|
||||
if (arguments.size() != param_num + 1)
|
||||
throw Exception("Predict got incorrect number of arguments. Got: " + std::to_string(arguments.size()) + ". Required: " + std::to_string(param_num + 1),
|
||||
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
|
||||
auto &column = dynamic_cast<ColumnVector<Float64> &>(to);
|
||||
|
||||
std::vector<Float64> predict_features(arguments.size() - 1);
|
||||
for (size_t i = 1; i < arguments.size(); ++i)
|
||||
{
|
||||
const auto& element = (*block.getByPosition(arguments[i]).column)[row_num];
|
||||
if (element.getType() != Field::Types::Float64)
|
||||
throw Exception("Prediction arguments must be values of type Float",
|
||||
ErrorCodes::BAD_ARGUMENTS);
|
||||
|
||||
predict_features[i - 1] = element.get<Float64>();
|
||||
}
|
||||
column.getData().push_back(this->data(place).predict(predict_features));
|
||||
// std::vector<Float64> predict_features(arguments.size() - 1);
|
||||
// for (size_t i = 1; i < arguments.size(); ++i)
|
||||
// {
|
||||
// const auto& element = (*block.getByPosition(arguments[i]).column)[row_num];
|
||||
// if (element.getType() != Field::Types::Float64)
|
||||
// throw Exception("Prediction arguments must be values of type Float",
|
||||
// ErrorCodes::BAD_ARGUMENTS);
|
||||
//
|
||||
//// predict_features[i - 1] = element.get<Float64>();
|
||||
// }
|
||||
// column.getData().push_back(this->data(place).predict(predict_features));
|
||||
// column.getData().push_back(this->data(place).predict_for_all());
|
||||
this->data(place).predict_for_all(column.getData(), block, arguments);
|
||||
}
|
||||
|
||||
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
|
||||
|
@ -33,9 +33,9 @@ void ColumnAggregateFunction::addArena(ArenaPtr arena_)
|
||||
arenas.push_back(arena_);
|
||||
}
|
||||
|
||||
bool ColumnAggregateFunction::convertion(MutableColumnPtr* res_) const
|
||||
bool ColumnAggregateFunction::convertion(MutableColumnPtr *res_) const
|
||||
{
|
||||
if (const AggregateFunctionState * function_state = typeid_cast<const AggregateFunctionState *>(func.get()))
|
||||
if (const AggregateFunctionState *function_state = typeid_cast<const AggregateFunctionState *>(func.get()))
|
||||
{
|
||||
auto res = createView();
|
||||
res->set(function_state->getNestedFunction());
|
||||
@ -122,16 +122,25 @@ MutableColumnPtr ColumnAggregateFunction::predictValues(Block & block, const Col
|
||||
return res;
|
||||
}
|
||||
|
||||
std::cout << "\n\nHELLO: " << data.size() << "\n\n";
|
||||
/// На моих тестах дважды в эту функцию приходит нечтно, имеющее data.size() == 0 однако оно по сути ничего не делает в следующих строках
|
||||
if (1 != data.size())
|
||||
return res;
|
||||
|
||||
// auto ML_function = typeid_cast<AggregateFunctionMLMethod<LinearRegressionData, NameLinearRegression> *>(func.get());
|
||||
auto ML_function_Linear = typeid_cast<AggregateFunctionMLMethod<LinearModelData, NameLinearRegression> *>(func.get());
|
||||
auto ML_function_Logistic = typeid_cast<AggregateFunctionMLMethod<LinearModelData, NameLogisticRegression> *>(func.get());
|
||||
if (ML_function_Linear)
|
||||
{
|
||||
size_t row_num = 0;
|
||||
std::cout << "\n\nIM HERE\n" << data.size() << "\n";
|
||||
for (auto val : data) {
|
||||
std::cout << "HIII\n\n";
|
||||
ML_function_Linear->predictResultInto(val, *res, block, row_num, arguments);
|
||||
++row_num;
|
||||
}
|
||||
// ML_function_Linear->predictResultInto(data[0], *res, block, row_num, arguments);
|
||||
|
||||
} else if (ML_function_Logistic)
|
||||
{
|
||||
size_t row_num = 0;
|
||||
|
@ -59,9 +59,20 @@ namespace DB
|
||||
|
||||
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) override
|
||||
{
|
||||
// const ColumnAggregateFunction * column_with_states =
|
||||
// typeid_cast<const ColumnAggregateFunction *>(static_cast<const ColumnConst *>(&*block.getByPosition(arguments.at(0)).column));
|
||||
|
||||
|
||||
// std::cout << "\n\n\nHELOOOOO\n\n\n";
|
||||
// завести МЛ_аггр_функции как отдельный класс, чтобы тут сразу это проверять, а не делать это внутри predictValues()
|
||||
const ColumnAggregateFunction * column_with_states
|
||||
= typeid_cast<const ColumnAggregateFunction *>(&*block.getByPosition(arguments.at(0)).column);
|
||||
// std::cout << "\n\n\nHELOOOOO 2\n\n\n";
|
||||
|
||||
// const ColumnConst * column_with_states
|
||||
// = typeid_cast<const ColumnConst *>(&*block.getByPosition(arguments.at(0)).column);
|
||||
|
||||
|
||||
if (!column_with_states)
|
||||
throw Exception("Illegal column " + block.getByPosition(arguments.at(0)).column->getName()
|
||||
+ " of first argument of function "
|
||||
|
@ -41,7 +41,17 @@ create table test.model engine = Memory as select LinearRegressionState(0.1, 5,
|
||||
-- -- select multiply(param1, param2) from test.defaults;
|
||||
--
|
||||
-- -- select evalLinReg(LinRegState(0.01)(target, param1, param2), 20.0, 40.0) from test.defaults;
|
||||
select evalMLMethod(state, predict1, predict2) from test.model cross join test.defaults;
|
||||
|
||||
-- select evalMLMethod(state, predict1, predict2) from test.model cross join test.defaults;
|
||||
|
||||
|
||||
-- select evalMLMethod(state, predict1, predict2) from test.model cross join test.defaults;
|
||||
|
||||
-- select evalMLMethod(state, 0.1, 0.2) from test.model;
|
||||
|
||||
-- with (select state from test.model) as model select evalMLMethod(model, predict1, predict2) from test.defaults;
|
||||
|
||||
|
||||
-- -- select evalLinReg(state, predict1, predict2) from test.model inner join (select * from test.tests) using state;
|
||||
-- -- select evalLinReg(state1, predict1, predict2) from test.tests;
|
||||
--
|
||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user