diff --git a/dbms/src/AggregateFunctions/AggregateFunctionMLMethod.h b/dbms/src/AggregateFunctions/AggregateFunctionMLMethod.h index 650665fc673..0c443a859f8 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionMLMethod.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionMLMethod.h @@ -29,6 +29,104 @@ namespace ErrorCodes extern const int BAD_ARGUMENTS; } +class IGradientComputer +{ +public: + virtual ~IGradientComputer() + {} + + void add(/* weights, */Float64 target, const IColumn ** columns, size_t row_num) final + { + ++cur_batch; + std::vector cur_grad = compute_gradient(/*...*/); + for (size_t i = 0; i != batch_gradient.size(); ++i) { + batch_gradient[i] += cur_grad[i]; + } + } + + std::vector get() final + { + std::vector result(batch_gradient.size()); + for (size_t i = 0; i != batch_gradient.size(); ++i) { + result[i] = batch_gradient[i] / cur_batch; + batch_gradient[i] = 0.0; + } + cur_batch = 0; + return result; + } + +protected: + virtual std::vector compute_gradient(/* weights, */Float64 target, const IColumn ** columns, size_t row_num) = 0; + +private: + UInt32 cur_batch = 0; + std::vector batch_gradient; // gradient for bias lies in batch_gradient[batch_gradient.size() - 1] +}; + +class LinearRegression : public IGradientComputer +{ +public: + virtual ~LinearRegression() + {} + +protected: + virtual std::vector compute_gradient(/* weights, */Float64 target, const IColumn ** columns, size_t row_num) + { + // TODO + } +}; + +class IWeightsUpdater +{ +public: + virtual ~IWeightsUpdater() + {} + + virtual void update(/* weights, gradient */) = 0; +}; + +class GradientDescent : public IWeightsUpdater +{ +public: + virtual ~GradientDescent() + {} + + virtual void update(/* weights, gradient */) = 0 { + // TODO + } +}; + +struct LinearModelData +{ + LinearModelData() + {} + + LinearModelData(Float64 learning_rate_, UInt32 param_num_, UInt32 batch_size_) + : learning_rate(learning_rate_), batch_size(batch_size_) { + weights.resize(param_num_, Float64{0.0}); + batch_gradient.resize(param_num_ + 1, Float64{0.0}); + cur_batch = 0; + } + + std::vector weights; + Float64 bias{0.0}; + std::shared_ptr gradient_computer; + std::shared_ptr weights_updater; + + void add(Float64 target, const IColumn ** columns, size_t row_num) + { + gradient_cumputer->add(target, columns, row_num); + if (cur_batch == batch_size) + { + cur_batch = 0; + weights_updater->update(/* weights */, gradient_computer->get()); + } + } + + void merge(const LinearModelData & rhs) + { + } +}; struct LinearRegressionData {