mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-19 22:22:00 +00:00
commit
dc51843065
@ -29,6 +29,104 @@ namespace ErrorCodes
|
|||||||
extern const int BAD_ARGUMENTS;
|
extern const int BAD_ARGUMENTS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class IGradientComputer
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
virtual ~IGradientComputer()
|
||||||
|
{}
|
||||||
|
|
||||||
|
void add(/* weights, */Float64 target, const IColumn ** columns, size_t row_num) final
|
||||||
|
{
|
||||||
|
++cur_batch;
|
||||||
|
std::vector<Float64> cur_grad = compute_gradient(/*...*/);
|
||||||
|
for (size_t i = 0; i != batch_gradient.size(); ++i) {
|
||||||
|
batch_gradient[i] += cur_grad[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Float64> get() final
|
||||||
|
{
|
||||||
|
std::vector<Float64> result(batch_gradient.size());
|
||||||
|
for (size_t i = 0; i != batch_gradient.size(); ++i) {
|
||||||
|
result[i] = batch_gradient[i] / cur_batch;
|
||||||
|
batch_gradient[i] = 0.0;
|
||||||
|
}
|
||||||
|
cur_batch = 0;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
virtual std::vector<Float64> compute_gradient(/* weights, */Float64 target, const IColumn ** columns, size_t row_num) = 0;
|
||||||
|
|
||||||
|
private:
|
||||||
|
UInt32 cur_batch = 0;
|
||||||
|
std::vector<Float64> batch_gradient; // gradient for bias lies in batch_gradient[batch_gradient.size() - 1]
|
||||||
|
};
|
||||||
|
|
||||||
|
class LinearRegression : public IGradientComputer
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
virtual ~LinearRegression()
|
||||||
|
{}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
virtual std::vector<Float64> compute_gradient(/* weights, */Float64 target, const IColumn ** columns, size_t row_num)
|
||||||
|
{
|
||||||
|
// TODO
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class IWeightsUpdater
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
virtual ~IWeightsUpdater()
|
||||||
|
{}
|
||||||
|
|
||||||
|
virtual void update(/* weights, gradient */) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
class GradientDescent : public IWeightsUpdater
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
virtual ~GradientDescent()
|
||||||
|
{}
|
||||||
|
|
||||||
|
virtual void update(/* weights, gradient */) = 0 {
|
||||||
|
// TODO
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct LinearModelData
|
||||||
|
{
|
||||||
|
LinearModelData()
|
||||||
|
{}
|
||||||
|
|
||||||
|
LinearModelData(Float64 learning_rate_, UInt32 param_num_, UInt32 batch_size_)
|
||||||
|
: learning_rate(learning_rate_), batch_size(batch_size_) {
|
||||||
|
weights.resize(param_num_, Float64{0.0});
|
||||||
|
batch_gradient.resize(param_num_ + 1, Float64{0.0});
|
||||||
|
cur_batch = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Float64> weights;
|
||||||
|
Float64 bias{0.0};
|
||||||
|
std::shared_ptr<IGradientComputer> gradient_computer;
|
||||||
|
std::shared_ptr<IWeightsUpdater> weights_updater;
|
||||||
|
|
||||||
|
void add(Float64 target, const IColumn ** columns, size_t row_num)
|
||||||
|
{
|
||||||
|
gradient_cumputer->add(target, columns, row_num);
|
||||||
|
if (cur_batch == batch_size)
|
||||||
|
{
|
||||||
|
cur_batch = 0;
|
||||||
|
weights_updater->update(/* weights */, gradient_computer->get());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void merge(const LinearModelData & rhs)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct LinearRegressionData
|
struct LinearRegressionData
|
||||||
{
|
{
|
||||||
|
Loading…
Reference in New Issue
Block a user