mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-19 14:11:58 +00:00
commit
dc51843065
@ -29,6 +29,104 @@ namespace ErrorCodes
|
||||
extern const int BAD_ARGUMENTS;
|
||||
}
|
||||
|
||||
class IGradientComputer
|
||||
{
|
||||
public:
|
||||
virtual ~IGradientComputer()
|
||||
{}
|
||||
|
||||
void add(/* weights, */Float64 target, const IColumn ** columns, size_t row_num) final
|
||||
{
|
||||
++cur_batch;
|
||||
std::vector<Float64> cur_grad = compute_gradient(/*...*/);
|
||||
for (size_t i = 0; i != batch_gradient.size(); ++i) {
|
||||
batch_gradient[i] += cur_grad[i];
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<Float64> get() final
|
||||
{
|
||||
std::vector<Float64> result(batch_gradient.size());
|
||||
for (size_t i = 0; i != batch_gradient.size(); ++i) {
|
||||
result[i] = batch_gradient[i] / cur_batch;
|
||||
batch_gradient[i] = 0.0;
|
||||
}
|
||||
cur_batch = 0;
|
||||
return result;
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual std::vector<Float64> compute_gradient(/* weights, */Float64 target, const IColumn ** columns, size_t row_num) = 0;
|
||||
|
||||
private:
|
||||
UInt32 cur_batch = 0;
|
||||
std::vector<Float64> batch_gradient; // gradient for bias lies in batch_gradient[batch_gradient.size() - 1]
|
||||
};
|
||||
|
||||
class LinearRegression : public IGradientComputer
|
||||
{
|
||||
public:
|
||||
virtual ~LinearRegression()
|
||||
{}
|
||||
|
||||
protected:
|
||||
virtual std::vector<Float64> compute_gradient(/* weights, */Float64 target, const IColumn ** columns, size_t row_num)
|
||||
{
|
||||
// TODO
|
||||
}
|
||||
};
|
||||
|
||||
class IWeightsUpdater
|
||||
{
|
||||
public:
|
||||
virtual ~IWeightsUpdater()
|
||||
{}
|
||||
|
||||
virtual void update(/* weights, gradient */) = 0;
|
||||
};
|
||||
|
||||
class GradientDescent : public IWeightsUpdater
|
||||
{
|
||||
public:
|
||||
virtual ~GradientDescent()
|
||||
{}
|
||||
|
||||
virtual void update(/* weights, gradient */) = 0 {
|
||||
// TODO
|
||||
}
|
||||
};
|
||||
|
||||
struct LinearModelData
|
||||
{
|
||||
LinearModelData()
|
||||
{}
|
||||
|
||||
LinearModelData(Float64 learning_rate_, UInt32 param_num_, UInt32 batch_size_)
|
||||
: learning_rate(learning_rate_), batch_size(batch_size_) {
|
||||
weights.resize(param_num_, Float64{0.0});
|
||||
batch_gradient.resize(param_num_ + 1, Float64{0.0});
|
||||
cur_batch = 0;
|
||||
}
|
||||
|
||||
std::vector<Float64> weights;
|
||||
Float64 bias{0.0};
|
||||
std::shared_ptr<IGradientComputer> gradient_computer;
|
||||
std::shared_ptr<IWeightsUpdater> weights_updater;
|
||||
|
||||
void add(Float64 target, const IColumn ** columns, size_t row_num)
|
||||
{
|
||||
gradient_cumputer->add(target, columns, row_num);
|
||||
if (cur_batch == batch_size)
|
||||
{
|
||||
cur_batch = 0;
|
||||
weights_updater->update(/* weights */, gradient_computer->get());
|
||||
}
|
||||
}
|
||||
|
||||
void merge(const LinearModelData & rhs)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
struct LinearRegressionData
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user