mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-21 15:12:02 +00:00
Nesterov and Adam + tests
This commit is contained in:
parent
5defd7a77c
commit
67b28c2240
@ -61,6 +61,12 @@ AggregateFunctionPtr createAggregateFunctionMLMethod(
|
||||
} else if (applyVisitor(FieldVisitorConvertToNumber<UInt32>(), parameters[2]) == Float64{2.0})
|
||||
{
|
||||
wu = std::make_shared<Momentum>();
|
||||
} else if (applyVisitor(FieldVisitorConvertToNumber<UInt32>(), parameters[2]) == Float64{3.0})
|
||||
{
|
||||
wu = std::make_shared<Nesterov>();
|
||||
} else if (applyVisitor(FieldVisitorConvertToNumber<UInt32>(), parameters[2]) == Float64{4.0})
|
||||
{
|
||||
wu = std::make_shared<Adam>();
|
||||
} else
|
||||
{
|
||||
throw Exception("Such weights updater is not implemented yet", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
@ -210,6 +210,9 @@ public:
|
||||
|
||||
virtual void update(UInt32 batch_size, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & gradient) = 0;
|
||||
virtual void merge(const IWeightsUpdater &, Float64, Float64) {}
|
||||
virtual std::vector<Float64> get_update(UInt32 sz, UInt32) {
|
||||
return std::vector<Float64>(sz, 0.0);
|
||||
}
|
||||
};
|
||||
|
||||
class StochasticGradientDescent : public IWeightsUpdater
|
||||
@ -260,6 +263,100 @@ private:
|
||||
Float64 alpha_{0.1};
|
||||
std::vector<Float64> accumulated_gradient;
|
||||
};
|
||||
class Nesterov : public IWeightsUpdater
|
||||
{
|
||||
public:
|
||||
Nesterov() {}
|
||||
Nesterov (Float64 alpha) : alpha_(alpha) {}
|
||||
void update(UInt32 batch_size, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & batch_gradient) override {
|
||||
if (accumulated_gradient.size() == 0)
|
||||
{
|
||||
accumulated_gradient.resize(batch_gradient.size(), Float64{0.0});
|
||||
}
|
||||
for (size_t i = 0; i < batch_gradient.size(); ++i)
|
||||
{
|
||||
accumulated_gradient[i] = accumulated_gradient[i] * alpha_ + batch_gradient[i];
|
||||
}
|
||||
for (size_t i = 0; i < weights.size(); ++i)
|
||||
{
|
||||
weights[i] += accumulated_gradient[i] / batch_size;
|
||||
}
|
||||
bias += accumulated_gradient[weights.size()] / batch_size;
|
||||
std::cout<<"BIAS " << bias<<'\n';
|
||||
}
|
||||
virtual void merge(const IWeightsUpdater & rhs, Float64 frac, Float64 rhs_frac) override {
|
||||
auto & nesterov_rhs = static_cast<const Nesterov &>(rhs);
|
||||
for (size_t i = 0; i < accumulated_gradient.size(); ++i)
|
||||
{
|
||||
accumulated_gradient[i] = accumulated_gradient[i] * frac + nesterov_rhs.accumulated_gradient[i] * rhs_frac;
|
||||
}
|
||||
}
|
||||
virtual std::vector<Float64> get_update(UInt32 sz, UInt32 batch_size) override {
|
||||
if (accumulated_gradient.size() == 0)
|
||||
{
|
||||
accumulated_gradient.resize(sz, Float64{0.0});
|
||||
return accumulated_gradient;
|
||||
}
|
||||
std::vector<Float64> delta(accumulated_gradient.size());
|
||||
// std::cout<<"\n\nHK\n\n";
|
||||
for (size_t i = 0; i < delta.size(); ++i)
|
||||
{
|
||||
delta[i] = accumulated_gradient[i] * alpha_ / batch_size;
|
||||
}
|
||||
return delta;
|
||||
}
|
||||
|
||||
Float64 alpha_{0.1};
|
||||
std::vector<Float64> accumulated_gradient;
|
||||
};
|
||||
class Adam : public IWeightsUpdater
|
||||
{
|
||||
public:
|
||||
Adam() {}
|
||||
Adam (Float64 betta1, Float64 betta2) : betta1_(betta1), betta2_(betta2), betta1t_(betta1), betta2t_(betta2) {}
|
||||
void update(UInt32 cur_batch, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & batch_gradient) override {
|
||||
if (mt_.size() == 0)
|
||||
{
|
||||
mt_.resize(batch_gradient.size(), Float64{0.0});
|
||||
vt_.resize(batch_gradient.size(), Float64{0.0});
|
||||
}
|
||||
Float64 eps = 0.01;
|
||||
for (size_t i = 0; i < batch_gradient.size(); ++i)
|
||||
{
|
||||
mt_[i] = mt_[i] * betta1_ + (1 - betta1_) * batch_gradient[i];
|
||||
vt_[i] = vt_[i] * betta2_ + (1 - betta2_) * batch_gradient[i] * batch_gradient[i];
|
||||
if (t < 8) {
|
||||
mt_[i] = mt_[i] / (1 - betta1t_);
|
||||
betta1t_ *= betta1_;
|
||||
}
|
||||
if (t < 850) {
|
||||
vt_[i] = vt_[i] / (1 - betta2t_);
|
||||
betta2t_ *= betta2_;
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < weights.size(); ++i)
|
||||
{
|
||||
weights[i] += (mt_[i] / (sqrt(vt_[i] + eps))) / cur_batch;
|
||||
}
|
||||
bias += (mt_[weights.size()] / (sqrt(vt_[weights.size()] + eps))) / cur_batch;
|
||||
t += 1;
|
||||
}
|
||||
virtual void merge(const IWeightsUpdater & rhs, Float64 frac, Float64 rhs_frac) override {
|
||||
auto & adam_rhs = static_cast<const Adam &>(rhs);
|
||||
for (size_t i = 0; i < mt_.size(); ++i)
|
||||
{
|
||||
mt_[i] = mt_[i] * frac + adam_rhs.mt_[i] * rhs_frac;
|
||||
vt_[i] = vt_[i] * frac + adam_rhs.vt_[i] * rhs_frac;
|
||||
}
|
||||
}
|
||||
Float64 betta1_{0.2};
|
||||
Float64 betta2_{0.3};
|
||||
Float64 betta1t_{0.3};
|
||||
Float64 betta2t_{0.3};
|
||||
UInt32 t = 0;
|
||||
std::vector<Float64> mt_;
|
||||
std::vector<Float64> vt_;
|
||||
};
|
||||
|
||||
class LinearModelData
|
||||
{
|
||||
@ -285,8 +382,16 @@ public:
|
||||
{
|
||||
/// first column stores target; features start from (columns + 1)
|
||||
const auto & target = static_cast<const ColumnVector<Float64> &>(*columns[0]).getData()[row_num];
|
||||
|
||||
gradient_computer->compute(weights, bias, learning_rate, target, columns + 1, row_num);
|
||||
|
||||
auto delta = weights_updater->get_update(weights.size() + 1, batch_capacity);
|
||||
Float64 delta_bias = bias + delta[weights.size()];
|
||||
delta.resize(weights.size());
|
||||
for (size_t i = 0; i < weights.size(); ++i) {
|
||||
delta[i] += weights[i];
|
||||
}
|
||||
|
||||
gradient_computer->compute(delta, delta_bias, learning_rate, target, columns + 1, row_num);
|
||||
// gradient_computer->compute(weights, bias, learning_rate, target, columns + 1, row_num);
|
||||
++batch_size;
|
||||
if (batch_size == batch_capacity)
|
||||
{
|
||||
@ -369,6 +474,7 @@ private:
|
||||
batch_size = 0;
|
||||
++iter_num;
|
||||
gradient_computer->reset();
|
||||
//TODO ask: для нестерова и адама не очень. Нужно добавить другую функцию
|
||||
}
|
||||
};
|
||||
|
||||
|
1
dbms/tests/queries/0_stateless/00954_ml_test.reference
Normal file
1
dbms/tests/queries/0_stateless/00954_ml_test.reference
Normal file
@ -0,0 +1 @@
|
||||
-66.98005053600168
|
17
dbms/tests/queries/0_stateless/00954_ml_test.sql
Normal file
17
dbms/tests/queries/0_stateless/00954_ml_test.sql
Normal file
File diff suppressed because one or more lines are too long
1
dbms/tests/queries/0_stateless/00955_ml_test.reference
Normal file
1
dbms/tests/queries/0_stateless/00955_ml_test.reference
Normal file
@ -0,0 +1 @@
|
||||
-70.73127165094067
|
17
dbms/tests/queries/0_stateless/00955_ml_test.sql
Normal file
17
dbms/tests/queries/0_stateless/00955_ml_test.sql
Normal file
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user