Merge pull request #4 from Quid37/makuzn

Added LinearModelData
This commit is contained in:
Quid37 2019-01-26 12:57:32 +03:00 committed by GitHub
commit dc51843065
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -29,6 +29,104 @@ namespace ErrorCodes
extern const int BAD_ARGUMENTS;
}
class IGradientComputer
{
public:
virtual ~IGradientComputer()
{}
void add(/* weights, */Float64 target, const IColumn ** columns, size_t row_num) final
{
++cur_batch;
std::vector<Float64> cur_grad = compute_gradient(/*...*/);
for (size_t i = 0; i != batch_gradient.size(); ++i) {
batch_gradient[i] += cur_grad[i];
}
}
std::vector<Float64> get() final
{
std::vector<Float64> result(batch_gradient.size());
for (size_t i = 0; i != batch_gradient.size(); ++i) {
result[i] = batch_gradient[i] / cur_batch;
batch_gradient[i] = 0.0;
}
cur_batch = 0;
return result;
}
protected:
virtual std::vector<Float64> compute_gradient(/* weights, */Float64 target, const IColumn ** columns, size_t row_num) = 0;
private:
UInt32 cur_batch = 0;
std::vector<Float64> batch_gradient; // gradient for bias lies in batch_gradient[batch_gradient.size() - 1]
};
class LinearRegression : public IGradientComputer
{
public:
virtual ~LinearRegression()
{}
protected:
virtual std::vector<Float64> compute_gradient(/* weights, */Float64 target, const IColumn ** columns, size_t row_num)
{
// TODO
}
};
class IWeightsUpdater
{
public:
virtual ~IWeightsUpdater()
{}
virtual void update(/* weights, gradient */) = 0;
};
class GradientDescent : public IWeightsUpdater
{
public:
virtual ~GradientDescent()
{}
virtual void update(/* weights, gradient */) = 0 {
// TODO
}
};
struct LinearModelData
{
LinearModelData()
{}
LinearModelData(Float64 learning_rate_, UInt32 param_num_, UInt32 batch_size_)
: learning_rate(learning_rate_), batch_size(batch_size_) {
weights.resize(param_num_, Float64{0.0});
batch_gradient.resize(param_num_ + 1, Float64{0.0});
cur_batch = 0;
}
std::vector<Float64> weights;
Float64 bias{0.0};
std::shared_ptr<IGradientComputer> gradient_computer;
std::shared_ptr<IWeightsUpdater> weights_updater;
void add(Float64 target, const IColumn ** columns, size_t row_num)
{
gradient_cumputer->add(target, columns, row_num);
if (cur_batch == batch_size)
{
cur_batch = 0;
weights_updater->update(/* weights */, gradient_computer->get());
}
}
void merge(const LinearModelData & rhs)
{
}
};
struct LinearRegressionData
{