diff --git a/dbms/src/AggregateFunctions/AggregateFunctionMLMethod.cpp b/dbms/src/AggregateFunctions/AggregateFunctionMLMethod.cpp index 814f60f544c..4a8d7a43172 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionMLMethod.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionMLMethod.cpp @@ -42,7 +42,6 @@ AggregateFunctionPtr createAggregateFunctionMLMethod( } - /// Gradient_Computer for LinearRegression has LinearRegression gradient computer if (std::is_same::value) { gc = std::make_shared(argument_types.size()); @@ -64,11 +63,12 @@ AggregateFunctionPtr createAggregateFunctionMLMethod( } else if (applyVisitor(FieldVisitorConvertToNumber(), parameters[2]) == Float64{3.0}) { wu = std::make_shared(); + } else if (applyVisitor(FieldVisitorConvertToNumber(), parameters[2]) == Float64{4.0}) { - wu = std::make_shared(); - } else - { + /// Adam should be here + wu = std::make_shared(); + } else { throw Exception("Such weights updater is not implemented yet", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); } } else diff --git a/dbms/src/AggregateFunctions/AggregateFunctionMLMethod.h b/dbms/src/AggregateFunctions/AggregateFunctionMLMethod.h index d815313a67c..c295679fab4 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionMLMethod.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionMLMethod.h @@ -29,52 +29,66 @@ namespace ErrorCodes extern const int BAD_ARGUMENTS; } +/** +IGradientComputer class computes gradient according to its loss function +and stores mini-batch +*/ class IGradientComputer { public: - IGradientComputer(UInt32 sz) - : batch_gradient(sz, 0) - {} + IGradientComputer(UInt32 sz) { std::ignore = sz;// : batch_gradient(sz, 0) { + } + virtual ~IGradientComputer() = default; - virtual void compute(const std::vector & weights, Float64 bias, Float64 learning_rate, - Float64 target, const IColumn ** columns, size_t row_num) = 0; + /// Adds to batch_gradient computed gradient in point (weigts, bias) using corresponding loss function + virtual void compute(std::vector * batch_gradient, const std::vector &weights, Float64 bias, + Float64 learning_rate, Float64 target, const IColumn **columns, size_t row_num) = 0; - void reset() - { - batch_gradient.assign(batch_gradient.size(), 0); - } +// void reset() +// { +// batch_gradient.assign(batch_gradient.size(), 0); +// } - void write(WriteBuffer & buf) const - { - writeBinary(batch_gradient, buf); - } +// void write(WriteBuffer &buf) const +// { +// writeBinary(batch_gradient, buf); +// } +// +// void read(ReadBuffer &buf) +// { +// readBinary(batch_gradient, buf); +// } - void read(ReadBuffer & buf) - { - readBinary(batch_gradient, buf); - } +// const std::vector &get() const +// { +// return batch_gradient; +// } - const std::vector & get() const - { - return batch_gradient; - } - virtual Float64 predict(const std::vector & predict_feature, const std::vector & weights, Float64 bias) const = 0; - virtual void predict_for_all(ColumnVector::Container & container, Block & block, const ColumnNumbers & arguments, const std::vector & weights, Float64 bias) const = 0; + virtual Float64 predict(const std::vector &predict_feature, + const std::vector &weights, + Float64 bias) const = 0; -protected: - std::vector batch_gradient; // gradient for bias lies in batch_gradient[batch_gradient.size() - 1] + /// Now we should use predict_for_all function instead of predict + virtual void predict_for_all(ColumnVector::Container &container, + Block &block, const ColumnNumbers &arguments, + const std::vector &weights, + Float64 bias) const = 0; + +//protected: +// std::vector batch_gradient; // gradient for bias lies in batch_gradient[batch_gradient.size() - 1] }; + class LinearRegression : public IGradientComputer { public: LinearRegression(UInt32 sz) - : IGradientComputer(sz) - {} + : IGradientComputer(sz) { + } - void compute(const std::vector & weights, Float64 bias, Float64 learning_rate, - Float64 target, const IColumn ** columns, size_t row_num) override + void compute(std::vector * batch_gradient, const std::vector &weights, Float64 bias, + Float64 learning_rate, Float64 target, const IColumn **columns, size_t row_num) override { Float64 derivative = (target - bias); for (size_t i = 0; i < weights.size(); ++i) @@ -83,20 +97,19 @@ public: } derivative *= (2 * learning_rate); - batch_gradient[weights.size()] += derivative; + (*batch_gradient)[weights.size()] += derivative; for (size_t i = 0; i < weights.size(); ++i) { - batch_gradient[i] += derivative * static_cast &>(*columns[i]).getData()[row_num];; + (*batch_gradient)[i] += + derivative * static_cast &>(*columns[i]).getData()[row_num]; } } - Float64 predict(const std::vector & predict_feature, const std::vector & weights, Float64 bias) const override + + Float64 predict(const std::vector &predict_feature, + const std::vector &weights, Float64 bias) const override { /// не обновляем веса при предикте, т.к. это может замедлить предсказание /// однако можно например обновлять их при каждом мердже не зависимо от того, сколько элементнов в батче -// if (cur_batch) -// { -// update_weights(); -// } Float64 res{0.0}; for (size_t i = 0; i < predict_feature.size(); ++i) @@ -107,7 +120,11 @@ public: return res; } - void predict_for_all(ColumnVector::Container & container, Block & block, const ColumnNumbers & arguments, const std::vector & weights, Float64 bias) const override + + void predict_for_all(ColumnVector::Container &container, + Block &block, + const ColumnNumbers &arguments, + const std::vector &weights, Float64 bias) const override { size_t rows_num = block.rows(); std::vector results(rows_num, bias); @@ -134,6 +151,7 @@ public: } }; + class LogisticRegression : public IGradientComputer { public: @@ -141,31 +159,32 @@ public: : IGradientComputer(sz) {} - void compute(const std::vector & weights, Float64 bias, Float64 learning_rate, - Float64 target, const IColumn ** columns, size_t row_num) override + void compute(std::vector * batch_gradient, const std::vector &weights, Float64 bias, + Float64 learning_rate, Float64 target, const IColumn **columns, size_t row_num) override { Float64 derivative = bias; for (size_t i = 0; i < weights.size(); ++i) { - derivative += weights[i] * static_cast &>(*columns[i]).getData()[row_num];; + derivative += weights[i] * static_cast &>(*columns[i]).getData()[row_num]; } derivative *= target; - derivative = exp(derivative); + derivative = exp(derivative); - batch_gradient[weights.size()] += learning_rate * target / (derivative + 1);; + (*batch_gradient)[weights.size()] += learning_rate * target / (derivative + 1);; for (size_t i = 0; i < weights.size(); ++i) { - batch_gradient[i] += learning_rate * target / (derivative + 1) * static_cast &>(*columns[i]).getData()[row_num]; + (*batch_gradient)[i] += + learning_rate * target * + static_cast &>(*columns[i]).getData()[row_num] + / (derivative + 1); } } - Float64 predict(const std::vector & predict_feature, const std::vector & weights, Float64 bias) const override + + Float64 predict(const std::vector &predict_feature, + const std::vector &weights, Float64 bias) const override { /// не обновляем веса при предикте, т.к. это может замедлить предсказание /// однако можно например обновлять их при каждом мердже не зависимо от того, сколько элементнов в батче -// if (cur_batch) -// { -// update_weights(); -// } Float64 res{0.0}; for (size_t i = 0; i < predict_feature.size(); ++i) @@ -176,7 +195,11 @@ public: res = 1 / (1 + exp(-res)); return res; } - void predict_for_all(ColumnVector::Container & container, Block & block, const ColumnNumbers & arguments, const std::vector & weights, Float64 bias) const override + + void predict_for_all(ColumnVector::Container & container, + Block & block, + const ColumnNumbers & arguments, + const std::vector & weights, Float64 bias) const override { size_t rows_num = block.rows(); std::vector results(rows_num, bias); @@ -186,7 +209,6 @@ public: ColumnPtr cur_col = block.getByPosition(arguments[i]).column; for (size_t row_num = 0; row_num != rows_num; ++row_num) { - const auto &element = (*cur_col)[row_num]; if (element.getType() != Field::Types::Float64) throw Exception("Prediction arguments must be values of type Float", @@ -197,124 +219,226 @@ public: } for (size_t row_num = 0; row_num != rows_num; ++row_num) { - results[row_num] = 1 / (1 + exp(-results[row_num])); - container.emplace_back(results[row_num]); + container.emplace_back(1 / (1 + exp(-results[row_num]))); } } }; + +/** +* IWeightsUpdater class defines the way to update current state +* and uses GradientComputer on each iteration +*/ class IWeightsUpdater { public: virtual ~IWeightsUpdater() = default; - virtual void update(UInt32 batch_size, std::vector & weights, Float64 & bias, const std::vector & gradient) = 0; - virtual void merge(const IWeightsUpdater &, Float64, Float64) {} - virtual std::vector get_update(UInt32 sz, UInt32) { - return std::vector(sz, 0.0); + virtual void add_to_batch(std::vector * batch_gradient, std::shared_ptr gc, + const std::vector & weights, Float64 bias, + Float64 learning_rate, Float64 target, const IColumn **columns, size_t row_num) + { + gc->compute(batch_gradient, weights, bias, learning_rate, target, columns, row_num); } + + virtual void update(UInt32 batch_size, + std::vector & weights, Float64 & bias, + const std::vector & gradient) = 0; + + virtual void merge(const IWeightsUpdater &, Float64, Float64) + {} + + virtual void write(WriteBuffer &) const + {} + + virtual void read(ReadBuffer &) + {} +// virtual std::vector get_update(UInt32 sz, UInt32) +// { +// return std::vector(sz, 0.0); +// } }; + class StochasticGradientDescent : public IWeightsUpdater { public: - void update(UInt32 batch_size, std::vector & weights, Float64 & bias, const std::vector & batch_gradient) override { + void update(UInt32 batch_size, + std::vector & weights, Float64 & bias, + const std::vector & batch_gradient) override + { /// batch_size is already checked to be greater than 0 - for (size_t i = 0; i < weights.size(); ++i) { weights[i] += batch_gradient[i] / batch_size; } bias += batch_gradient[weights.size()] / batch_size; + +// batch_gradient->assign(batch_gradient->size(), Float64{0.0}); } }; + class Momentum : public IWeightsUpdater { public: - Momentum() {} - Momentum (Float64 alpha) : alpha_(alpha) {} - void update(UInt32 batch_size, std::vector & weights, Float64 & bias, const std::vector & batch_gradient) override { - /// batch_size is already checked to be greater than 0 + Momentum() + {} - if (accumulated_gradient.size() == 0) + Momentum(Float64 alpha) : alpha_(alpha) + {} + + void update(UInt32 batch_size, + std::vector & weights, Float64 & bias, + const std::vector & batch_gradient) override + { + /// batch_size is already checked to be greater than 0 + if (accumulated_gradient.empty()) { accumulated_gradient.resize(batch_gradient.size(), Float64{0.0}); } + for (size_t i = 0; i < batch_gradient.size(); ++i) { - accumulated_gradient[i] = accumulated_gradient[i] * alpha_ + batch_gradient[i]; + accumulated_gradient[i] = accumulated_gradient[i] * alpha_ + batch_gradient[i] / batch_size; } for (size_t i = 0; i < weights.size(); ++i) { - weights[i] += accumulated_gradient[i] / batch_size; + weights[i] += accumulated_gradient[i]; } - bias += accumulated_gradient[weights.size()] / batch_size; + bias += accumulated_gradient[weights.size()]; } - virtual void merge(const IWeightsUpdater & rhs, Float64 frac, Float64 rhs_frac) override { - auto & momentum_rhs = static_cast(rhs); + + virtual void merge(const IWeightsUpdater & rhs, Float64 frac, Float64 rhs_frac) override + { + auto &momentum_rhs = static_cast(rhs); for (size_t i = 0; i < accumulated_gradient.size(); ++i) { - accumulated_gradient[i] = accumulated_gradient[i] * frac + momentum_rhs.accumulated_gradient[i] * rhs_frac; + accumulated_gradient[i] = accumulated_gradient[i] * frac + + momentum_rhs.accumulated_gradient[i] * rhs_frac; } } + void write(WriteBuffer &buf) const override + { + writeBinary(accumulated_gradient, buf); + } + + void read(ReadBuffer &buf) override + { + readBinary(accumulated_gradient, buf); + } + private: Float64 alpha_{0.1}; std::vector accumulated_gradient; }; + + class Nesterov : public IWeightsUpdater { public: - Nesterov() {} - Nesterov (Float64 alpha) : alpha_(alpha) {} - void update(UInt32 batch_size, std::vector & weights, Float64 & bias, const std::vector & batch_gradient) override { - if (accumulated_gradient.size() == 0) + Nesterov() + {} + + Nesterov(Float64 alpha) : alpha_(alpha) + {} + + void add_to_batch(std::vector * batch_gradient, std::shared_ptr gc, + const std::vector & weights, Float64 bias, + Float64 learning_rate, Float64 target, const IColumn ** columns, size_t row_num) override + { + if (accumulated_gradient.empty()) + { + accumulated_gradient.resize(batch_gradient->size(), Float64{0.0}); + } + + std::vector shifted_weights(weights.size()); + for (size_t i = 0; i != shifted_weights.size(); ++i) + { + shifted_weights[i] = weights[i] + accumulated_gradient[i] * alpha_; + } + auto shifted_bias = bias + accumulated_gradient[weights.size()] * alpha_; + + gc->compute(batch_gradient, shifted_weights, shifted_bias, learning_rate, target, columns, row_num); + } + + void update(UInt32 batch_size, + std::vector & weights, Float64 & bias, + const std::vector & batch_gradient) override + { + if (accumulated_gradient.empty()) { accumulated_gradient.resize(batch_gradient.size(), Float64{0.0}); } + for (size_t i = 0; i < batch_gradient.size(); ++i) { - accumulated_gradient[i] = accumulated_gradient[i] * alpha_ + batch_gradient[i]; + accumulated_gradient[i] = accumulated_gradient[i] * alpha_ + batch_gradient[i] / batch_size; } for (size_t i = 0; i < weights.size(); ++i) { - weights[i] += accumulated_gradient[i] / batch_size; + weights[i] += accumulated_gradient[i]; } - bias += accumulated_gradient[weights.size()] / batch_size; - std::cout<<"BIAS " << bias<<'\n'; + bias += accumulated_gradient[weights.size()]; } - virtual void merge(const IWeightsUpdater & rhs, Float64 frac, Float64 rhs_frac) override { + + virtual void merge(const IWeightsUpdater & rhs, Float64 frac, Float64 rhs_frac) override + { auto & nesterov_rhs = static_cast(rhs); for (size_t i = 0; i < accumulated_gradient.size(); ++i) { - accumulated_gradient[i] = accumulated_gradient[i] * frac + nesterov_rhs.accumulated_gradient[i] * rhs_frac; + accumulated_gradient[i] = + accumulated_gradient[i] * frac + nesterov_rhs.accumulated_gradient[i] * rhs_frac; } } - virtual std::vector get_update(UInt32 sz, UInt32 batch_size) override { - if (accumulated_gradient.size() == 0) - { - accumulated_gradient.resize(sz, Float64{0.0}); - return accumulated_gradient; - } - std::vector delta(accumulated_gradient.size()); - // std::cout<<"\n\nHK\n\n"; - for (size_t i = 0; i < delta.size(); ++i) - { - delta[i] = accumulated_gradient[i] * alpha_ / batch_size; - } - return delta; - } -Float64 alpha_{0.1}; -std::vector accumulated_gradient; + void write(WriteBuffer &buf) const override + { + writeBinary(accumulated_gradient, buf); + } + + void read(ReadBuffer &buf) override + { + readBinary(accumulated_gradient, buf); + } +// virtual std::vector get_update(UInt32 sz, UInt32 batch_size) override +// { +// if (accumulated_gradient.size() == 0) +// { +// accumulated_gradient.resize(sz, Float64{0.0}); +// return accumulated_gradient; +// } +// std::vector delta(accumulated_gradient.size()); +// // std::cout<<"\n\nHK\n\n"; +// for (size_t i = 0; i < delta.size(); ++i) +// { +// delta[i] = accumulated_gradient[i] * alpha_ / batch_size; +// } +// return delta; +// } + +private: + Float64 alpha_{0.1}; + std::vector accumulated_gradient; }; + + +// TODO: проверить после изменения логики моментума +/* class Adam : public IWeightsUpdater { public: - Adam() {} - Adam (Float64 betta1, Float64 betta2) : betta1_(betta1), betta2_(betta2), betta1t_(betta1), betta2t_(betta2) {} - void update(UInt32 cur_batch, std::vector & weights, Float64 & bias, const std::vector & batch_gradient) override { + Adam() + {} + + Adam(Float64 betta1, Float64 betta2) : betta1_(betta1), betta2_(betta2), betta1t_(betta1), betta2t_(betta2) + {} + + void update(UInt32 cur_batch, + std::vector & weights, Float64 & bias, + std::vector * batch_gradient) override + { if (mt_.size() == 0) { mt_.resize(batch_gradient.size(), Float64{0.0}); @@ -325,11 +449,13 @@ public: { mt_[i] = mt_[i] * betta1_ + (1 - betta1_) * batch_gradient[i]; vt_[i] = vt_[i] * betta2_ + (1 - betta2_) * batch_gradient[i] * batch_gradient[i]; - if (t < 8) { + if (t < 8) + { mt_[i] = mt_[i] / (1 - betta1t_); betta1t_ *= betta1_; } - if (t < 850) { + if (t < 850) + { vt_[i] = vt_[i] / (1 - betta2t_); betta2t_ *= betta2_; } @@ -341,23 +467,33 @@ public: bias += (mt_[weights.size()] / (sqrt(vt_[weights.size()] + eps))) / cur_batch; t += 1; } - virtual void merge(const IWeightsUpdater & rhs, Float64 frac, Float64 rhs_frac) override { - auto & adam_rhs = static_cast(rhs); + + virtual void merge(const IWeightsUpdater &rhs, Float64 frac, Float64 rhs_frac) override + { + auto &adam_rhs = static_cast(rhs); for (size_t i = 0; i < mt_.size(); ++i) { mt_[i] = mt_[i] * frac + adam_rhs.mt_[i] * rhs_frac; vt_[i] = vt_[i] * frac + adam_rhs.vt_[i] * rhs_frac; } } -Float64 betta1_{0.2}; -Float64 betta2_{0.3}; -Float64 betta1t_{0.3}; -Float64 betta2t_{0.3}; -UInt32 t = 0; -std::vector mt_; -std::vector vt_; -}; +private: + Float64 betta1_{0.2}; + Float64 betta2_{0.3}; + Float64 betta1t_{0.3}; + Float64 betta2t_{0.3}; + UInt32 t = 0; + std::vector mt_; + std::vector vt_; +}; + */ + + +/** +* LinearModelData is a class which manages current state of learning +* and is stored as AggregateFunctionState +*/ class LinearModelData { public: @@ -365,33 +501,53 @@ public: {} LinearModelData(Float64 learning_rate, - UInt32 param_num, - UInt32 batch_capacity, - std::shared_ptr gc, - std::shared_ptr wu) - : learning_rate(learning_rate), - batch_capacity(batch_capacity), - gradient_computer(std::move(gc)), - weights_updater(std::move(wu)) + UInt32 param_num, + UInt32 batch_capacity, + std::shared_ptr gc, + std::shared_ptr wu) + : learning_rate(learning_rate), + batch_capacity(batch_capacity), + batch_size(0), + gradient_computer(std::move(gc)), + weights_updater(std::move(wu)) { weights.resize(param_num, Float64{0.0}); - batch_size = 0; + gradient_batch.resize(param_num + 1, Float64{0.0}); } - void add(const IColumn ** columns, size_t row_num) + void add(const IColumn **columns, size_t row_num) { /// first column stores target; features start from (columns + 1) - const auto & target = static_cast &>(*columns[0]).getData()[row_num]; - - auto delta = weights_updater->get_update(weights.size() + 1, batch_capacity); - Float64 delta_bias = bias + delta[weights.size()]; - delta.resize(weights.size()); - for (size_t i = 0; i < weights.size(); ++i) { - delta[i] += weights[i]; - } - - gradient_computer->compute(delta, delta_bias, learning_rate, target, columns + 1, row_num); - // gradient_computer->compute(weights, bias, learning_rate, target, columns + 1, row_num); + const auto &target = static_cast &>(*columns[0]).getData()[row_num]; + +// auto delta = weights_updater->get_update(weights.size() + 1, batch_capacity); +// Float64 delta_bias = bias + delta[weights.size()]; +// delta.resize(weights.size()); +// for (size_t i = 0; i < weights.size(); ++i) +// { +// delta[i] += weights[i]; +// } + +// gradient_computer->compute(delta, delta_bias, learning_rate, target, columns + 1, row_num); +// gradient_computer->compute(weights, bias, learning_rate, target, columns + 1, row_num); + std::cout << "\nBATCH BEFORE\n"; + for (auto i : gradient_batch) + std::cout << i << " "; + std::cout << "\nhello\n"; + + + weights_updater->add_to_batch(&gradient_batch, gradient_computer, + weights, bias, learning_rate, target, columns, row_num); + + + std::cout << "BATCH AFTER\n"; + for (auto i : gradient_batch) + std::cout << i << " "; + std::cout << "\nhello\n\n"; + + if (iter_num == 10) + exit(1); + ++batch_size; if (batch_size == batch_capacity) { @@ -399,7 +555,7 @@ public: } } - void merge(const LinearModelData & rhs) + void merge(const LinearModelData &rhs) { if (iter_num == 0 && rhs.iter_num == 0) return; @@ -421,63 +577,90 @@ public: weights_updater->merge(*rhs.weights_updater, frac, rhs_frac); } - void write(WriteBuffer & buf) const + void write(WriteBuffer &buf) const { writeBinary(bias, buf); writeBinary(weights, buf); writeBinary(iter_num, buf); + writeBinary(gradient_batch, buf); writeBinary(batch_size, buf); - gradient_computer->write(buf); + weights_updater->write(buf); +// gradient_computer->write(buf); } - void read(ReadBuffer & buf) + void read(ReadBuffer &buf) { readBinary(bias, buf); readBinary(weights, buf); readBinary(iter_num, buf); + readBinary(gradient_batch, buf); readBinary(batch_size, buf); - gradient_computer->read(buf); + weights_updater->read(buf); +// gradient_computer->read(buf); } - Float64 predict(const std::vector & predict_feature) const + Float64 predict(const std::vector &predict_feature) const { /// не обновляем веса при предикте, т.к. это может замедлить предсказание - /// однако можно например обновлять их при каждом мердже не зависимо от того, сколько элементнов в батче + /// однако можно например обновлять их при каждом мердже независимо от того, сколько элементнов в батче // if (cur_batch) // { // update_weights(); // } + std::cout << "\n\nWEIGHTS: "; + for (size_t i = 0; i != weights.size(); ++i) { + std::cout << weights[i] << " "; + } + std::cout << "\n\n"; + return gradient_computer->predict(predict_feature, weights, bias); } - void predict_for_all(ColumnVector::Container & container, Block & block, const ColumnNumbers & arguments) const + + void predict_for_all(ColumnVector::Container &container, Block &block, const ColumnNumbers &arguments) const { + std::cout << "\n\nWEIGHTS: "; + for (size_t i = 0; i != weights.size(); ++i) { + std::cout << weights[i] << " "; + } + std::cout << "\n\n"; + gradient_computer->predict_for_all(container, block, arguments, weights, bias); } private: std::vector weights; + Float64 bias{0.0}; + Float64 learning_rate; UInt32 batch_capacity; - Float64 bias{0.0}; + UInt32 iter_num = 0; + std::vector gradient_batch; UInt32 batch_size; + std::shared_ptr gradient_computer; std::shared_ptr weights_updater; + /** + * The function is called when we want to flush current batch and make a step with it + */ void update_state() { if (batch_size == 0) return; - weights_updater->update(batch_size, weights, bias, gradient_computer->get()); +// weights_updater->update(batch_size, weights, bias, gradient_batch); + + // /// use pointer to gradient_batch, because some methods (e.g. simple stochastic descent) require to reset it + weights_updater->update(batch_size, weights, bias, gradient_batch); batch_size = 0; ++iter_num; - gradient_computer->reset(); - //TODO ask: для нестерова и адама не очень. Нужно добавить другую функцию + gradient_batch.assign(gradient_batch.size(), Float64{0.0}); } }; + template < /// Implemented Machine Learning method typename Data, @@ -490,19 +673,19 @@ public: String getName() const override { return Name::name; } explicit AggregateFunctionMLMethod(UInt32 param_num, - std::shared_ptr gradient_computer, - std::shared_ptr weights_updater, - Float64 learning_rate, - UInt32 batch_size, - const DataTypes & argument_types, - const Array & params) + std::shared_ptr gradient_computer, + std::shared_ptr weights_updater, + Float64 learning_rate, + UInt32 batch_size, + const DataTypes & argument_types, + const Array & params) : IAggregateFunctionDataHelper>(argument_types, params), param_num(param_num), learning_rate(learning_rate), batch_size(batch_size), gc(std::move(gradient_computer)), - wu(std::move(weights_updater)) - {} + wu(std::move(weights_updater)) { + } DataTypePtr getReturnType() const override { @@ -538,7 +721,8 @@ public: void predictResultInto(ConstAggregateDataPtr place, IColumn & to, Block & block, const ColumnNumbers & arguments) const { if (arguments.size() != param_num + 1) - throw Exception("Predict got incorrect number of arguments. Got: " + std::to_string(arguments.size()) + ". Required: " + std::to_string(param_num + 1), + throw Exception("Predict got incorrect number of arguments. Got: " + + std::to_string(arguments.size()) + ". Required: " + std::to_string(param_num + 1), ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); auto &column = dynamic_cast &>(to); diff --git a/dbms/tests/queries/0_stateless/00950_ml_test.sql b/dbms/tests/queries/0_stateless/00950_ml_test.sql index 0a009c29d29..46f64a1f1dd 100644 --- a/dbms/tests/queries/0_stateless/00950_ml_test.sql +++ b/dbms/tests/queries/0_stateless/00950_ml_test.sql @@ -11,7 +11,7 @@ CREATE TABLE IF NOT EXISTS test.defaults insert into test.defaults values (-3.273, -1.452, 4.267, 20.0, 40.0), (0.121, -0.615, 4.290, 20.0, 40.0), (-1.099, 2.755, -3.060, 20.0, 40.0), (1.090, 2.945, -2.346, 20.0, 40.0), (0.305, 2.179, -1.205, 20.0, 40.0), (-0.925, 0.702, 1.134, 20.0, 40.0), (3.178, -1.316, 7.221, 20.0, 40.0), (-2.756, -0.473, 2.569, 20.0, 40.0), (3.665, 2.303, 0.226, 20.0, 40.0), (1.662, 1.951, -0.070, 20.0, 40.0), (2.869, 0.593, 3.249, 20.0, 40.0), (0.818, -0.593, 4.594, 20.0, 40.0), (-1.917, 0.916, 0.209, 20.0, 40.0), (2.706, 1.523, 1.307, 20.0, 40.0), (0.219, 2.162, -1.214, 20.0, 40.0), (-4.510, 1.376, -2.007, 20.0, 40.0), (4.284, -0.515, 6.173, 20.0, 40.0), (-1.101, 2.810, -3.170, 20.0, 40.0), (-1.810, -1.117, 4.329, 20.0, 40.0), (0.055, 1.115, 0.797, 20.0, 40.0), (-2.178, 2.904, -3.898, 20.0, 40.0), (-3.494, -1.814, 4.882, 20.0, 40.0), (3.027, 0.476, 3.562, 20.0, 40.0), (-1.434, 1.151, -0.018, 20.0, 40.0), (1.180, 0.992, 1.606, 20.0, 40.0), (0.015, 0.971, 1.067, 20.0, 40.0), (-0.511, -0.875, 4.495, 20.0, 40.0), (0.961, 2.348, -1.216, 20.0, 40.0), (-2.279, 0.038, 1.785, 20.0, 40.0), (-1.568, -0.248, 2.712, 20.0, 40.0), (-0.496, 0.366, 2.020, 20.0, 40.0), (1.177, -1.401, 6.390, 20.0, 40.0), (2.882, -1.442, 7.325, 20.0, 40.0), (-1.066, 1.817, -1.167, 20.0, 40.0), (-2.144, 2.791, -3.655, 20.0, 40.0), (-4.370, 2.228, -3.642, 20.0, 40.0), (3.996, 2.775, -0.553, 20.0, 40.0), (0.289, 2.055, -0.965, 20.0, 40.0), (-0.588, -1.601, 5.908, 20.0, 40.0), (-1.801, 0.417, 1.265, 20.0, 40.0), (4.375, -1.499, 8.186, 20.0, 40.0), (-2.618, 0.038, 1.615, 20.0, 40.0), (3.616, -0.833, 6.475, 20.0, 40.0), (-4.045, -1.558, 4.094, 20.0, 40.0), (-3.962, 0.636, -0.253, 20.0, 40.0), (3.505, 2.625, -0.497, 20.0, 40.0), (3.029, -0.523, 5.560, 20.0, 40.0), (-3.520, -0.474, 2.188, 20.0, 40.0), (2.430, -1.469, 7.154, 20.0, 40.0), (1.547, -1.654, 7.082, 20.0, 40.0), (-1.370, 0.575, 1.165, 20.0, 40.0), (-1.869, -1.555, 5.176, 20.0, 40.0), (3.536, 2.841, -0.913, 20.0, 40.0), (-3.810, 1.220, -1.344, 20.0, 40.0), (-1.971, 1.462, -0.910, 20.0, 40.0), (-0.243, 0.167, 2.545, 20.0, 40.0), (-1.403, 2.645, -2.991, 20.0, 40.0), (0.532, -0.114, 3.494, 20.0, 40.0), (-1.678, 0.975, 0.212, 20.0, 40.0), (-0.656, 2.140, -1.609, 20.0, 40.0), (1.743, 2.631, -1.390, 20.0, 40.0), (2.586, 2.943, -1.593, 20.0, 40.0), (-0.512, 2.969, -3.195, 20.0, 40.0), (2.283, -0.100, 4.342, 20.0, 40.0), (-4.293, 0.872, -0.890, 20.0, 40.0), (3.411, 1.300, 2.106, 20.0, 40.0), (-0.281, 2.951, -3.042, 20.0, 40.0), (-4.442, 0.384, 0.012, 20.0, 40.0), (1.194, 1.746, 0.104, 20.0, 40.0), (-1.152, 1.862, -1.300, 20.0, 40.0), (1.362, -1.341, 6.363, 20.0, 40.0), (-4.488, 2.618, -4.481, 20.0, 40.0), (3.419, -0.564, 5.837, 20.0, 40.0), (-3.392, 0.396, 0.512, 20.0, 40.0), (-1.629, -0.909, 4.003, 20.0, 40.0), (4.447, -1.088, 7.399, 20.0, 40.0), (-1.232, 1.699, -1.014, 20.0, 40.0), (-1.286, -0.609, 3.575, 20.0, 40.0), (2.437, 2.796, -1.374, 20.0, 40.0), (-4.864, 1.989, -3.410, 20.0, 40.0), (-1.716, -1.399, 4.940, 20.0, 40.0), (-3.084, 1.858, -2.259, 20.0, 40.0), (2.828, -0.319, 5.053, 20.0, 40.0), (-1.226, 2.586, -2.786, 20.0, 40.0), (2.456, 0.092, 4.044, 20.0, 40.0), (-0.989, 2.375, -2.245, 20.0, 40.0), (3.268, 0.935, 2.765, 20.0, 40.0), (-4.128, -1.995, 4.927, 20.0, 40.0), (-1.083, 2.197, -1.935, 20.0, 40.0), (-3.471, -1.198, 3.660, 20.0, 40.0), (4.617, -1.136, 7.579, 20.0, 40.0), (2.054, -1.675, 7.378, 20.0, 40.0), (4.106, 2.326, 0.402, 20.0, 40.0), (1.558, 0.310, 3.158, 20.0, 40.0), (0.792, 0.900, 1.596, 20.0, 40.0), (-3.229, 0.300, 0.785, 20.0, 40.0), (3.787, -0.793, 6.479, 20.0, 40.0), (1.786, 2.288, -0.684, 20.0, 40.0), (2.643, 0.223, 3.875, 20.0, 40.0), (-3.592, 2.122, -3.040, 20.0, 40.0), (4.519, -1.760, 8.779, 20.0, 40.0), (3.221, 2.255, 0.101, 20.0, 40.0), (4.151, 1.788, 1.500, 20.0, 40.0), (-1.033, -1.195, 4.874, 20.0, 40.0), (-1.636, -1.037, 4.257, 20.0, 40.0), (-3.548, 1.911, -2.596, 20.0, 40.0), (4.829, -0.293, 6.001, 20.0, 40.0), (-4.684, -1.664, 3.986, 20.0, 40.0), (4.531, -0.503, 6.271, 20.0, 40.0), (-3.503, -1.606, 4.460, 20.0, 40.0), (-2.036, -1.522, 5.027, 20.0, 40.0), (-0.473, -0.617, 3.997, 20.0, 40.0), (-1.554, -1.630, 5.483, 20.0, 40.0), (-3.567, -1.043, 3.302, 20.0, 40.0), (-2.038, 0.579, 0.823, 20.0, 40.0), (-3.040, 0.857, -0.233, 20.0, 40.0), (4.610, 0.562, 4.181, 20.0, 40.0), (-3.323, -1.938, 5.215, 20.0, 40.0), (4.314, 1.720, 1.717, 20.0, 40.0), (-1.220, 0.615, 1.161, 20.0, 40.0), (-2.556, 1.120, -0.519, 20.0, 40.0), (-3.717, -0.108, 1.358, 20.0, 40.0), (4.689, -1.826, 8.996, 20.0, 40.0), (3.452, 0.506, 3.713, 20.0, 40.0), (2.472, 0.612, 3.012, 20.0, 40.0), (3.452, 0.450, 3.826, 20.0, 40.0), (1.207, 2.585, -1.567, 20.0, 40.0), (-4.826, 1.090, -1.593, 20.0, 40.0), (3.116, -1.118, 6.794, 20.0, 40.0), (0.448, 2.732, -2.240, 20.0, 40.0), (-1.096, -0.525, 3.503, 20.0, 40.0), (-4.680, -0.238, 1.137, 20.0, 40.0), (2.552, -1.403, 7.082, 20.0, 40.0), (0.719, 2.997, -2.635, 20.0, 40.0), (0.347, -1.966, 7.105, 20.0, 40.0), (2.958, -0.404, 5.288, 20.0, 40.0), (0.722, -1.950, 7.261, 20.0, 40.0), (-2.851, -0.986, 3.546, 20.0, 40.0), (-4.316, -0.439, 1.721, 20.0, 40.0), (-1.685, -0.201, 2.560, 20.0, 40.0), (1.856, 0.190, 3.549, 20.0, 40.0), (-2.052, 0.206, 1.562, 20.0, 40.0), (-2.504, -0.646, 3.041, 20.0, 40.0), (3.235, 0.882, 2.854, 20.0, 40.0), (-1.366, -1.573, 5.463, 20.0, 40.0), (-3.447, 2.419, -3.562, 20.0, 40.0), (4.155, 2.092, 0.893, 20.0, 40.0), (-0.935, 0.209, 2.116, 20.0, 40.0), (3.117, -1.821, 8.201, 20.0, 40.0), (3.759, 0.577, 3.725, 20.0, 40.0), (-0.938, 2.992, -3.453, 20.0, 40.0), (-0.525, 2.341, -1.945, 20.0, 40.0), (4.540, 2.625, 0.019, 20.0, 40.0), (-2.097, 1.190, -0.429, 20.0, 40.0), (-2.672, 1.983, -2.302, 20.0, 40.0), (-3.038, -1.490, 4.460, 20.0, 40.0), (-0.943, 2.149, -1.770, 20.0, 40.0), (0.739, 1.598, 0.174, 20.0, 40.0), (1.828, 1.853, 0.208, 20.0, 40.0), (4.856, 0.137, 5.153, 20.0, 40.0), (-1.617, 0.468, 1.255, 20.0, 40.0), (-1.972, 2.053, -2.092, 20.0, 40.0), (-4.633, 1.389, -2.094, 20.0, 40.0), (-3.628, -1.156, 3.498, 20.0, 40.0), (3.597, 1.034, 2.731, 20.0, 40.0), (-1.488, -0.002, 2.261, 20.0, 40.0), (0.749, 1.921, -0.468, 20.0, 40.0), (1.304, -1.371, 6.394, 20.0, 40.0), (4.587, 2.936, -0.579, 20.0, 40.0), (-2.241, 1.791, -1.703, 20.0, 40.0), (-2.945, 1.372, -1.216, 20.0, 40.0), (1.375, 0.395, 2.898, 20.0, 40.0), (-1.281, -0.641, 3.642, 20.0, 40.0), (2.178, 0.895, 2.299, 20.0, 40.0), (3.031, -0.786, 6.087, 20.0, 40.0), (-1.385, -0.375, 3.058, 20.0, 40.0), (4.041, -0.431, 5.882, 20.0, 40.0), (0.480, -0.507, 4.254, 20.0, 40.0), (-3.797, 0.140, 0.822, 20.0, 40.0), (2.355, 2.502, -0.827, 20.0, 40.0), (1.376, -1.583, 6.854, 20.0, 40.0), (0.164, 1.405, 0.273, 20.0, 40.0), (-1.273, 1.471, -0.579, 20.0, 40.0), (0.770, 2.246, -1.107, 20.0, 40.0), (4.552, 2.904, -0.533, 20.0, 40.0), (4.259, -1.772, 8.674, 20.0, 40.0), (-0.309, 1.159, 0.528, 20.0, 40.0), (3.581, 2.700, -0.610, 20.0, 40.0), (-3.202, 0.346, 0.707, 20.0, 40.0), (-1.575, 1.242, -0.271, 20.0, 40.0), (-1.584, -0.493, 3.194, 20.0, 40.0), (-3.778, 0.150, 0.810, 20.0, 40.0), (-4.675, 1.749, -2.835, 20.0, 40.0), (3.567, -0.792, 6.367, 20.0, 40.0), (-0.417, 1.399, -0.006, 20.0, 40.0), (-4.672, 2.007, -3.349, 20.0, 40.0), (-1.034, 0.196, 2.090, 20.0, 40.0), (-3.796, 2.496, -3.890, 20.0, 40.0), (3.532, -0.497, 5.759, 20.0, 40.0), (4.868, -1.359, 8.151, 20.0, 40.0), (-0.769, 0.302, 2.011, 20.0, 40.0), (4.475, 2.612, 0.014, 20.0, 40.0), (-3.532, -0.395, 2.024, 20.0, 40.0), (0.322, 0.675, 1.812, 20.0, 40.0), (-2.028, -1.942, 5.870, 20.0, 40.0), (1.810, -1.244, 6.392, 20.0, 40.0), (-0.783, 1.242, 0.124, 20.0, 40.0), (-4.745, -1.300, 3.227, 20.0, 40.0), (1.902, 1.973, 0.005, 20.0, 40.0), (-3.453, -1.429, 4.132, 20.0, 40.0), (1.559, 0.986, 1.808, 20.0, 40.0), (0.128, 2.754, -2.443, 20.0, 40.0), (2.759, 1.727, 0.926, 20.0, 40.0), (-4.468, 1.690, -2.614, 20.0, 40.0), (-2.368, -1.922, 5.659, 20.0, 40.0), (-2.766, 2.128, -2.640, 20.0, 40.0), (0.967, -1.825, 7.133, 20.0, 40.0), (-2.854, 2.855, -4.136, 20.0, 40.0), (-2.944, 1.875, -2.222, 20.0, 40.0), (-2.632, -0.983, 3.649, 20.0, 40.0), (2.427, 2.239, -0.266, 20.0, 40.0), (-1.726, -0.838, 3.812, 20.0, 40.0), (0.007, -0.903, 4.809, 20.0, 40.0), (-2.013, 1.092, -0.191, 20.0, 40.0), (-0.449, 0.970, 0.836, 20.0, 40.0), (1.396, 0.411, 2.876, 20.0, 40.0), (-1.115, -1.790, 6.023, 20.0, 40.0), (3.748, 1.917, 1.039, 20.0, 40.0), (2.978, 1.043, 2.404, 20.0, 40.0), (-3.969, 2.514, -4.013, 20.0, 40.0), (4.455, -0.050, 5.328, 20.0, 40.0), (-3.065, -0.846, 3.160, 20.0, 40.0), (-1.069, 2.167, -1.869, 20.0, 40.0), (3.016, -1.393, 7.294, 20.0, 40.0), (0.045, -1.928, 6.879, 20.0, 40.0), (-2.555, -0.984, 3.690, 20.0, 40.0), (-1.995, -0.054, 2.111, 20.0, 40.0), (4.600, -0.509, 6.318, 20.0, 40.0), (-1.942, 1.215, -0.402, 20.0, 40.0), (1.262, 2.765, -1.899, 20.0, 40.0), (2.617, -1.106, 6.521, 20.0, 40.0), (1.737, 0.554, 2.761, 20.0, 40.0), (-2.197, 0.632, 0.638, 20.0, 40.0), (4.768, 2.618, 0.147, 20.0, 40.0), (-3.737, -0.939, 3.010, 20.0, 40.0), (-2.623, 0.595, 0.499, 20.0, 40.0), (4.752, -0.340, 6.057, 20.0, 40.0), (2.333, -1.037, 6.240, 20.0, 40.0), (4.234, -1.882, 8.881, 20.0, 40.0), (-3.393, -0.812, 2.927, 20.0, 40.0), (0.885, 1.383, 0.678, 20.0, 40.0), (0.123, 2.937, -2.812, 20.0, 40.0), (2.969, 0.760, 2.964, 20.0, 40.0), (-4.929, 1.251, -1.967, 20.0, 40.0), (1.916, 2.223, -0.488, 20.0, 40.0), (-0.020, -1.740, 6.469, 20.0, 40.0), (0.702, -1.272, 5.895, 20.0, 40.0), (2.496, 2.648, -1.048, 20.0, 40.0), (4.067, -1.475, 7.984, 20.0, 40.0), (-3.717, 1.851, -2.561, 20.0, 40.0), (1.678, -0.624, 5.088, 20.0, 40.0), (1.073, 0.695, 2.146, 20.0, 40.0), (1.842, -0.749, 5.419, 20.0, 40.0), (-3.518, 1.909, -2.578, 20.0, 40.0), (2.229, 1.189, 1.737, 20.0, 40.0), (4.987, 2.893, -0.292, 20.0, 40.0), (-4.809, 1.043, -1.490, 20.0, 40.0), (-0.241, -0.728, 4.334, 20.0, 40.0), (-3.331, 0.590, 0.156, 20.0, 40.0), (-0.455, 2.621, -2.470, 20.0, 40.0), (1.492, 1.223, 1.301, 20.0, 40.0), (3.948, 2.841, -0.709, 20.0, 40.0), (0.732, 0.446, 2.475, 20.0, 40.0), (2.400, 2.390, -0.579, 20.0, 40.0), (-2.718, 1.427, -1.213, 20.0, 40.0), (-1.826, 1.451, -0.815, 20.0, 40.0), (1.125, 0.438, 2.686, 20.0, 40.0), (-4.918, 1.880, -3.219, 20.0, 40.0), (3.068, -0.442, 5.418, 20.0, 40.0), (1.982, 1.201, 1.589, 20.0, 40.0), (0.701, -1.709, 6.768, 20.0, 40.0), (-1.496, 2.564, -2.877, 20.0, 40.0), (-3.812, 0.974, -0.853, 20.0, 40.0), (-3.405, 2.018, -2.739, 20.0, 40.0), (2.211, 2.889, -1.674, 20.0, 40.0), (-2.481, 2.931, -4.103, 20.0, 40.0), (-3.721, 2.765, -4.391, 20.0, 40.0), (-1.768, -1.292, 4.699, 20.0, 40.0), (-4.462, 1.058, -1.347, 20.0, 40.0), (-3.516, -1.942, 5.126, 20.0, 40.0), (0.485, 2.420, -1.597, 20.0, 40.0), (-0.492, 0.242, 2.270, 20.0, 40.0), (4.245, 1.689, 1.744, 20.0, 40.0), (2.234, 0.364, 3.389, 20.0, 40.0), (2.629, 2.224, -0.134, 20.0, 40.0), (-4.375, 1.221, -1.630, 20.0, 40.0), (-0.618, 1.374, -0.057, 20.0, 40.0), (-2.580, -1.604, 4.918, 20.0, 40.0), (0.159, 1.104, 0.871, 20.0, 40.0), (-3.597, 0.975, -0.749, 20.0, 40.0); DROP TABLE IF EXISTS test.model; -create table test.model engine = Memory as select LinearRegressionState(0.1, 5, 2.0)(target, param1, param2) as state from test.defaults; +create table test.model engine = Memory as select LinearRegressionState(0.01, 5, 1.0)(target, param1, param2) as state from test.defaults; with (select state from test.model) as model select evalMLMethod(model, predict1, predict2) from test.defaults; diff --git a/dbms/tests/queries/0_stateless/00954_ml_test.sql b/dbms/tests/queries/0_stateless/00954_ml_test.sql index b1bce54036f..2df84c00d3f 100644 --- a/dbms/tests/queries/0_stateless/00954_ml_test.sql +++ b/dbms/tests/queries/0_stateless/00954_ml_test.sql @@ -11,7 +11,7 @@ CREATE TABLE IF NOT EXISTS test.defaults insert into test.defaults values (-3.273, -1.452, 4.267, 20.0, 40.0), (0.121, -0.615, 4.290, 20.0, 40.0), (-1.099, 2.755, -3.060, 20.0, 40.0), (1.090, 2.945, -2.346, 20.0, 40.0), (0.305, 2.179, -1.205, 20.0, 40.0), (-0.925, 0.702, 1.134, 20.0, 40.0), (3.178, -1.316, 7.221, 20.0, 40.0), (-2.756, -0.473, 2.569, 20.0, 40.0), (3.665, 2.303, 0.226, 20.0, 40.0), (1.662, 1.951, -0.070, 20.0, 40.0), (2.869, 0.593, 3.249, 20.0, 40.0), (0.818, -0.593, 4.594, 20.0, 40.0), (-1.917, 0.916, 0.209, 20.0, 40.0), (2.706, 1.523, 1.307, 20.0, 40.0), (0.219, 2.162, -1.214, 20.0, 40.0), (-4.510, 1.376, -2.007, 20.0, 40.0), (4.284, -0.515, 6.173, 20.0, 40.0), (-1.101, 2.810, -3.170, 20.0, 40.0), (-1.810, -1.117, 4.329, 20.0, 40.0), (0.055, 1.115, 0.797, 20.0, 40.0), (-2.178, 2.904, -3.898, 20.0, 40.0), (-3.494, -1.814, 4.882, 20.0, 40.0), (3.027, 0.476, 3.562, 20.0, 40.0), (-1.434, 1.151, -0.018, 20.0, 40.0), (1.180, 0.992, 1.606, 20.0, 40.0), (0.015, 0.971, 1.067, 20.0, 40.0), (-0.511, -0.875, 4.495, 20.0, 40.0), (0.961, 2.348, -1.216, 20.0, 40.0), (-2.279, 0.038, 1.785, 20.0, 40.0), (-1.568, -0.248, 2.712, 20.0, 40.0), (-0.496, 0.366, 2.020, 20.0, 40.0), (1.177, -1.401, 6.390, 20.0, 40.0), (2.882, -1.442, 7.325, 20.0, 40.0), (-1.066, 1.817, -1.167, 20.0, 40.0), (-2.144, 2.791, -3.655, 20.0, 40.0), (-4.370, 2.228, -3.642, 20.0, 40.0), (3.996, 2.775, -0.553, 20.0, 40.0), (0.289, 2.055, -0.965, 20.0, 40.0), (-0.588, -1.601, 5.908, 20.0, 40.0), (-1.801, 0.417, 1.265, 20.0, 40.0), (4.375, -1.499, 8.186, 20.0, 40.0), (-2.618, 0.038, 1.615, 20.0, 40.0), (3.616, -0.833, 6.475, 20.0, 40.0), (-4.045, -1.558, 4.094, 20.0, 40.0), (-3.962, 0.636, -0.253, 20.0, 40.0), (3.505, 2.625, -0.497, 20.0, 40.0), (3.029, -0.523, 5.560, 20.0, 40.0), (-3.520, -0.474, 2.188, 20.0, 40.0), (2.430, -1.469, 7.154, 20.0, 40.0), (1.547, -1.654, 7.082, 20.0, 40.0), (-1.370, 0.575, 1.165, 20.0, 40.0), (-1.869, -1.555, 5.176, 20.0, 40.0), (3.536, 2.841, -0.913, 20.0, 40.0), (-3.810, 1.220, -1.344, 20.0, 40.0), (-1.971, 1.462, -0.910, 20.0, 40.0), (-0.243, 0.167, 2.545, 20.0, 40.0), (-1.403, 2.645, -2.991, 20.0, 40.0), (0.532, -0.114, 3.494, 20.0, 40.0), (-1.678, 0.975, 0.212, 20.0, 40.0), (-0.656, 2.140, -1.609, 20.0, 40.0), (1.743, 2.631, -1.390, 20.0, 40.0), (2.586, 2.943, -1.593, 20.0, 40.0), (-0.512, 2.969, -3.195, 20.0, 40.0), (2.283, -0.100, 4.342, 20.0, 40.0), (-4.293, 0.872, -0.890, 20.0, 40.0), (3.411, 1.300, 2.106, 20.0, 40.0), (-0.281, 2.951, -3.042, 20.0, 40.0), (-4.442, 0.384, 0.012, 20.0, 40.0), (1.194, 1.746, 0.104, 20.0, 40.0), (-1.152, 1.862, -1.300, 20.0, 40.0), (1.362, -1.341, 6.363, 20.0, 40.0), (-4.488, 2.618, -4.481, 20.0, 40.0), (3.419, -0.564, 5.837, 20.0, 40.0), (-3.392, 0.396, 0.512, 20.0, 40.0), (-1.629, -0.909, 4.003, 20.0, 40.0), (4.447, -1.088, 7.399, 20.0, 40.0), (-1.232, 1.699, -1.014, 20.0, 40.0), (-1.286, -0.609, 3.575, 20.0, 40.0), (2.437, 2.796, -1.374, 20.0, 40.0), (-4.864, 1.989, -3.410, 20.0, 40.0), (-1.716, -1.399, 4.940, 20.0, 40.0), (-3.084, 1.858, -2.259, 20.0, 40.0), (2.828, -0.319, 5.053, 20.0, 40.0), (-1.226, 2.586, -2.786, 20.0, 40.0), (2.456, 0.092, 4.044, 20.0, 40.0), (-0.989, 2.375, -2.245, 20.0, 40.0), (3.268, 0.935, 2.765, 20.0, 40.0), (-4.128, -1.995, 4.927, 20.0, 40.0), (-1.083, 2.197, -1.935, 20.0, 40.0), (-3.471, -1.198, 3.660, 20.0, 40.0), (4.617, -1.136, 7.579, 20.0, 40.0), (2.054, -1.675, 7.378, 20.0, 40.0), (4.106, 2.326, 0.402, 20.0, 40.0), (1.558, 0.310, 3.158, 20.0, 40.0), (0.792, 0.900, 1.596, 20.0, 40.0), (-3.229, 0.300, 0.785, 20.0, 40.0), (3.787, -0.793, 6.479, 20.0, 40.0), (1.786, 2.288, -0.684, 20.0, 40.0), (2.643, 0.223, 3.875, 20.0, 40.0), (-3.592, 2.122, -3.040, 20.0, 40.0), (4.519, -1.760, 8.779, 20.0, 40.0), (3.221, 2.255, 0.101, 20.0, 40.0), (4.151, 1.788, 1.500, 20.0, 40.0), (-1.033, -1.195, 4.874, 20.0, 40.0), (-1.636, -1.037, 4.257, 20.0, 40.0), (-3.548, 1.911, -2.596, 20.0, 40.0), (4.829, -0.293, 6.001, 20.0, 40.0), (-4.684, -1.664, 3.986, 20.0, 40.0), (4.531, -0.503, 6.271, 20.0, 40.0), (-3.503, -1.606, 4.460, 20.0, 40.0), (-2.036, -1.522, 5.027, 20.0, 40.0), (-0.473, -0.617, 3.997, 20.0, 40.0), (-1.554, -1.630, 5.483, 20.0, 40.0), (-3.567, -1.043, 3.302, 20.0, 40.0), (-2.038, 0.579, 0.823, 20.0, 40.0), (-3.040, 0.857, -0.233, 20.0, 40.0), (4.610, 0.562, 4.181, 20.0, 40.0), (-3.323, -1.938, 5.215, 20.0, 40.0), (4.314, 1.720, 1.717, 20.0, 40.0), (-1.220, 0.615, 1.161, 20.0, 40.0), (-2.556, 1.120, -0.519, 20.0, 40.0), (-3.717, -0.108, 1.358, 20.0, 40.0), (4.689, -1.826, 8.996, 20.0, 40.0), (3.452, 0.506, 3.713, 20.0, 40.0), (2.472, 0.612, 3.012, 20.0, 40.0), (3.452, 0.450, 3.826, 20.0, 40.0), (1.207, 2.585, -1.567, 20.0, 40.0), (-4.826, 1.090, -1.593, 20.0, 40.0), (3.116, -1.118, 6.794, 20.0, 40.0), (0.448, 2.732, -2.240, 20.0, 40.0), (-1.096, -0.525, 3.503, 20.0, 40.0), (-4.680, -0.238, 1.137, 20.0, 40.0), (2.552, -1.403, 7.082, 20.0, 40.0), (0.719, 2.997, -2.635, 20.0, 40.0), (0.347, -1.966, 7.105, 20.0, 40.0), (2.958, -0.404, 5.288, 20.0, 40.0), (0.722, -1.950, 7.261, 20.0, 40.0), (-2.851, -0.986, 3.546, 20.0, 40.0), (-4.316, -0.439, 1.721, 20.0, 40.0), (-1.685, -0.201, 2.560, 20.0, 40.0), (1.856, 0.190, 3.549, 20.0, 40.0), (-2.052, 0.206, 1.562, 20.0, 40.0), (-2.504, -0.646, 3.041, 20.0, 40.0), (3.235, 0.882, 2.854, 20.0, 40.0), (-1.366, -1.573, 5.463, 20.0, 40.0), (-3.447, 2.419, -3.562, 20.0, 40.0), (4.155, 2.092, 0.893, 20.0, 40.0), (-0.935, 0.209, 2.116, 20.0, 40.0), (3.117, -1.821, 8.201, 20.0, 40.0), (3.759, 0.577, 3.725, 20.0, 40.0), (-0.938, 2.992, -3.453, 20.0, 40.0), (-0.525, 2.341, -1.945, 20.0, 40.0), (4.540, 2.625, 0.019, 20.0, 40.0), (-2.097, 1.190, -0.429, 20.0, 40.0), (-2.672, 1.983, -2.302, 20.0, 40.0), (-3.038, -1.490, 4.460, 20.0, 40.0), (-0.943, 2.149, -1.770, 20.0, 40.0), (0.739, 1.598, 0.174, 20.0, 40.0), (1.828, 1.853, 0.208, 20.0, 40.0), (4.856, 0.137, 5.153, 20.0, 40.0), (-1.617, 0.468, 1.255, 20.0, 40.0), (-1.972, 2.053, -2.092, 20.0, 40.0), (-4.633, 1.389, -2.094, 20.0, 40.0), (-3.628, -1.156, 3.498, 20.0, 40.0), (3.597, 1.034, 2.731, 20.0, 40.0), (-1.488, -0.002, 2.261, 20.0, 40.0), (0.749, 1.921, -0.468, 20.0, 40.0), (1.304, -1.371, 6.394, 20.0, 40.0), (4.587, 2.936, -0.579, 20.0, 40.0), (-2.241, 1.791, -1.703, 20.0, 40.0), (-2.945, 1.372, -1.216, 20.0, 40.0), (1.375, 0.395, 2.898, 20.0, 40.0), (-1.281, -0.641, 3.642, 20.0, 40.0), (2.178, 0.895, 2.299, 20.0, 40.0), (3.031, -0.786, 6.087, 20.0, 40.0), (-1.385, -0.375, 3.058, 20.0, 40.0), (4.041, -0.431, 5.882, 20.0, 40.0), (0.480, -0.507, 4.254, 20.0, 40.0), (-3.797, 0.140, 0.822, 20.0, 40.0), (2.355, 2.502, -0.827, 20.0, 40.0), (1.376, -1.583, 6.854, 20.0, 40.0), (0.164, 1.405, 0.273, 20.0, 40.0), (-1.273, 1.471, -0.579, 20.0, 40.0), (0.770, 2.246, -1.107, 20.0, 40.0), (4.552, 2.904, -0.533, 20.0, 40.0), (4.259, -1.772, 8.674, 20.0, 40.0), (-0.309, 1.159, 0.528, 20.0, 40.0), (3.581, 2.700, -0.610, 20.0, 40.0), (-3.202, 0.346, 0.707, 20.0, 40.0), (-1.575, 1.242, -0.271, 20.0, 40.0), (-1.584, -0.493, 3.194, 20.0, 40.0), (-3.778, 0.150, 0.810, 20.0, 40.0), (-4.675, 1.749, -2.835, 20.0, 40.0), (3.567, -0.792, 6.367, 20.0, 40.0), (-0.417, 1.399, -0.006, 20.0, 40.0), (-4.672, 2.007, -3.349, 20.0, 40.0), (-1.034, 0.196, 2.090, 20.0, 40.0), (-3.796, 2.496, -3.890, 20.0, 40.0), (3.532, -0.497, 5.759, 20.0, 40.0), (4.868, -1.359, 8.151, 20.0, 40.0), (-0.769, 0.302, 2.011, 20.0, 40.0), (4.475, 2.612, 0.014, 20.0, 40.0), (-3.532, -0.395, 2.024, 20.0, 40.0), (0.322, 0.675, 1.812, 20.0, 40.0), (-2.028, -1.942, 5.870, 20.0, 40.0), (1.810, -1.244, 6.392, 20.0, 40.0), (-0.783, 1.242, 0.124, 20.0, 40.0), (-4.745, -1.300, 3.227, 20.0, 40.0), (1.902, 1.973, 0.005, 20.0, 40.0), (-3.453, -1.429, 4.132, 20.0, 40.0), (1.559, 0.986, 1.808, 20.0, 40.0), (0.128, 2.754, -2.443, 20.0, 40.0), (2.759, 1.727, 0.926, 20.0, 40.0), (-4.468, 1.690, -2.614, 20.0, 40.0), (-2.368, -1.922, 5.659, 20.0, 40.0), (-2.766, 2.128, -2.640, 20.0, 40.0), (0.967, -1.825, 7.133, 20.0, 40.0), (-2.854, 2.855, -4.136, 20.0, 40.0), (-2.944, 1.875, -2.222, 20.0, 40.0), (-2.632, -0.983, 3.649, 20.0, 40.0), (2.427, 2.239, -0.266, 20.0, 40.0), (-1.726, -0.838, 3.812, 20.0, 40.0), (0.007, -0.903, 4.809, 20.0, 40.0), (-2.013, 1.092, -0.191, 20.0, 40.0), (-0.449, 0.970, 0.836, 20.0, 40.0), (1.396, 0.411, 2.876, 20.0, 40.0), (-1.115, -1.790, 6.023, 20.0, 40.0), (3.748, 1.917, 1.039, 20.0, 40.0), (2.978, 1.043, 2.404, 20.0, 40.0), (-3.969, 2.514, -4.013, 20.0, 40.0), (4.455, -0.050, 5.328, 20.0, 40.0), (-3.065, -0.846, 3.160, 20.0, 40.0), (-1.069, 2.167, -1.869, 20.0, 40.0), (3.016, -1.393, 7.294, 20.0, 40.0), (0.045, -1.928, 6.879, 20.0, 40.0), (-2.555, -0.984, 3.690, 20.0, 40.0), (-1.995, -0.054, 2.111, 20.0, 40.0), (4.600, -0.509, 6.318, 20.0, 40.0), (-1.942, 1.215, -0.402, 20.0, 40.0), (1.262, 2.765, -1.899, 20.0, 40.0), (2.617, -1.106, 6.521, 20.0, 40.0), (1.737, 0.554, 2.761, 20.0, 40.0), (-2.197, 0.632, 0.638, 20.0, 40.0), (4.768, 2.618, 0.147, 20.0, 40.0), (-3.737, -0.939, 3.010, 20.0, 40.0), (-2.623, 0.595, 0.499, 20.0, 40.0), (4.752, -0.340, 6.057, 20.0, 40.0), (2.333, -1.037, 6.240, 20.0, 40.0), (4.234, -1.882, 8.881, 20.0, 40.0), (-3.393, -0.812, 2.927, 20.0, 40.0), (0.885, 1.383, 0.678, 20.0, 40.0), (0.123, 2.937, -2.812, 20.0, 40.0), (2.969, 0.760, 2.964, 20.0, 40.0), (-4.929, 1.251, -1.967, 20.0, 40.0), (1.916, 2.223, -0.488, 20.0, 40.0), (-0.020, -1.740, 6.469, 20.0, 40.0), (0.702, -1.272, 5.895, 20.0, 40.0), (2.496, 2.648, -1.048, 20.0, 40.0), (4.067, -1.475, 7.984, 20.0, 40.0), (-3.717, 1.851, -2.561, 20.0, 40.0), (1.678, -0.624, 5.088, 20.0, 40.0), (1.073, 0.695, 2.146, 20.0, 40.0), (1.842, -0.749, 5.419, 20.0, 40.0), (-3.518, 1.909, -2.578, 20.0, 40.0), (2.229, 1.189, 1.737, 20.0, 40.0), (4.987, 2.893, -0.292, 20.0, 40.0), (-4.809, 1.043, -1.490, 20.0, 40.0), (-0.241, -0.728, 4.334, 20.0, 40.0), (-3.331, 0.590, 0.156, 20.0, 40.0), (-0.455, 2.621, -2.470, 20.0, 40.0), (1.492, 1.223, 1.301, 20.0, 40.0), (3.948, 2.841, -0.709, 20.0, 40.0), (0.732, 0.446, 2.475, 20.0, 40.0), (2.400, 2.390, -0.579, 20.0, 40.0), (-2.718, 1.427, -1.213, 20.0, 40.0), (-1.826, 1.451, -0.815, 20.0, 40.0), (1.125, 0.438, 2.686, 20.0, 40.0), (-4.918, 1.880, -3.219, 20.0, 40.0), (3.068, -0.442, 5.418, 20.0, 40.0), (1.982, 1.201, 1.589, 20.0, 40.0), (0.701, -1.709, 6.768, 20.0, 40.0), (-1.496, 2.564, -2.877, 20.0, 40.0), (-3.812, 0.974, -0.853, 20.0, 40.0), (-3.405, 2.018, -2.739, 20.0, 40.0), (2.211, 2.889, -1.674, 20.0, 40.0), (-2.481, 2.931, -4.103, 20.0, 40.0), (-3.721, 2.765, -4.391, 20.0, 40.0), (-1.768, -1.292, 4.699, 20.0, 40.0), (-4.462, 1.058, -1.347, 20.0, 40.0), (-3.516, -1.942, 5.126, 20.0, 40.0), (0.485, 2.420, -1.597, 20.0, 40.0), (-0.492, 0.242, 2.270, 20.0, 40.0), (4.245, 1.689, 1.744, 20.0, 40.0), (2.234, 0.364, 3.389, 20.0, 40.0), (2.629, 2.224, -0.134, 20.0, 40.0), (-4.375, 1.221, -1.630, 20.0, 40.0), (-0.618, 1.374, -0.057, 20.0, 40.0), (-2.580, -1.604, 4.918, 20.0, 40.0), (0.159, 1.104, 0.871, 20.0, 40.0), (-3.597, 0.975, -0.749, 20.0, 40.0); DROP TABLE IF EXISTS test.model; -create table test.model engine = Memory as select LinearRegressionState(0.1, 5, 3.0)(target, param1, param2) as state from test.defaults; +create table test.model engine = Memory as select LinearRegressionState(0.1, 5, 1.0)(target, param1, param2) as state from test.defaults; with (select state from test.model) as model select evalMLMethod(model, predict1, predict2) from test.defaults; diff --git a/dbms/tests/queries/0_stateless/00955_ml_test.sql b/dbms/tests/queries/0_stateless/00955_ml_test.sql index fb919e5aa57..2df84c00d3f 100644 --- a/dbms/tests/queries/0_stateless/00955_ml_test.sql +++ b/dbms/tests/queries/0_stateless/00955_ml_test.sql @@ -11,7 +11,7 @@ CREATE TABLE IF NOT EXISTS test.defaults insert into test.defaults values (-3.273, -1.452, 4.267, 20.0, 40.0), (0.121, -0.615, 4.290, 20.0, 40.0), (-1.099, 2.755, -3.060, 20.0, 40.0), (1.090, 2.945, -2.346, 20.0, 40.0), (0.305, 2.179, -1.205, 20.0, 40.0), (-0.925, 0.702, 1.134, 20.0, 40.0), (3.178, -1.316, 7.221, 20.0, 40.0), (-2.756, -0.473, 2.569, 20.0, 40.0), (3.665, 2.303, 0.226, 20.0, 40.0), (1.662, 1.951, -0.070, 20.0, 40.0), (2.869, 0.593, 3.249, 20.0, 40.0), (0.818, -0.593, 4.594, 20.0, 40.0), (-1.917, 0.916, 0.209, 20.0, 40.0), (2.706, 1.523, 1.307, 20.0, 40.0), (0.219, 2.162, -1.214, 20.0, 40.0), (-4.510, 1.376, -2.007, 20.0, 40.0), (4.284, -0.515, 6.173, 20.0, 40.0), (-1.101, 2.810, -3.170, 20.0, 40.0), (-1.810, -1.117, 4.329, 20.0, 40.0), (0.055, 1.115, 0.797, 20.0, 40.0), (-2.178, 2.904, -3.898, 20.0, 40.0), (-3.494, -1.814, 4.882, 20.0, 40.0), (3.027, 0.476, 3.562, 20.0, 40.0), (-1.434, 1.151, -0.018, 20.0, 40.0), (1.180, 0.992, 1.606, 20.0, 40.0), (0.015, 0.971, 1.067, 20.0, 40.0), (-0.511, -0.875, 4.495, 20.0, 40.0), (0.961, 2.348, -1.216, 20.0, 40.0), (-2.279, 0.038, 1.785, 20.0, 40.0), (-1.568, -0.248, 2.712, 20.0, 40.0), (-0.496, 0.366, 2.020, 20.0, 40.0), (1.177, -1.401, 6.390, 20.0, 40.0), (2.882, -1.442, 7.325, 20.0, 40.0), (-1.066, 1.817, -1.167, 20.0, 40.0), (-2.144, 2.791, -3.655, 20.0, 40.0), (-4.370, 2.228, -3.642, 20.0, 40.0), (3.996, 2.775, -0.553, 20.0, 40.0), (0.289, 2.055, -0.965, 20.0, 40.0), (-0.588, -1.601, 5.908, 20.0, 40.0), (-1.801, 0.417, 1.265, 20.0, 40.0), (4.375, -1.499, 8.186, 20.0, 40.0), (-2.618, 0.038, 1.615, 20.0, 40.0), (3.616, -0.833, 6.475, 20.0, 40.0), (-4.045, -1.558, 4.094, 20.0, 40.0), (-3.962, 0.636, -0.253, 20.0, 40.0), (3.505, 2.625, -0.497, 20.0, 40.0), (3.029, -0.523, 5.560, 20.0, 40.0), (-3.520, -0.474, 2.188, 20.0, 40.0), (2.430, -1.469, 7.154, 20.0, 40.0), (1.547, -1.654, 7.082, 20.0, 40.0), (-1.370, 0.575, 1.165, 20.0, 40.0), (-1.869, -1.555, 5.176, 20.0, 40.0), (3.536, 2.841, -0.913, 20.0, 40.0), (-3.810, 1.220, -1.344, 20.0, 40.0), (-1.971, 1.462, -0.910, 20.0, 40.0), (-0.243, 0.167, 2.545, 20.0, 40.0), (-1.403, 2.645, -2.991, 20.0, 40.0), (0.532, -0.114, 3.494, 20.0, 40.0), (-1.678, 0.975, 0.212, 20.0, 40.0), (-0.656, 2.140, -1.609, 20.0, 40.0), (1.743, 2.631, -1.390, 20.0, 40.0), (2.586, 2.943, -1.593, 20.0, 40.0), (-0.512, 2.969, -3.195, 20.0, 40.0), (2.283, -0.100, 4.342, 20.0, 40.0), (-4.293, 0.872, -0.890, 20.0, 40.0), (3.411, 1.300, 2.106, 20.0, 40.0), (-0.281, 2.951, -3.042, 20.0, 40.0), (-4.442, 0.384, 0.012, 20.0, 40.0), (1.194, 1.746, 0.104, 20.0, 40.0), (-1.152, 1.862, -1.300, 20.0, 40.0), (1.362, -1.341, 6.363, 20.0, 40.0), (-4.488, 2.618, -4.481, 20.0, 40.0), (3.419, -0.564, 5.837, 20.0, 40.0), (-3.392, 0.396, 0.512, 20.0, 40.0), (-1.629, -0.909, 4.003, 20.0, 40.0), (4.447, -1.088, 7.399, 20.0, 40.0), (-1.232, 1.699, -1.014, 20.0, 40.0), (-1.286, -0.609, 3.575, 20.0, 40.0), (2.437, 2.796, -1.374, 20.0, 40.0), (-4.864, 1.989, -3.410, 20.0, 40.0), (-1.716, -1.399, 4.940, 20.0, 40.0), (-3.084, 1.858, -2.259, 20.0, 40.0), (2.828, -0.319, 5.053, 20.0, 40.0), (-1.226, 2.586, -2.786, 20.0, 40.0), (2.456, 0.092, 4.044, 20.0, 40.0), (-0.989, 2.375, -2.245, 20.0, 40.0), (3.268, 0.935, 2.765, 20.0, 40.0), (-4.128, -1.995, 4.927, 20.0, 40.0), (-1.083, 2.197, -1.935, 20.0, 40.0), (-3.471, -1.198, 3.660, 20.0, 40.0), (4.617, -1.136, 7.579, 20.0, 40.0), (2.054, -1.675, 7.378, 20.0, 40.0), (4.106, 2.326, 0.402, 20.0, 40.0), (1.558, 0.310, 3.158, 20.0, 40.0), (0.792, 0.900, 1.596, 20.0, 40.0), (-3.229, 0.300, 0.785, 20.0, 40.0), (3.787, -0.793, 6.479, 20.0, 40.0), (1.786, 2.288, -0.684, 20.0, 40.0), (2.643, 0.223, 3.875, 20.0, 40.0), (-3.592, 2.122, -3.040, 20.0, 40.0), (4.519, -1.760, 8.779, 20.0, 40.0), (3.221, 2.255, 0.101, 20.0, 40.0), (4.151, 1.788, 1.500, 20.0, 40.0), (-1.033, -1.195, 4.874, 20.0, 40.0), (-1.636, -1.037, 4.257, 20.0, 40.0), (-3.548, 1.911, -2.596, 20.0, 40.0), (4.829, -0.293, 6.001, 20.0, 40.0), (-4.684, -1.664, 3.986, 20.0, 40.0), (4.531, -0.503, 6.271, 20.0, 40.0), (-3.503, -1.606, 4.460, 20.0, 40.0), (-2.036, -1.522, 5.027, 20.0, 40.0), (-0.473, -0.617, 3.997, 20.0, 40.0), (-1.554, -1.630, 5.483, 20.0, 40.0), (-3.567, -1.043, 3.302, 20.0, 40.0), (-2.038, 0.579, 0.823, 20.0, 40.0), (-3.040, 0.857, -0.233, 20.0, 40.0), (4.610, 0.562, 4.181, 20.0, 40.0), (-3.323, -1.938, 5.215, 20.0, 40.0), (4.314, 1.720, 1.717, 20.0, 40.0), (-1.220, 0.615, 1.161, 20.0, 40.0), (-2.556, 1.120, -0.519, 20.0, 40.0), (-3.717, -0.108, 1.358, 20.0, 40.0), (4.689, -1.826, 8.996, 20.0, 40.0), (3.452, 0.506, 3.713, 20.0, 40.0), (2.472, 0.612, 3.012, 20.0, 40.0), (3.452, 0.450, 3.826, 20.0, 40.0), (1.207, 2.585, -1.567, 20.0, 40.0), (-4.826, 1.090, -1.593, 20.0, 40.0), (3.116, -1.118, 6.794, 20.0, 40.0), (0.448, 2.732, -2.240, 20.0, 40.0), (-1.096, -0.525, 3.503, 20.0, 40.0), (-4.680, -0.238, 1.137, 20.0, 40.0), (2.552, -1.403, 7.082, 20.0, 40.0), (0.719, 2.997, -2.635, 20.0, 40.0), (0.347, -1.966, 7.105, 20.0, 40.0), (2.958, -0.404, 5.288, 20.0, 40.0), (0.722, -1.950, 7.261, 20.0, 40.0), (-2.851, -0.986, 3.546, 20.0, 40.0), (-4.316, -0.439, 1.721, 20.0, 40.0), (-1.685, -0.201, 2.560, 20.0, 40.0), (1.856, 0.190, 3.549, 20.0, 40.0), (-2.052, 0.206, 1.562, 20.0, 40.0), (-2.504, -0.646, 3.041, 20.0, 40.0), (3.235, 0.882, 2.854, 20.0, 40.0), (-1.366, -1.573, 5.463, 20.0, 40.0), (-3.447, 2.419, -3.562, 20.0, 40.0), (4.155, 2.092, 0.893, 20.0, 40.0), (-0.935, 0.209, 2.116, 20.0, 40.0), (3.117, -1.821, 8.201, 20.0, 40.0), (3.759, 0.577, 3.725, 20.0, 40.0), (-0.938, 2.992, -3.453, 20.0, 40.0), (-0.525, 2.341, -1.945, 20.0, 40.0), (4.540, 2.625, 0.019, 20.0, 40.0), (-2.097, 1.190, -0.429, 20.0, 40.0), (-2.672, 1.983, -2.302, 20.0, 40.0), (-3.038, -1.490, 4.460, 20.0, 40.0), (-0.943, 2.149, -1.770, 20.0, 40.0), (0.739, 1.598, 0.174, 20.0, 40.0), (1.828, 1.853, 0.208, 20.0, 40.0), (4.856, 0.137, 5.153, 20.0, 40.0), (-1.617, 0.468, 1.255, 20.0, 40.0), (-1.972, 2.053, -2.092, 20.0, 40.0), (-4.633, 1.389, -2.094, 20.0, 40.0), (-3.628, -1.156, 3.498, 20.0, 40.0), (3.597, 1.034, 2.731, 20.0, 40.0), (-1.488, -0.002, 2.261, 20.0, 40.0), (0.749, 1.921, -0.468, 20.0, 40.0), (1.304, -1.371, 6.394, 20.0, 40.0), (4.587, 2.936, -0.579, 20.0, 40.0), (-2.241, 1.791, -1.703, 20.0, 40.0), (-2.945, 1.372, -1.216, 20.0, 40.0), (1.375, 0.395, 2.898, 20.0, 40.0), (-1.281, -0.641, 3.642, 20.0, 40.0), (2.178, 0.895, 2.299, 20.0, 40.0), (3.031, -0.786, 6.087, 20.0, 40.0), (-1.385, -0.375, 3.058, 20.0, 40.0), (4.041, -0.431, 5.882, 20.0, 40.0), (0.480, -0.507, 4.254, 20.0, 40.0), (-3.797, 0.140, 0.822, 20.0, 40.0), (2.355, 2.502, -0.827, 20.0, 40.0), (1.376, -1.583, 6.854, 20.0, 40.0), (0.164, 1.405, 0.273, 20.0, 40.0), (-1.273, 1.471, -0.579, 20.0, 40.0), (0.770, 2.246, -1.107, 20.0, 40.0), (4.552, 2.904, -0.533, 20.0, 40.0), (4.259, -1.772, 8.674, 20.0, 40.0), (-0.309, 1.159, 0.528, 20.0, 40.0), (3.581, 2.700, -0.610, 20.0, 40.0), (-3.202, 0.346, 0.707, 20.0, 40.0), (-1.575, 1.242, -0.271, 20.0, 40.0), (-1.584, -0.493, 3.194, 20.0, 40.0), (-3.778, 0.150, 0.810, 20.0, 40.0), (-4.675, 1.749, -2.835, 20.0, 40.0), (3.567, -0.792, 6.367, 20.0, 40.0), (-0.417, 1.399, -0.006, 20.0, 40.0), (-4.672, 2.007, -3.349, 20.0, 40.0), (-1.034, 0.196, 2.090, 20.0, 40.0), (-3.796, 2.496, -3.890, 20.0, 40.0), (3.532, -0.497, 5.759, 20.0, 40.0), (4.868, -1.359, 8.151, 20.0, 40.0), (-0.769, 0.302, 2.011, 20.0, 40.0), (4.475, 2.612, 0.014, 20.0, 40.0), (-3.532, -0.395, 2.024, 20.0, 40.0), (0.322, 0.675, 1.812, 20.0, 40.0), (-2.028, -1.942, 5.870, 20.0, 40.0), (1.810, -1.244, 6.392, 20.0, 40.0), (-0.783, 1.242, 0.124, 20.0, 40.0), (-4.745, -1.300, 3.227, 20.0, 40.0), (1.902, 1.973, 0.005, 20.0, 40.0), (-3.453, -1.429, 4.132, 20.0, 40.0), (1.559, 0.986, 1.808, 20.0, 40.0), (0.128, 2.754, -2.443, 20.0, 40.0), (2.759, 1.727, 0.926, 20.0, 40.0), (-4.468, 1.690, -2.614, 20.0, 40.0), (-2.368, -1.922, 5.659, 20.0, 40.0), (-2.766, 2.128, -2.640, 20.0, 40.0), (0.967, -1.825, 7.133, 20.0, 40.0), (-2.854, 2.855, -4.136, 20.0, 40.0), (-2.944, 1.875, -2.222, 20.0, 40.0), (-2.632, -0.983, 3.649, 20.0, 40.0), (2.427, 2.239, -0.266, 20.0, 40.0), (-1.726, -0.838, 3.812, 20.0, 40.0), (0.007, -0.903, 4.809, 20.0, 40.0), (-2.013, 1.092, -0.191, 20.0, 40.0), (-0.449, 0.970, 0.836, 20.0, 40.0), (1.396, 0.411, 2.876, 20.0, 40.0), (-1.115, -1.790, 6.023, 20.0, 40.0), (3.748, 1.917, 1.039, 20.0, 40.0), (2.978, 1.043, 2.404, 20.0, 40.0), (-3.969, 2.514, -4.013, 20.0, 40.0), (4.455, -0.050, 5.328, 20.0, 40.0), (-3.065, -0.846, 3.160, 20.0, 40.0), (-1.069, 2.167, -1.869, 20.0, 40.0), (3.016, -1.393, 7.294, 20.0, 40.0), (0.045, -1.928, 6.879, 20.0, 40.0), (-2.555, -0.984, 3.690, 20.0, 40.0), (-1.995, -0.054, 2.111, 20.0, 40.0), (4.600, -0.509, 6.318, 20.0, 40.0), (-1.942, 1.215, -0.402, 20.0, 40.0), (1.262, 2.765, -1.899, 20.0, 40.0), (2.617, -1.106, 6.521, 20.0, 40.0), (1.737, 0.554, 2.761, 20.0, 40.0), (-2.197, 0.632, 0.638, 20.0, 40.0), (4.768, 2.618, 0.147, 20.0, 40.0), (-3.737, -0.939, 3.010, 20.0, 40.0), (-2.623, 0.595, 0.499, 20.0, 40.0), (4.752, -0.340, 6.057, 20.0, 40.0), (2.333, -1.037, 6.240, 20.0, 40.0), (4.234, -1.882, 8.881, 20.0, 40.0), (-3.393, -0.812, 2.927, 20.0, 40.0), (0.885, 1.383, 0.678, 20.0, 40.0), (0.123, 2.937, -2.812, 20.0, 40.0), (2.969, 0.760, 2.964, 20.0, 40.0), (-4.929, 1.251, -1.967, 20.0, 40.0), (1.916, 2.223, -0.488, 20.0, 40.0), (-0.020, -1.740, 6.469, 20.0, 40.0), (0.702, -1.272, 5.895, 20.0, 40.0), (2.496, 2.648, -1.048, 20.0, 40.0), (4.067, -1.475, 7.984, 20.0, 40.0), (-3.717, 1.851, -2.561, 20.0, 40.0), (1.678, -0.624, 5.088, 20.0, 40.0), (1.073, 0.695, 2.146, 20.0, 40.0), (1.842, -0.749, 5.419, 20.0, 40.0), (-3.518, 1.909, -2.578, 20.0, 40.0), (2.229, 1.189, 1.737, 20.0, 40.0), (4.987, 2.893, -0.292, 20.0, 40.0), (-4.809, 1.043, -1.490, 20.0, 40.0), (-0.241, -0.728, 4.334, 20.0, 40.0), (-3.331, 0.590, 0.156, 20.0, 40.0), (-0.455, 2.621, -2.470, 20.0, 40.0), (1.492, 1.223, 1.301, 20.0, 40.0), (3.948, 2.841, -0.709, 20.0, 40.0), (0.732, 0.446, 2.475, 20.0, 40.0), (2.400, 2.390, -0.579, 20.0, 40.0), (-2.718, 1.427, -1.213, 20.0, 40.0), (-1.826, 1.451, -0.815, 20.0, 40.0), (1.125, 0.438, 2.686, 20.0, 40.0), (-4.918, 1.880, -3.219, 20.0, 40.0), (3.068, -0.442, 5.418, 20.0, 40.0), (1.982, 1.201, 1.589, 20.0, 40.0), (0.701, -1.709, 6.768, 20.0, 40.0), (-1.496, 2.564, -2.877, 20.0, 40.0), (-3.812, 0.974, -0.853, 20.0, 40.0), (-3.405, 2.018, -2.739, 20.0, 40.0), (2.211, 2.889, -1.674, 20.0, 40.0), (-2.481, 2.931, -4.103, 20.0, 40.0), (-3.721, 2.765, -4.391, 20.0, 40.0), (-1.768, -1.292, 4.699, 20.0, 40.0), (-4.462, 1.058, -1.347, 20.0, 40.0), (-3.516, -1.942, 5.126, 20.0, 40.0), (0.485, 2.420, -1.597, 20.0, 40.0), (-0.492, 0.242, 2.270, 20.0, 40.0), (4.245, 1.689, 1.744, 20.0, 40.0), (2.234, 0.364, 3.389, 20.0, 40.0), (2.629, 2.224, -0.134, 20.0, 40.0), (-4.375, 1.221, -1.630, 20.0, 40.0), (-0.618, 1.374, -0.057, 20.0, 40.0), (-2.580, -1.604, 4.918, 20.0, 40.0), (0.159, 1.104, 0.871, 20.0, 40.0), (-3.597, 0.975, -0.749, 20.0, 40.0); DROP TABLE IF EXISTS test.model; -create table test.model engine = Memory as select LinearRegressionState(0.1, 5, 4.0)(target, param1, param2) as state from test.defaults; +create table test.model engine = Memory as select LinearRegressionState(0.1, 5, 1.0)(target, param1, param2) as state from test.defaults; with (select state from test.model) as model select evalMLMethod(model, predict1, predict2) from test.defaults;