mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-24 16:42:05 +00:00
some review fixes
This commit is contained in:
parent
a948f223bc
commit
fc4c721fa5
@ -78,14 +78,14 @@ public:
|
|||||||
Float64 derivative = (target - bias);
|
Float64 derivative = (target - bias);
|
||||||
for (size_t i = 0; i < weights.size(); ++i)
|
for (size_t i = 0; i < weights.size(); ++i)
|
||||||
{
|
{
|
||||||
derivative -= weights[i] * static_cast<const ColumnVector<Float64> &>(*columns[i + 1]).getData()[row_num];
|
derivative -= weights[i] * static_cast<const ColumnVector<Float64> &>(*columns[i]).getData()[row_num];
|
||||||
}
|
}
|
||||||
derivative *= (2 * learning_rate);
|
derivative *= (2 * learning_rate);
|
||||||
|
|
||||||
batch_gradient[weights.size()] += derivative;
|
batch_gradient[weights.size()] += derivative;
|
||||||
for (size_t i = 0; i < weights.size(); ++i)
|
for (size_t i = 0; i < weights.size(); ++i)
|
||||||
{
|
{
|
||||||
batch_gradient[i] += derivative * static_cast<const ColumnVector<Float64> &>(*columns[i + 1]).getData()[row_num];;
|
batch_gradient[i] += derivative * static_cast<const ColumnVector<Float64> &>(*columns[i]).getData()[row_num];;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Float64 predict(const std::vector<Float64> & predict_feature, const std::vector<Float64> & weights, Float64 bias) const override
|
Float64 predict(const std::vector<Float64> & predict_feature, const std::vector<Float64> & weights, Float64 bias) const override
|
||||||
@ -107,6 +107,7 @@ public:
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class LogisticRegression : public IGradientComputer
|
class LogisticRegression : public IGradientComputer
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
@ -120,7 +121,7 @@ public:
|
|||||||
Float64 derivative = bias;
|
Float64 derivative = bias;
|
||||||
for (size_t i = 0; i < weights.size(); ++i)
|
for (size_t i = 0; i < weights.size(); ++i)
|
||||||
{
|
{
|
||||||
derivative += weights[i] * static_cast<const ColumnVector<Float64> &>(*columns[i + 1]).getData()[row_num];;
|
derivative += weights[i] * static_cast<const ColumnVector<Float64> &>(*columns[i]).getData()[row_num];;
|
||||||
}
|
}
|
||||||
derivative *= target;
|
derivative *= target;
|
||||||
derivative = learning_rate * exp(derivative);
|
derivative = learning_rate * exp(derivative);
|
||||||
@ -128,7 +129,7 @@ public:
|
|||||||
batch_gradient[weights.size()] += target / (derivative + 1);;
|
batch_gradient[weights.size()] += target / (derivative + 1);;
|
||||||
for (size_t i = 0; i < weights.size(); ++i)
|
for (size_t i = 0; i < weights.size(); ++i)
|
||||||
{
|
{
|
||||||
batch_gradient[i] += target / (derivative + 1) * static_cast<const ColumnVector<Float64> &>(*columns[i + 1]).getData()[row_num];
|
batch_gradient[i] += target / (derivative + 1) * static_cast<const ColumnVector<Float64> &>(*columns[i]).getData()[row_num];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Float64 predict(const std::vector<Float64> & predict_feature, const std::vector<Float64> & weights, Float64 bias) const override
|
Float64 predict(const std::vector<Float64> & predict_feature, const std::vector<Float64> & weights, Float64 bias) const override
|
||||||
@ -156,27 +157,32 @@ class IWeightsUpdater
|
|||||||
public:
|
public:
|
||||||
virtual ~IWeightsUpdater() = default;
|
virtual ~IWeightsUpdater() = default;
|
||||||
|
|
||||||
virtual void update(UInt32 cur_batch, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & gradient) = 0;
|
virtual void update(UInt32 batch_size, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & gradient) = 0;
|
||||||
virtual void merge(const std::shared_ptr<IWeightsUpdater>, Float64, Float64) {}
|
virtual void merge(const IWeightsUpdater &, Float64, Float64) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
class StochasticGradientDescent : public IWeightsUpdater
|
class StochasticGradientDescent : public IWeightsUpdater
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
void update(UInt32 cur_batch, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & batch_gradient) override {
|
void update(UInt32 batch_size, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & batch_gradient) override {
|
||||||
|
/// batch_size is already checked to be greater than 0
|
||||||
|
|
||||||
for (size_t i = 0; i < weights.size(); ++i)
|
for (size_t i = 0; i < weights.size(); ++i)
|
||||||
{
|
{
|
||||||
weights[i] += batch_gradient[i] / cur_batch;
|
weights[i] += batch_gradient[i] / batch_size;
|
||||||
}
|
}
|
||||||
bias += batch_gradient[weights.size()] / cur_batch;
|
bias += batch_gradient[weights.size()] / batch_size;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class Momentum : public IWeightsUpdater
|
class Momentum : public IWeightsUpdater
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
Momentum() {}
|
Momentum() {}
|
||||||
Momentum (Float64 alpha) : alpha_(alpha) {}
|
Momentum (Float64 alpha) : alpha_(alpha) {}
|
||||||
void update(UInt32 cur_batch, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & batch_gradient) override {
|
void update(UInt32 batch_size, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & batch_gradient) override {
|
||||||
|
/// batch_size is already checked to be greater than 0
|
||||||
|
|
||||||
if (hk_.size() == 0)
|
if (hk_.size() == 0)
|
||||||
{
|
{
|
||||||
hk_.resize(batch_gradient.size(), Float64{0.0});
|
hk_.resize(batch_gradient.size(), Float64{0.0});
|
||||||
@ -187,21 +193,23 @@ public:
|
|||||||
}
|
}
|
||||||
for (size_t i = 0; i < weights.size(); ++i)
|
for (size_t i = 0; i < weights.size(); ++i)
|
||||||
{
|
{
|
||||||
weights[i] += hk_[i] / cur_batch;
|
weights[i] += hk_[i] / batch_size;
|
||||||
}
|
}
|
||||||
bias += hk_[weights.size()] / cur_batch;
|
bias += hk_[weights.size()] / batch_size;
|
||||||
}
|
}
|
||||||
virtual void merge(const std::shared_ptr<IWeightsUpdater> rhs, Float64 frac, Float64 rhs_frac) override {
|
virtual void merge(const IWeightsUpdater & rhs, Float64 frac, Float64 rhs_frac) override {
|
||||||
auto momentum_rhs = std::dynamic_pointer_cast<Momentum>(rhs);
|
const auto & momentum_rhs = dynamic_cast<const Momentum &>(rhs);
|
||||||
for (size_t i = 0; i < hk_.size(); ++i)
|
for (size_t i = 0; i < hk_.size(); ++i)
|
||||||
{
|
{
|
||||||
hk_[i] = hk_[i] * frac + momentum_rhs->hk_[i] * rhs_frac;
|
hk_[i] = hk_[i] * frac + momentum_rhs.hk_[i] * rhs_frac;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Float64 alpha_{0.1};
|
private:
|
||||||
std::vector<Float64> hk_;
|
Float64 alpha_{0.1};
|
||||||
|
std::vector<Float64> hk_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class LinearModelData
|
class LinearModelData
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
@ -210,25 +218,26 @@ public:
|
|||||||
|
|
||||||
LinearModelData(Float64 learning_rate,
|
LinearModelData(Float64 learning_rate,
|
||||||
UInt32 param_num,
|
UInt32 param_num,
|
||||||
UInt32 batch_size,
|
UInt32 batch_capacity,
|
||||||
std::shared_ptr<IGradientComputer> gc,
|
std::shared_ptr<IGradientComputer> gc,
|
||||||
std::shared_ptr<IWeightsUpdater> wu)
|
std::shared_ptr<IWeightsUpdater> wu)
|
||||||
: learning_rate(learning_rate),
|
: learning_rate(learning_rate),
|
||||||
batch_size(batch_size),
|
batch_capacity(batch_capacity),
|
||||||
gradient_computer(std::move(gc)),
|
gradient_computer(std::move(gc)),
|
||||||
weights_updater(std::move(wu))
|
weights_updater(std::move(wu))
|
||||||
{
|
{
|
||||||
weights.resize(param_num, Float64{0.0});
|
weights.resize(param_num, Float64{0.0});
|
||||||
cur_batch = 0;
|
batch_size = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void add(const IColumn ** columns, size_t row_num)
|
||||||
|
|
||||||
void add(Float64 target, const IColumn ** columns, size_t row_num)
|
|
||||||
{
|
{
|
||||||
gradient_computer->compute(weights, bias, learning_rate, target, columns, row_num);
|
/// first column stores target; features start from (columns + 1)
|
||||||
++cur_batch;
|
const auto & target = static_cast<const ColumnVector<Float64> &>(*columns[0]).getData()[row_num];
|
||||||
if (cur_batch == batch_size)
|
|
||||||
|
gradient_computer->compute(weights, bias, learning_rate, target, columns + 1, row_num);
|
||||||
|
++batch_size;
|
||||||
|
if (batch_size == batch_capacity)
|
||||||
{
|
{
|
||||||
update_state();
|
update_state();
|
||||||
}
|
}
|
||||||
@ -253,7 +262,7 @@ public:
|
|||||||
|
|
||||||
bias = bias * frac + rhs.bias * rhs_frac;
|
bias = bias * frac + rhs.bias * rhs_frac;
|
||||||
iter_num += rhs.iter_num;
|
iter_num += rhs.iter_num;
|
||||||
weights_updater->merge(rhs.weights_updater, frac, rhs_frac);
|
weights_updater->merge(*rhs.weights_updater, frac, rhs_frac);
|
||||||
}
|
}
|
||||||
|
|
||||||
void write(WriteBuffer & buf) const
|
void write(WriteBuffer & buf) const
|
||||||
@ -261,7 +270,7 @@ public:
|
|||||||
writeBinary(bias, buf);
|
writeBinary(bias, buf);
|
||||||
writeBinary(weights, buf);
|
writeBinary(weights, buf);
|
||||||
writeBinary(iter_num, buf);
|
writeBinary(iter_num, buf);
|
||||||
writeBinary(cur_batch, buf);
|
writeBinary(batch_size, buf);
|
||||||
gradient_computer->write(buf);
|
gradient_computer->write(buf);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -270,7 +279,7 @@ public:
|
|||||||
readBinary(bias, buf);
|
readBinary(bias, buf);
|
||||||
readBinary(weights, buf);
|
readBinary(weights, buf);
|
||||||
readBinary(iter_num, buf);
|
readBinary(iter_num, buf);
|
||||||
readBinary(cur_batch, buf);
|
readBinary(batch_size, buf);
|
||||||
gradient_computer->read(buf);
|
gradient_computer->read(buf);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -289,20 +298,20 @@ public:
|
|||||||
private:
|
private:
|
||||||
std::vector<Float64> weights;
|
std::vector<Float64> weights;
|
||||||
Float64 learning_rate;
|
Float64 learning_rate;
|
||||||
UInt32 batch_size;
|
UInt32 batch_capacity;
|
||||||
Float64 bias{0.0};
|
Float64 bias{0.0};
|
||||||
UInt32 iter_num = 0;
|
UInt32 iter_num = 0;
|
||||||
UInt32 cur_batch;
|
UInt32 batch_size;
|
||||||
std::shared_ptr<IGradientComputer> gradient_computer;
|
std::shared_ptr<IGradientComputer> gradient_computer;
|
||||||
std::shared_ptr<IWeightsUpdater> weights_updater;
|
std::shared_ptr<IWeightsUpdater> weights_updater;
|
||||||
|
|
||||||
void update_state()
|
void update_state()
|
||||||
{
|
{
|
||||||
if (cur_batch == 0)
|
if (batch_size == 0)
|
||||||
return;
|
return;
|
||||||
|
|
||||||
weights_updater->update(cur_batch, weights, bias, gradient_computer->get());
|
weights_updater->update(batch_size, weights, bias, gradient_computer->get());
|
||||||
cur_batch = 0;
|
batch_size = 0;
|
||||||
++iter_num;
|
++iter_num;
|
||||||
gradient_computer->reset();
|
gradient_computer->reset();
|
||||||
}
|
}
|
||||||
@ -343,9 +352,7 @@ public:
|
|||||||
|
|
||||||
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
|
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||||
{
|
{
|
||||||
const auto & target = static_cast<const ColumnVector<Float64> &>(*columns[0]);
|
this->data(place).add(columns, row_num);
|
||||||
|
|
||||||
this->data(place).add(target.getData()[row_num], columns, row_num);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// хочется не константный rhs
|
/// хочется не константный rhs
|
||||||
|
@ -128,14 +128,16 @@ MutableColumnPtr ColumnAggregateFunction::predictValues(Block & block, const Col
|
|||||||
if (ML_function_Linear)
|
if (ML_function_Linear)
|
||||||
{
|
{
|
||||||
size_t row_num = 0;
|
size_t row_num = 0;
|
||||||
for (auto val : data) {
|
for (auto val : data)
|
||||||
|
{
|
||||||
ML_function_Linear->predictResultInto(val, *res, block, row_num, arguments);
|
ML_function_Linear->predictResultInto(val, *res, block, row_num, arguments);
|
||||||
++row_num;
|
++row_num;
|
||||||
}
|
}
|
||||||
} else if (ML_function_Logistic)
|
} else if (ML_function_Logistic)
|
||||||
{
|
{
|
||||||
size_t row_num = 0;
|
size_t row_num = 0;
|
||||||
for (auto val : data) {
|
for (auto val : data)
|
||||||
|
{
|
||||||
ML_function_Logistic->predictResultInto(val, *res, block, row_num, arguments);
|
ML_function_Logistic->predictResultInto(val, *res, block, row_num, arguments);
|
||||||
++row_num;
|
++row_num;
|
||||||
}
|
}
|
||||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user