adam is default now

This commit is contained in:
Alexander Kozhikhov 2019-07-17 00:11:10 +03:00
parent 7c54bb0956
commit 52007c96d9
5 changed files with 210 additions and 217 deletions

View File

@ -45,11 +45,11 @@ namespace
/// Such default parameters were picked because they did good on some tests,
/// though it still requires to fit parameters to achieve better result
auto learning_rate = Float64(0.01);
auto l2_reg_coef = Float64(0.1);
auto learning_rate = Float64(1.0);
auto l2_reg_coef = Float64(0.5);
UInt64 batch_size = 15;
std::string weights_updater_name = "SGD";
std::string weights_updater_name = "Adam";
std::unique_ptr<IGradientComputer> gradient_computer;
if (!parameters.empty())
@ -126,7 +126,7 @@ void LinearModelData::update_state()
if (batch_size == 0)
return;
weights_updater->update(batch_size, weights, bias, gradient_batch);
weights_updater->update(batch_size, weights, bias, learning_rate, gradient_batch);
batch_size = 0;
++iter_num;
gradient_batch.assign(gradient_batch.size(), Float64{0.0});
@ -211,7 +211,7 @@ void LinearModelData::add(const IColumn ** columns, size_t row_num)
/// Here we have columns + 1 as first column corresponds to target value, and others - to features
weights_updater->add_to_batch(
gradient_batch, *gradient_computer, weights, bias, learning_rate, l2_reg_coef, target, columns + 1, row_num);
gradient_batch, *gradient_computer, weights, bias, l2_reg_coef, target, columns + 1, row_num);
++batch_size;
if (batch_size == batch_capacity)
@ -256,7 +256,7 @@ void Adam::merge(const IWeightsUpdater & rhs, Float64 frac, Float64 rhs_frac)
beta2_powered_ *= adam_rhs.beta2_powered_;
}
void Adam::update(UInt64 batch_size, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & batch_gradient)
void Adam::update(UInt64 batch_size, std::vector<Float64> & weights, Float64 & bias, Float64 learning_rate, const std::vector<Float64> & batch_gradient)
{
if (average_gradient.empty())
{
@ -267,7 +267,6 @@ void Adam::update(UInt64 batch_size, std::vector<Float64> & weights, Float64 & b
average_squared_gradient.resize(batch_gradient.size(), Float64{0.0});
}
/// batch_gradient already includes learning_rate - bad for squared gradient
for (size_t i = 0; i != average_gradient.size(); ++i)
{
Float64 normed_gradient = batch_gradient[i] / batch_size;
@ -278,10 +277,10 @@ void Adam::update(UInt64 batch_size, std::vector<Float64> & weights, Float64 & b
for (size_t i = 0; i < weights.size(); ++i)
{
weights[i] += average_gradient[i] /
weights[i] += (learning_rate * average_gradient[i]) /
((1 - beta1_powered_) * (sqrt(average_squared_gradient[i] / (1 - beta2_powered_)) + eps_));
}
bias += average_gradient[weights.size()] /
bias += (learning_rate * average_gradient[weights.size()]) /
((1 - beta1_powered_) * (sqrt(average_squared_gradient[weights.size()] / (1 - beta2_powered_)) + eps_));
beta1_powered_ *= beta1_;
@ -293,7 +292,6 @@ void Adam::add_to_batch(
IGradientComputer & gradient_computer,
const std::vector<Float64> & weights,
Float64 bias,
Float64 learning_rate,
Float64 l2_reg_coef,
Float64 target,
const IColumn ** columns,
@ -304,7 +302,7 @@ void Adam::add_to_batch(
average_gradient.resize(batch_gradient.size(), Float64{0.0});
average_squared_gradient.resize(batch_gradient.size(), Float64{0.0});
}
gradient_computer.compute(batch_gradient, weights, bias, learning_rate, l2_reg_coef, target, columns, row_num);
gradient_computer.compute(batch_gradient, weights, bias, l2_reg_coef, target, columns, row_num);
}
void Nesterov::read(ReadBuffer & buf)
@ -329,7 +327,7 @@ void Nesterov::merge(const IWeightsUpdater & rhs, Float64 frac, Float64 rhs_frac
}
}
void Nesterov::update(UInt64 batch_size, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & batch_gradient)
void Nesterov::update(UInt64 batch_size, std::vector<Float64> & weights, Float64 & bias, Float64 learning_rate, const std::vector<Float64> & batch_gradient)
{
if (accumulated_gradient.empty())
{
@ -338,7 +336,7 @@ void Nesterov::update(UInt64 batch_size, std::vector<Float64> & weights, Float64
for (size_t i = 0; i < batch_gradient.size(); ++i)
{
accumulated_gradient[i] = accumulated_gradient[i] * alpha_ + batch_gradient[i] / batch_size;
accumulated_gradient[i] = accumulated_gradient[i] * alpha_ + (learning_rate * batch_gradient[i]) / batch_size;
}
for (size_t i = 0; i < weights.size(); ++i)
{
@ -352,7 +350,6 @@ void Nesterov::add_to_batch(
IGradientComputer & gradient_computer,
const std::vector<Float64> & weights,
Float64 bias,
Float64 learning_rate,
Float64 l2_reg_coef,
Float64 target,
const IColumn ** columns,
@ -370,7 +367,7 @@ void Nesterov::add_to_batch(
}
auto shifted_bias = bias + accumulated_gradient[weights.size()] * alpha_;
gradient_computer.compute(batch_gradient, shifted_weights, shifted_bias, learning_rate, l2_reg_coef, target, columns, row_num);
gradient_computer.compute(batch_gradient, shifted_weights, shifted_bias, l2_reg_coef, target, columns, row_num);
}
void Momentum::read(ReadBuffer & buf)
@ -392,7 +389,7 @@ void Momentum::merge(const IWeightsUpdater & rhs, Float64 frac, Float64 rhs_frac
}
}
void Momentum::update(UInt64 batch_size, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & batch_gradient)
void Momentum::update(UInt64 batch_size, std::vector<Float64> & weights, Float64 & bias, Float64 learning_rate, const std::vector<Float64> & batch_gradient)
{
/// batch_size is already checked to be greater than 0
if (accumulated_gradient.empty())
@ -402,7 +399,7 @@ void Momentum::update(UInt64 batch_size, std::vector<Float64> & weights, Float64
for (size_t i = 0; i < batch_gradient.size(); ++i)
{
accumulated_gradient[i] = accumulated_gradient[i] * alpha_ + batch_gradient[i] / batch_size;
accumulated_gradient[i] = accumulated_gradient[i] * alpha_ + (learning_rate * batch_gradient[i]) / batch_size;
}
for (size_t i = 0; i < weights.size(); ++i)
{
@ -412,14 +409,14 @@ void Momentum::update(UInt64 batch_size, std::vector<Float64> & weights, Float64
}
void StochasticGradientDescent::update(
UInt64 batch_size, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & batch_gradient)
UInt64 batch_size, std::vector<Float64> & weights, Float64 & bias, Float64 learning_rate, const std::vector<Float64> & batch_gradient)
{
/// batch_size is already checked to be greater than 0
for (size_t i = 0; i < weights.size(); ++i)
{
weights[i] += batch_gradient[i] / batch_size;
weights[i] += (learning_rate * batch_gradient[i]) / batch_size;
}
bias += batch_gradient[weights.size()] / batch_size;
bias += (learning_rate * batch_gradient[weights.size()]) / batch_size;
}
void IWeightsUpdater::add_to_batch(
@ -427,13 +424,12 @@ void IWeightsUpdater::add_to_batch(
IGradientComputer & gradient_computer,
const std::vector<Float64> & weights,
Float64 bias,
Float64 learning_rate,
Float64 l2_reg_coef,
Float64 target,
const IColumn ** columns,
size_t row_num)
{
gradient_computer.compute(batch_gradient, weights, bias, learning_rate, l2_reg_coef, target, columns, row_num);
gradient_computer.compute(batch_gradient, weights, bias, l2_reg_coef, target, columns, row_num);
}
/// Gradient computers
@ -479,7 +475,6 @@ void LogisticRegression::compute(
std::vector<Float64> & batch_gradient,
const std::vector<Float64> & weights,
Float64 bias,
Float64 learning_rate,
Float64 l2_reg_coef,
Float64 target,
const IColumn ** columns,
@ -494,11 +489,11 @@ void LogisticRegression::compute(
derivative *= target;
derivative = exp(derivative);
batch_gradient[weights.size()] += learning_rate * target / (derivative + 1);
batch_gradient[weights.size()] += target / (derivative + 1);
for (size_t i = 0; i < weights.size(); ++i)
{
auto value = (*columns[i]).getFloat64(row_num);
batch_gradient[i] += learning_rate * target * value / (derivative + 1) - 2 * learning_rate * l2_reg_coef * weights[i];
batch_gradient[i] += target * value / (derivative + 1) - 2 * l2_reg_coef * weights[i];
}
}
@ -551,7 +546,6 @@ void LinearRegression::compute(
std::vector<Float64> & batch_gradient,
const std::vector<Float64> & weights,
Float64 bias,
Float64 learning_rate,
Float64 l2_reg_coef,
Float64 target,
const IColumn ** columns,
@ -563,13 +557,13 @@ void LinearRegression::compute(
auto value = (*columns[i]).getFloat64(row_num);
derivative -= weights[i] * value;
}
derivative *= (2 * learning_rate);
derivative *= 2;
batch_gradient[weights.size()] += derivative;
for (size_t i = 0; i < weights.size(); ++i)
{
auto value = (*columns[i]).getFloat64(row_num);
batch_gradient[i] += derivative * value - 2 * learning_rate * l2_reg_coef * weights[i];
batch_gradient[i] += derivative * value - 2 * l2_reg_coef * weights[i];
}
}

View File

@ -33,7 +33,6 @@ public:
std::vector<Float64> & batch_gradient,
const std::vector<Float64> & weights,
Float64 bias,
Float64 learning_rate,
Float64 l2_reg_coef,
Float64 target,
const IColumn ** columns,
@ -60,7 +59,6 @@ public:
std::vector<Float64> & batch_gradient,
const std::vector<Float64> & weights,
Float64 bias,
Float64 learning_rate,
Float64 l2_reg_coef,
Float64 target,
const IColumn ** columns,
@ -87,7 +85,6 @@ public:
std::vector<Float64> & batch_gradient,
const std::vector<Float64> & weights,
Float64 bias,
Float64 learning_rate,
Float64 l2_reg_coef,
Float64 target,
const IColumn ** columns,
@ -120,14 +117,18 @@ public:
IGradientComputer & gradient_computer,
const std::vector<Float64> & weights,
Float64 bias,
Float64 learning_rate,
Float64 l2_reg_coef,
Float64 target,
const IColumn ** columns,
size_t row_num);
/// Updates current weights according to the gradient from the last mini-batch
virtual void update(UInt64 batch_size, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & gradient) = 0;
virtual void update(
UInt64 batch_size,
std::vector<Float64> & weights,
Float64 & bias,
Float64 learning_rate,
const std::vector<Float64> & gradient) = 0;
/// Used during the merge of two states
virtual void merge(const IWeightsUpdater &, Float64, Float64) {}
@ -143,7 +144,7 @@ public:
class StochasticGradientDescent : public IWeightsUpdater
{
public:
void update(UInt64 batch_size, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & batch_gradient) override;
void update(UInt64 batch_size, std::vector<Float64> & weights, Float64 & bias, Float64 learning_rate, const std::vector<Float64> & batch_gradient) override;
};
@ -154,7 +155,7 @@ public:
Momentum(Float64 alpha) : alpha_(alpha) {}
void update(UInt64 batch_size, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & batch_gradient) override;
void update(UInt64 batch_size, std::vector<Float64> & weights, Float64 & bias, Float64 learning_rate, const std::vector<Float64> & batch_gradient) override;
virtual void merge(const IWeightsUpdater & rhs, Float64 frac, Float64 rhs_frac) override;
@ -180,13 +181,12 @@ public:
IGradientComputer & gradient_computer,
const std::vector<Float64> & weights,
Float64 bias,
Float64 learning_rate,
Float64 l2_reg_coef,
Float64 target,
const IColumn ** columns,
size_t row_num) override;
void update(UInt64 batch_size, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & batch_gradient) override;
void update(UInt64 batch_size, std::vector<Float64> & weights, Float64 & bias, Float64 learning_rate, const std::vector<Float64> & batch_gradient) override;
virtual void merge(const IWeightsUpdater & rhs, Float64 frac, Float64 rhs_frac) override;
@ -195,7 +195,7 @@ public:
void read(ReadBuffer & buf) override;
private:
Float64 alpha_{0.1};
const Float64 alpha_ = 0.9;
std::vector<Float64> accumulated_gradient;
};
@ -214,13 +214,12 @@ public:
IGradientComputer & gradient_computer,
const std::vector<Float64> & weights,
Float64 bias,
Float64 learning_rate,
Float64 l2_reg_coef,
Float64 target,
const IColumn ** columns,
size_t row_num) override;
void update(UInt64 batch_size, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & batch_gradient) override;
void update(UInt64 batch_size, std::vector<Float64> & weights, Float64 & bias, Float64 learning_rate, const std::vector<Float64> & batch_gradient) override;
virtual void merge(const IWeightsUpdater & rhs, Float64 frac, Float64 rhs_frac) override;

View File

@ -16,7 +16,7 @@ select ans < -61.374 and ans > -61.375 from
(with (select state from remote('127.0.0.1', currentDatabase(), model)) as model select evalMLMethod(model, predict1, predict2) as ans from remote('127.0.0.1', currentDatabase(), defaults));
SELECT 0 < ans[1] and ans[1] < 0.15 and 0.95 < ans[2] and ans[2] < 1.0 and 0 < ans[3] and ans[3] < 0.05 FROM
(SELECT stochasticLinearRegression(0.000001, 0.01, 100)(number, rand() % 100, number) AS ans FROM numbers(1000));
(SELECT stochasticLinearRegression(0.000001, 0.01, 100, 'SGD')(number, rand() % 100, number) AS ans FROM numbers(1000));
DROP TABLE model;
DROP TABLE defaults;

View File

@ -11,43 +11,43 @@
0.6542885368159769
0.6542885368159769
0.6542885368159769
0.8444267125384497
0.8444267125384497
0.8444267125384497
0.8444267125384497
0.8444267125384497
0.8444267125384497
0.8444267125384497
0.8444267125384497
0.8444267125384497
0.8444267125384497
0.8444267125384497
0.8444267125384497
0.9683751248474649
0.9683751248474649
0.9683751248474649
0.9683751248474649
0.9683751248474649
0.9683751248474649
0.9683751248474649
0.9683751248474649
0.9683751248474649
0.9683751248474649
0.9683751248474649
0.9683751248474649
0.7836319925339996
0.7836319925339996
0.7836319925339996
0.7836319925339996
0.7836319925339996
0.7836319925339996
0.7836319925339996
0.7836319925339996
0.7836319925339996
0.7836319925339996
0.7836319925339996
0.7836319925339996
0.7836319925339996
0.8444267125384498
0.8444267125384498
0.8444267125384498
0.8444267125384498
0.8444267125384498
0.8444267125384498
0.8444267125384498
0.8444267125384498
0.8444267125384498
0.8444267125384498
0.8444267125384498
0.8444267125384498
0.968375124847465
0.968375124847465
0.968375124847465
0.968375124847465
0.968375124847465
0.968375124847465
0.968375124847465
0.968375124847465
0.968375124847465
0.968375124847465
0.968375124847465
0.968375124847465
0.7836319925339997
0.7836319925339997
0.7836319925339997
0.7836319925339997
0.7836319925339997
0.7836319925339997
0.7836319925339997
0.7836319925339997
0.7836319925339997
0.7836319925339997
0.7836319925339997
0.7836319925339997
0.7836319925339997
0.6375535053362572
0.6375535053362572
0.6375535053362572
@ -60,18 +60,18 @@
0.6375535053362572
0.6375535053362572
0.6375535053362572
0.5871709781307677
0.5871709781307677
0.5871709781307677
0.5871709781307677
0.5871709781307677
0.5871709781307677
0.5871709781307677
0.5871709781307677
0.5871709781307677
0.5871709781307677
0.5871709781307677
0.5871709781307677
0.5871709781307678
0.5871709781307678
0.5871709781307678
0.5871709781307678
0.5871709781307678
0.5871709781307678
0.5871709781307678
0.5871709781307678
0.5871709781307678
0.5871709781307678
0.5871709781307678
0.5871709781307678
0.5202091999924413
0.5202091999924413
0.5202091999924413
@ -97,18 +97,18 @@
0.5130525169352177
0.5130525169352177
0.5130525169352177
0.5581075117249047
0.5581075117249047
0.5581075117249047
0.5581075117249047
0.5581075117249047
0.5581075117249047
0.5581075117249047
0.5581075117249047
0.5581075117249047
0.5581075117249047
0.5581075117249047
0.5581075117249047
0.5581075117249046
0.5581075117249046
0.5581075117249046
0.5581075117249046
0.5581075117249046
0.5581075117249046
0.5581075117249046
0.5581075117249046
0.5581075117249046
0.5581075117249046
0.5581075117249046
0.5581075117249046
0.6798311701936688
0.6798311701936688
0.6798311701936688
@ -146,43 +146,43 @@
0.7773509726916165
0.7773509726916165
0.7773509726916165
0.8606987912604607
0.8606987912604607
0.8606987912604607
0.8606987912604607
0.8606987912604607
0.8606987912604607
0.8606987912604607
0.8606987912604607
0.8606987912604607
0.8606987912604607
0.8606987912604607
0.8606987912604607
0.8606987912604607
0.6352934050115681
0.6352934050115681
0.6352934050115681
0.6352934050115681
0.6352934050115681
0.6352934050115681
0.6352934050115681
0.6352934050115681
0.6352934050115681
0.6352934050115681
0.6352934050115681
0.6352934050115681
0.9771089703353684
0.9771089703353684
0.9771089703353684
0.9771089703353684
0.9771089703353684
0.9771089703353684
0.9771089703353684
0.9771089703353684
0.9771089703353684
0.9771089703353684
0.9771089703353684
0.9771089703353684
0.8606987912604608
0.8606987912604608
0.8606987912604608
0.8606987912604608
0.8606987912604608
0.8606987912604608
0.8606987912604608
0.8606987912604608
0.8606987912604608
0.8606987912604608
0.8606987912604608
0.8606987912604608
0.8606987912604608
0.635293405011568
0.635293405011568
0.635293405011568
0.635293405011568
0.635293405011568
0.635293405011568
0.635293405011568
0.635293405011568
0.635293405011568
0.635293405011568
0.635293405011568
0.635293405011568
0.9771089703353683
0.9771089703353683
0.9771089703353683
0.9771089703353683
0.9771089703353683
0.9771089703353683
0.9771089703353683
0.9771089703353683
0.9771089703353683
0.9771089703353683
0.9771089703353683
0.9771089703353683
0.9955717835823895
0.9955717835823895
0.9955717835823895
@ -196,18 +196,18 @@
0.9955717835823895
0.9955717835823895
0.9955717835823895
0.6124539775938347
0.6124539775938347
0.6124539775938347
0.6124539775938347
0.6124539775938347
0.6124539775938347
0.6124539775938347
0.6124539775938347
0.6124539775938347
0.6124539775938347
0.6124539775938347
0.6124539775938347
0.6124539775938348
0.6124539775938348
0.6124539775938348
0.6124539775938348
0.6124539775938348
0.6124539775938348
0.6124539775938348
0.6124539775938348
0.6124539775938348
0.6124539775938348
0.6124539775938348
0.6124539775938348
0.6564358792397615
0.6564358792397615
0.6564358792397615
@ -220,19 +220,19 @@
0.6564358792397615
0.6564358792397615
0.6564358792397615
0.552111558999158
0.552111558999158
0.552111558999158
0.552111558999158
0.552111558999158
0.552111558999158
0.552111558999158
0.552111558999158
0.552111558999158
0.552111558999158
0.552111558999158
0.552111558999158
0.552111558999158
0.5521115589991579
0.5521115589991579
0.5521115589991579
0.5521115589991579
0.5521115589991579
0.5521115589991579
0.5521115589991579
0.5521115589991579
0.5521115589991579
0.5521115589991579
0.5521115589991579
0.5521115589991579
0.5521115589991579
0.7792659923782862
0.7792659923782862
0.7792659923782862
@ -257,31 +257,31 @@
0.6656871036437929
0.6656871036437929
0.6656871036437929
0.7435137743371989
0.7435137743371989
0.7435137743371989
0.7435137743371989
0.7435137743371989
0.7435137743371989
0.7435137743371989
0.7435137743371989
0.7435137743371989
0.7435137743371989
0.7435137743371989
0.7435137743371989
0.7435137743371989
0.8688023472919777
0.8688023472919777
0.8688023472919777
0.8688023472919777
0.8688023472919777
0.8688023472919777
0.8688023472919777
0.8688023472919777
0.8688023472919777
0.8688023472919777
0.8688023472919777
0.8688023472919777
0.743513774337199
0.743513774337199
0.743513774337199
0.743513774337199
0.743513774337199
0.743513774337199
0.743513774337199
0.743513774337199
0.743513774337199
0.743513774337199
0.743513774337199
0.743513774337199
0.743513774337199
0.8688023472919778
0.8688023472919778
0.8688023472919778
0.8688023472919778
0.8688023472919778
0.8688023472919778
0.8688023472919778
0.8688023472919778
0.8688023472919778
0.8688023472919778
0.8688023472919778
0.8688023472919778
0.7225690042828818
0.7225690042828818
0.7225690042828818
@ -307,18 +307,18 @@
0.8866100282141612
0.8866100282141612
0.8866100282141612
0.8374461350184257
0.8374461350184257
0.8374461350184257
0.8374461350184257
0.8374461350184257
0.8374461350184257
0.8374461350184257
0.8374461350184257
0.8374461350184257
0.8374461350184257
0.8374461350184257
0.8374461350184257
0.8374461350184258
0.8374461350184258
0.8374461350184258
0.8374461350184258
0.8374461350184258
0.8374461350184258
0.8374461350184258
0.8374461350184258
0.8374461350184258
0.8374461350184258
0.8374461350184258
0.8374461350184258
0.8365104788783658
0.8365104788783658
0.8365104788783658
@ -344,18 +344,18 @@
0.928892180915439
0.928892180915439
0.928892180915439
0.7275019293899534
0.7275019293899534
0.7275019293899534
0.7275019293899534
0.7275019293899534
0.7275019293899534
0.7275019293899534
0.7275019293899534
0.7275019293899534
0.7275019293899534
0.7275019293899534
0.7275019293899534
0.7275019293899533
0.7275019293899533
0.7275019293899533
0.7275019293899533
0.7275019293899533
0.7275019293899533
0.7275019293899533
0.7275019293899533
0.7275019293899533
0.7275019293899533
0.7275019293899533
0.7275019293899533
0.9516437185963472
0.9516437185963472
0.9516437185963472

File diff suppressed because one or more lines are too long