some review fixes

This commit is contained in:
alexander kozhikhov 2019-02-11 01:07:47 +03:00
parent a948f223bc
commit fc4c721fa5
3 changed files with 48 additions and 87 deletions

View File

@ -78,14 +78,14 @@ public:
Float64 derivative = (target - bias); Float64 derivative = (target - bias);
for (size_t i = 0; i < weights.size(); ++i) for (size_t i = 0; i < weights.size(); ++i)
{ {
derivative -= weights[i] * static_cast<const ColumnVector<Float64> &>(*columns[i + 1]).getData()[row_num]; derivative -= weights[i] * static_cast<const ColumnVector<Float64> &>(*columns[i]).getData()[row_num];
} }
derivative *= (2 * learning_rate); derivative *= (2 * learning_rate);
batch_gradient[weights.size()] += derivative; batch_gradient[weights.size()] += derivative;
for (size_t i = 0; i < weights.size(); ++i) for (size_t i = 0; i < weights.size(); ++i)
{ {
batch_gradient[i] += derivative * static_cast<const ColumnVector<Float64> &>(*columns[i + 1]).getData()[row_num];; batch_gradient[i] += derivative * static_cast<const ColumnVector<Float64> &>(*columns[i]).getData()[row_num];;
} }
} }
Float64 predict(const std::vector<Float64> & predict_feature, const std::vector<Float64> & weights, Float64 bias) const override Float64 predict(const std::vector<Float64> & predict_feature, const std::vector<Float64> & weights, Float64 bias) const override
@ -107,6 +107,7 @@ public:
return res; return res;
} }
}; };
class LogisticRegression : public IGradientComputer class LogisticRegression : public IGradientComputer
{ {
public: public:
@ -120,7 +121,7 @@ public:
Float64 derivative = bias; Float64 derivative = bias;
for (size_t i = 0; i < weights.size(); ++i) for (size_t i = 0; i < weights.size(); ++i)
{ {
derivative += weights[i] * static_cast<const ColumnVector<Float64> &>(*columns[i + 1]).getData()[row_num];; derivative += weights[i] * static_cast<const ColumnVector<Float64> &>(*columns[i]).getData()[row_num];;
} }
derivative *= target; derivative *= target;
derivative = learning_rate * exp(derivative); derivative = learning_rate * exp(derivative);
@ -128,7 +129,7 @@ public:
batch_gradient[weights.size()] += target / (derivative + 1);; batch_gradient[weights.size()] += target / (derivative + 1);;
for (size_t i = 0; i < weights.size(); ++i) for (size_t i = 0; i < weights.size(); ++i)
{ {
batch_gradient[i] += target / (derivative + 1) * static_cast<const ColumnVector<Float64> &>(*columns[i + 1]).getData()[row_num]; batch_gradient[i] += target / (derivative + 1) * static_cast<const ColumnVector<Float64> &>(*columns[i]).getData()[row_num];
} }
} }
Float64 predict(const std::vector<Float64> & predict_feature, const std::vector<Float64> & weights, Float64 bias) const override Float64 predict(const std::vector<Float64> & predict_feature, const std::vector<Float64> & weights, Float64 bias) const override
@ -156,27 +157,32 @@ class IWeightsUpdater
public: public:
virtual ~IWeightsUpdater() = default; virtual ~IWeightsUpdater() = default;
virtual void update(UInt32 cur_batch, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & gradient) = 0; virtual void update(UInt32 batch_size, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & gradient) = 0;
virtual void merge(const std::shared_ptr<IWeightsUpdater>, Float64, Float64) {} virtual void merge(const IWeightsUpdater &, Float64, Float64) {}
}; };
class StochasticGradientDescent : public IWeightsUpdater class StochasticGradientDescent : public IWeightsUpdater
{ {
public: public:
void update(UInt32 cur_batch, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & batch_gradient) override { void update(UInt32 batch_size, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & batch_gradient) override {
/// batch_size is already checked to be greater than 0
for (size_t i = 0; i < weights.size(); ++i) for (size_t i = 0; i < weights.size(); ++i)
{ {
weights[i] += batch_gradient[i] / cur_batch; weights[i] += batch_gradient[i] / batch_size;
} }
bias += batch_gradient[weights.size()] / cur_batch; bias += batch_gradient[weights.size()] / batch_size;
} }
}; };
class Momentum : public IWeightsUpdater class Momentum : public IWeightsUpdater
{ {
public: public:
Momentum() {} Momentum() {}
Momentum (Float64 alpha) : alpha_(alpha) {} Momentum (Float64 alpha) : alpha_(alpha) {}
void update(UInt32 cur_batch, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & batch_gradient) override { void update(UInt32 batch_size, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & batch_gradient) override {
/// batch_size is already checked to be greater than 0
if (hk_.size() == 0) if (hk_.size() == 0)
{ {
hk_.resize(batch_gradient.size(), Float64{0.0}); hk_.resize(batch_gradient.size(), Float64{0.0});
@ -187,21 +193,23 @@ public:
} }
for (size_t i = 0; i < weights.size(); ++i) for (size_t i = 0; i < weights.size(); ++i)
{ {
weights[i] += hk_[i] / cur_batch; weights[i] += hk_[i] / batch_size;
} }
bias += hk_[weights.size()] / cur_batch; bias += hk_[weights.size()] / batch_size;
} }
virtual void merge(const std::shared_ptr<IWeightsUpdater> rhs, Float64 frac, Float64 rhs_frac) override { virtual void merge(const IWeightsUpdater & rhs, Float64 frac, Float64 rhs_frac) override {
auto momentum_rhs = std::dynamic_pointer_cast<Momentum>(rhs); const auto & momentum_rhs = dynamic_cast<const Momentum &>(rhs);
for (size_t i = 0; i < hk_.size(); ++i) for (size_t i = 0; i < hk_.size(); ++i)
{ {
hk_[i] = hk_[i] * frac + momentum_rhs->hk_[i] * rhs_frac; hk_[i] = hk_[i] * frac + momentum_rhs.hk_[i] * rhs_frac;
} }
} }
Float64 alpha_{0.1}; private:
std::vector<Float64> hk_; Float64 alpha_{0.1};
std::vector<Float64> hk_;
}; };
class LinearModelData class LinearModelData
{ {
public: public:
@ -210,25 +218,26 @@ public:
LinearModelData(Float64 learning_rate, LinearModelData(Float64 learning_rate,
UInt32 param_num, UInt32 param_num,
UInt32 batch_size, UInt32 batch_capacity,
std::shared_ptr<IGradientComputer> gc, std::shared_ptr<IGradientComputer> gc,
std::shared_ptr<IWeightsUpdater> wu) std::shared_ptr<IWeightsUpdater> wu)
: learning_rate(learning_rate), : learning_rate(learning_rate),
batch_size(batch_size), batch_capacity(batch_capacity),
gradient_computer(std::move(gc)), gradient_computer(std::move(gc)),
weights_updater(std::move(wu)) weights_updater(std::move(wu))
{ {
weights.resize(param_num, Float64{0.0}); weights.resize(param_num, Float64{0.0});
cur_batch = 0; batch_size = 0;
} }
void add(const IColumn ** columns, size_t row_num)
void add(Float64 target, const IColumn ** columns, size_t row_num)
{ {
gradient_computer->compute(weights, bias, learning_rate, target, columns, row_num); /// first column stores target; features start from (columns + 1)
++cur_batch; const auto & target = static_cast<const ColumnVector<Float64> &>(*columns[0]).getData()[row_num];
if (cur_batch == batch_size)
gradient_computer->compute(weights, bias, learning_rate, target, columns + 1, row_num);
++batch_size;
if (batch_size == batch_capacity)
{ {
update_state(); update_state();
} }
@ -253,7 +262,7 @@ public:
bias = bias * frac + rhs.bias * rhs_frac; bias = bias * frac + rhs.bias * rhs_frac;
iter_num += rhs.iter_num; iter_num += rhs.iter_num;
weights_updater->merge(rhs.weights_updater, frac, rhs_frac); weights_updater->merge(*rhs.weights_updater, frac, rhs_frac);
} }
void write(WriteBuffer & buf) const void write(WriteBuffer & buf) const
@ -261,7 +270,7 @@ public:
writeBinary(bias, buf); writeBinary(bias, buf);
writeBinary(weights, buf); writeBinary(weights, buf);
writeBinary(iter_num, buf); writeBinary(iter_num, buf);
writeBinary(cur_batch, buf); writeBinary(batch_size, buf);
gradient_computer->write(buf); gradient_computer->write(buf);
} }
@ -270,7 +279,7 @@ public:
readBinary(bias, buf); readBinary(bias, buf);
readBinary(weights, buf); readBinary(weights, buf);
readBinary(iter_num, buf); readBinary(iter_num, buf);
readBinary(cur_batch, buf); readBinary(batch_size, buf);
gradient_computer->read(buf); gradient_computer->read(buf);
} }
@ -289,20 +298,20 @@ public:
private: private:
std::vector<Float64> weights; std::vector<Float64> weights;
Float64 learning_rate; Float64 learning_rate;
UInt32 batch_size; UInt32 batch_capacity;
Float64 bias{0.0}; Float64 bias{0.0};
UInt32 iter_num = 0; UInt32 iter_num = 0;
UInt32 cur_batch; UInt32 batch_size;
std::shared_ptr<IGradientComputer> gradient_computer; std::shared_ptr<IGradientComputer> gradient_computer;
std::shared_ptr<IWeightsUpdater> weights_updater; std::shared_ptr<IWeightsUpdater> weights_updater;
void update_state() void update_state()
{ {
if (cur_batch == 0) if (batch_size == 0)
return; return;
weights_updater->update(cur_batch, weights, bias, gradient_computer->get()); weights_updater->update(batch_size, weights, bias, gradient_computer->get());
cur_batch = 0; batch_size = 0;
++iter_num; ++iter_num;
gradient_computer->reset(); gradient_computer->reset();
} }
@ -343,9 +352,7 @@ public:
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
{ {
const auto & target = static_cast<const ColumnVector<Float64> &>(*columns[0]); this->data(place).add(columns, row_num);
this->data(place).add(target.getData()[row_num], columns, row_num);
} }
/// хочется не константный rhs /// хочется не константный rhs

View File

@ -128,14 +128,16 @@ MutableColumnPtr ColumnAggregateFunction::predictValues(Block & block, const Col
if (ML_function_Linear) if (ML_function_Linear)
{ {
size_t row_num = 0; size_t row_num = 0;
for (auto val : data) { for (auto val : data)
{
ML_function_Linear->predictResultInto(val, *res, block, row_num, arguments); ML_function_Linear->predictResultInto(val, *res, block, row_num, arguments);
++row_num; ++row_num;
} }
} else if (ML_function_Logistic) } else if (ML_function_Logistic)
{ {
size_t row_num = 0; size_t row_num = 0;
for (auto val : data) { for (auto val : data)
{
ML_function_Logistic->predictResultInto(val, *res, block, row_num, arguments); ML_function_Logistic->predictResultInto(val, *res, block, row_num, arguments);
++row_num; ++row_num;
} }

File diff suppressed because one or more lines are too long