mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-24 08:32:02 +00:00
some review fixes
This commit is contained in:
parent
a948f223bc
commit
fc4c721fa5
@ -78,14 +78,14 @@ public:
|
||||
Float64 derivative = (target - bias);
|
||||
for (size_t i = 0; i < weights.size(); ++i)
|
||||
{
|
||||
derivative -= weights[i] * static_cast<const ColumnVector<Float64> &>(*columns[i + 1]).getData()[row_num];
|
||||
derivative -= weights[i] * static_cast<const ColumnVector<Float64> &>(*columns[i]).getData()[row_num];
|
||||
}
|
||||
derivative *= (2 * learning_rate);
|
||||
|
||||
batch_gradient[weights.size()] += derivative;
|
||||
for (size_t i = 0; i < weights.size(); ++i)
|
||||
{
|
||||
batch_gradient[i] += derivative * static_cast<const ColumnVector<Float64> &>(*columns[i + 1]).getData()[row_num];;
|
||||
batch_gradient[i] += derivative * static_cast<const ColumnVector<Float64> &>(*columns[i]).getData()[row_num];;
|
||||
}
|
||||
}
|
||||
Float64 predict(const std::vector<Float64> & predict_feature, const std::vector<Float64> & weights, Float64 bias) const override
|
||||
@ -107,6 +107,7 @@ public:
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
class LogisticRegression : public IGradientComputer
|
||||
{
|
||||
public:
|
||||
@ -120,7 +121,7 @@ public:
|
||||
Float64 derivative = bias;
|
||||
for (size_t i = 0; i < weights.size(); ++i)
|
||||
{
|
||||
derivative += weights[i] * static_cast<const ColumnVector<Float64> &>(*columns[i + 1]).getData()[row_num];;
|
||||
derivative += weights[i] * static_cast<const ColumnVector<Float64> &>(*columns[i]).getData()[row_num];;
|
||||
}
|
||||
derivative *= target;
|
||||
derivative = learning_rate * exp(derivative);
|
||||
@ -128,7 +129,7 @@ public:
|
||||
batch_gradient[weights.size()] += target / (derivative + 1);;
|
||||
for (size_t i = 0; i < weights.size(); ++i)
|
||||
{
|
||||
batch_gradient[i] += target / (derivative + 1) * static_cast<const ColumnVector<Float64> &>(*columns[i + 1]).getData()[row_num];
|
||||
batch_gradient[i] += target / (derivative + 1) * static_cast<const ColumnVector<Float64> &>(*columns[i]).getData()[row_num];
|
||||
}
|
||||
}
|
||||
Float64 predict(const std::vector<Float64> & predict_feature, const std::vector<Float64> & weights, Float64 bias) const override
|
||||
@ -156,27 +157,32 @@ class IWeightsUpdater
|
||||
public:
|
||||
virtual ~IWeightsUpdater() = default;
|
||||
|
||||
virtual void update(UInt32 cur_batch, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & gradient) = 0;
|
||||
virtual void merge(const std::shared_ptr<IWeightsUpdater>, Float64, Float64) {}
|
||||
virtual void update(UInt32 batch_size, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & gradient) = 0;
|
||||
virtual void merge(const IWeightsUpdater &, Float64, Float64) {}
|
||||
};
|
||||
|
||||
class StochasticGradientDescent : public IWeightsUpdater
|
||||
{
|
||||
public:
|
||||
void update(UInt32 cur_batch, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & batch_gradient) override {
|
||||
void update(UInt32 batch_size, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & batch_gradient) override {
|
||||
/// batch_size is already checked to be greater than 0
|
||||
|
||||
for (size_t i = 0; i < weights.size(); ++i)
|
||||
{
|
||||
weights[i] += batch_gradient[i] / cur_batch;
|
||||
weights[i] += batch_gradient[i] / batch_size;
|
||||
}
|
||||
bias += batch_gradient[weights.size()] / cur_batch;
|
||||
bias += batch_gradient[weights.size()] / batch_size;
|
||||
}
|
||||
};
|
||||
|
||||
class Momentum : public IWeightsUpdater
|
||||
{
|
||||
public:
|
||||
Momentum() {}
|
||||
Momentum (Float64 alpha) : alpha_(alpha) {}
|
||||
void update(UInt32 cur_batch, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & batch_gradient) override {
|
||||
void update(UInt32 batch_size, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & batch_gradient) override {
|
||||
/// batch_size is already checked to be greater than 0
|
||||
|
||||
if (hk_.size() == 0)
|
||||
{
|
||||
hk_.resize(batch_gradient.size(), Float64{0.0});
|
||||
@ -187,21 +193,23 @@ public:
|
||||
}
|
||||
for (size_t i = 0; i < weights.size(); ++i)
|
||||
{
|
||||
weights[i] += hk_[i] / cur_batch;
|
||||
weights[i] += hk_[i] / batch_size;
|
||||
}
|
||||
bias += hk_[weights.size()] / cur_batch;
|
||||
bias += hk_[weights.size()] / batch_size;
|
||||
}
|
||||
virtual void merge(const std::shared_ptr<IWeightsUpdater> rhs, Float64 frac, Float64 rhs_frac) override {
|
||||
auto momentum_rhs = std::dynamic_pointer_cast<Momentum>(rhs);
|
||||
virtual void merge(const IWeightsUpdater & rhs, Float64 frac, Float64 rhs_frac) override {
|
||||
const auto & momentum_rhs = dynamic_cast<const Momentum &>(rhs);
|
||||
for (size_t i = 0; i < hk_.size(); ++i)
|
||||
{
|
||||
hk_[i] = hk_[i] * frac + momentum_rhs->hk_[i] * rhs_frac;
|
||||
hk_[i] = hk_[i] * frac + momentum_rhs.hk_[i] * rhs_frac;
|
||||
}
|
||||
}
|
||||
|
||||
Float64 alpha_{0.1};
|
||||
std::vector<Float64> hk_;
|
||||
private:
|
||||
Float64 alpha_{0.1};
|
||||
std::vector<Float64> hk_;
|
||||
};
|
||||
|
||||
class LinearModelData
|
||||
{
|
||||
public:
|
||||
@ -210,25 +218,26 @@ public:
|
||||
|
||||
LinearModelData(Float64 learning_rate,
|
||||
UInt32 param_num,
|
||||
UInt32 batch_size,
|
||||
UInt32 batch_capacity,
|
||||
std::shared_ptr<IGradientComputer> gc,
|
||||
std::shared_ptr<IWeightsUpdater> wu)
|
||||
: learning_rate(learning_rate),
|
||||
batch_size(batch_size),
|
||||
batch_capacity(batch_capacity),
|
||||
gradient_computer(std::move(gc)),
|
||||
weights_updater(std::move(wu))
|
||||
{
|
||||
weights.resize(param_num, Float64{0.0});
|
||||
cur_batch = 0;
|
||||
batch_size = 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void add(Float64 target, const IColumn ** columns, size_t row_num)
|
||||
void add(const IColumn ** columns, size_t row_num)
|
||||
{
|
||||
gradient_computer->compute(weights, bias, learning_rate, target, columns, row_num);
|
||||
++cur_batch;
|
||||
if (cur_batch == batch_size)
|
||||
/// first column stores target; features start from (columns + 1)
|
||||
const auto & target = static_cast<const ColumnVector<Float64> &>(*columns[0]).getData()[row_num];
|
||||
|
||||
gradient_computer->compute(weights, bias, learning_rate, target, columns + 1, row_num);
|
||||
++batch_size;
|
||||
if (batch_size == batch_capacity)
|
||||
{
|
||||
update_state();
|
||||
}
|
||||
@ -253,7 +262,7 @@ public:
|
||||
|
||||
bias = bias * frac + rhs.bias * rhs_frac;
|
||||
iter_num += rhs.iter_num;
|
||||
weights_updater->merge(rhs.weights_updater, frac, rhs_frac);
|
||||
weights_updater->merge(*rhs.weights_updater, frac, rhs_frac);
|
||||
}
|
||||
|
||||
void write(WriteBuffer & buf) const
|
||||
@ -261,7 +270,7 @@ public:
|
||||
writeBinary(bias, buf);
|
||||
writeBinary(weights, buf);
|
||||
writeBinary(iter_num, buf);
|
||||
writeBinary(cur_batch, buf);
|
||||
writeBinary(batch_size, buf);
|
||||
gradient_computer->write(buf);
|
||||
}
|
||||
|
||||
@ -270,7 +279,7 @@ public:
|
||||
readBinary(bias, buf);
|
||||
readBinary(weights, buf);
|
||||
readBinary(iter_num, buf);
|
||||
readBinary(cur_batch, buf);
|
||||
readBinary(batch_size, buf);
|
||||
gradient_computer->read(buf);
|
||||
}
|
||||
|
||||
@ -289,20 +298,20 @@ public:
|
||||
private:
|
||||
std::vector<Float64> weights;
|
||||
Float64 learning_rate;
|
||||
UInt32 batch_size;
|
||||
UInt32 batch_capacity;
|
||||
Float64 bias{0.0};
|
||||
UInt32 iter_num = 0;
|
||||
UInt32 cur_batch;
|
||||
UInt32 batch_size;
|
||||
std::shared_ptr<IGradientComputer> gradient_computer;
|
||||
std::shared_ptr<IWeightsUpdater> weights_updater;
|
||||
|
||||
void update_state()
|
||||
{
|
||||
if (cur_batch == 0)
|
||||
if (batch_size == 0)
|
||||
return;
|
||||
|
||||
weights_updater->update(cur_batch, weights, bias, gradient_computer->get());
|
||||
cur_batch = 0;
|
||||
weights_updater->update(batch_size, weights, bias, gradient_computer->get());
|
||||
batch_size = 0;
|
||||
++iter_num;
|
||||
gradient_computer->reset();
|
||||
}
|
||||
@ -343,9 +352,7 @@ public:
|
||||
|
||||
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
{
|
||||
const auto & target = static_cast<const ColumnVector<Float64> &>(*columns[0]);
|
||||
|
||||
this->data(place).add(target.getData()[row_num], columns, row_num);
|
||||
this->data(place).add(columns, row_num);
|
||||
}
|
||||
|
||||
/// хочется не константный rhs
|
||||
|
@ -128,14 +128,16 @@ MutableColumnPtr ColumnAggregateFunction::predictValues(Block & block, const Col
|
||||
if (ML_function_Linear)
|
||||
{
|
||||
size_t row_num = 0;
|
||||
for (auto val : data) {
|
||||
for (auto val : data)
|
||||
{
|
||||
ML_function_Linear->predictResultInto(val, *res, block, row_num, arguments);
|
||||
++row_num;
|
||||
}
|
||||
} else if (ML_function_Logistic)
|
||||
{
|
||||
size_t row_num = 0;
|
||||
for (auto val : data) {
|
||||
for (auto val : data)
|
||||
{
|
||||
ML_function_Logistic->predictResultInto(val, *res, block, row_num, arguments);
|
||||
++row_num;
|
||||
}
|
||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user