2019-01-22 21:07:05 +00:00
|
|
|
#pragma once
|
|
|
|
|
2019-04-04 00:17:27 +00:00
|
|
|
#include <Columns/ColumnVector.h>
|
2019-01-22 21:07:05 +00:00
|
|
|
#include <Columns/ColumnsCommon.h>
|
|
|
|
#include <Columns/ColumnsNumber.h>
|
2019-06-03 05:25:48 +00:00
|
|
|
#include <Common/typeid_cast.h>
|
2019-05-14 19:52:29 +00:00
|
|
|
#include <DataTypes/DataTypesNumber.h>
|
2019-05-27 20:14:23 +00:00
|
|
|
#include <DataTypes/DataTypeTuple.h>
|
|
|
|
#include <DataTypes/DataTypeArray.h>
|
2019-05-14 19:52:29 +00:00
|
|
|
#include "IAggregateFunction.h"
|
2019-01-22 21:07:05 +00:00
|
|
|
|
2019-01-23 11:58:05 +00:00
|
|
|
namespace DB
|
|
|
|
{
|
2019-01-23 01:29:53 +00:00
|
|
|
namespace ErrorCodes
|
|
|
|
{
|
2020-02-25 18:02:41 +00:00
|
|
|
extern const int LOGICAL_ERROR;
|
2019-01-23 01:29:53 +00:00
|
|
|
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
|
|
|
|
}
|
2019-01-22 21:07:05 +00:00
|
|
|
|
2019-03-03 08:46:36 +00:00
|
|
|
/**
|
2019-04-08 21:01:10 +00:00
|
|
|
GradientComputer class computes gradient according to its loss function
|
2019-03-03 08:46:36 +00:00
|
|
|
*/
|
2019-01-24 14:22:35 +00:00
|
|
|
class IGradientComputer
|
|
|
|
{
|
|
|
|
public:
|
2020-11-17 13:24:45 +00:00
|
|
|
IGradientComputer() = default;
|
2019-01-24 14:22:35 +00:00
|
|
|
|
2019-03-03 08:46:36 +00:00
|
|
|
virtual ~IGradientComputer() = default;
|
2019-01-26 12:38:42 +00:00
|
|
|
|
2019-04-08 21:01:10 +00:00
|
|
|
/// Adds computed gradient in new point (weights, bias) to batch_gradient
|
2019-05-14 19:52:29 +00:00
|
|
|
virtual void compute(
|
|
|
|
std::vector<Float64> & batch_gradient,
|
|
|
|
const std::vector<Float64> & weights,
|
|
|
|
Float64 bias,
|
|
|
|
Float64 l2_reg_coef,
|
|
|
|
Float64 target,
|
|
|
|
const IColumn ** columns,
|
2019-06-14 12:33:29 +00:00
|
|
|
size_t row_num) = 0;
|
2019-05-14 19:52:29 +00:00
|
|
|
|
|
|
|
virtual void predict(
|
|
|
|
ColumnVector<Float64>::Container & container,
|
2020-11-17 13:24:45 +00:00
|
|
|
const ColumnsWithTypeAndName & arguments,
|
2019-06-03 05:11:15 +00:00
|
|
|
size_t offset,
|
|
|
|
size_t limit,
|
2019-05-14 19:52:29 +00:00
|
|
|
const std::vector<Float64> & weights,
|
|
|
|
Float64 bias,
|
|
|
|
const Context & context) const = 0;
|
2019-01-24 14:22:35 +00:00
|
|
|
};
|
2019-01-22 21:07:05 +00:00
|
|
|
|
2019-03-03 08:46:36 +00:00
|
|
|
|
2019-01-26 12:38:42 +00:00
|
|
|
class LinearRegression : public IGradientComputer
|
2019-01-23 01:29:53 +00:00
|
|
|
{
|
2019-01-26 12:38:42 +00:00
|
|
|
public:
|
2020-11-17 13:24:45 +00:00
|
|
|
LinearRegression() = default;
|
2019-05-14 19:52:29 +00:00
|
|
|
|
|
|
|
void compute(
|
|
|
|
std::vector<Float64> & batch_gradient,
|
|
|
|
const std::vector<Float64> & weights,
|
|
|
|
Float64 bias,
|
|
|
|
Float64 l2_reg_coef,
|
|
|
|
Float64 target,
|
|
|
|
const IColumn ** columns,
|
|
|
|
size_t row_num) override;
|
|
|
|
|
|
|
|
void predict(
|
|
|
|
ColumnVector<Float64>::Container & container,
|
2020-11-17 13:24:45 +00:00
|
|
|
const ColumnsWithTypeAndName & arguments,
|
2019-06-03 05:11:15 +00:00
|
|
|
size_t offset,
|
|
|
|
size_t limit,
|
2019-05-14 19:52:29 +00:00
|
|
|
const std::vector<Float64> & weights,
|
|
|
|
Float64 bias,
|
|
|
|
const Context & context) const override;
|
2019-01-28 10:39:57 +00:00
|
|
|
};
|
2019-02-10 22:07:47 +00:00
|
|
|
|
2019-03-03 08:46:36 +00:00
|
|
|
|
2019-01-28 11:54:55 +00:00
|
|
|
class LogisticRegression : public IGradientComputer
|
2019-01-28 10:39:57 +00:00
|
|
|
{
|
|
|
|
public:
|
2020-11-17 13:24:45 +00:00
|
|
|
LogisticRegression() = default;
|
2019-05-14 19:52:29 +00:00
|
|
|
|
|
|
|
void compute(
|
|
|
|
std::vector<Float64> & batch_gradient,
|
|
|
|
const std::vector<Float64> & weights,
|
|
|
|
Float64 bias,
|
|
|
|
Float64 l2_reg_coef,
|
|
|
|
Float64 target,
|
|
|
|
const IColumn ** columns,
|
2019-05-30 22:48:44 +00:00
|
|
|
size_t row_num) override;
|
2019-05-14 19:52:29 +00:00
|
|
|
|
|
|
|
void predict(
|
|
|
|
ColumnVector<Float64>::Container & container,
|
2020-11-17 13:24:45 +00:00
|
|
|
const ColumnsWithTypeAndName & arguments,
|
2019-06-03 05:11:15 +00:00
|
|
|
size_t offset,
|
|
|
|
size_t limit,
|
2019-05-14 19:52:29 +00:00
|
|
|
const std::vector<Float64> & weights,
|
|
|
|
Float64 bias,
|
|
|
|
const Context & context) const override;
|
2019-01-26 12:38:42 +00:00
|
|
|
};
|
2019-01-23 18:03:26 +00:00
|
|
|
|
2019-03-03 08:46:36 +00:00
|
|
|
|
|
|
|
/**
|
2019-04-08 21:01:10 +00:00
|
|
|
* IWeightsUpdater class defines the way to update current weights
|
2019-04-20 23:22:42 +00:00
|
|
|
* and uses GradientComputer class on each iteration
|
2019-03-03 08:46:36 +00:00
|
|
|
*/
|
2019-01-26 12:38:42 +00:00
|
|
|
class IWeightsUpdater
|
|
|
|
{
|
|
|
|
public:
|
|
|
|
virtual ~IWeightsUpdater() = default;
|
2019-01-23 18:03:26 +00:00
|
|
|
|
2019-04-20 23:22:42 +00:00
|
|
|
/// Calls GradientComputer to update current mini-batch
|
2020-03-23 02:12:31 +00:00
|
|
|
virtual void addToBatch(
|
2019-05-14 19:52:29 +00:00
|
|
|
std::vector<Float64> & batch_gradient,
|
|
|
|
IGradientComputer & gradient_computer,
|
|
|
|
const std::vector<Float64> & weights,
|
|
|
|
Float64 bias,
|
|
|
|
Float64 l2_reg_coef,
|
|
|
|
Float64 target,
|
|
|
|
const IColumn ** columns,
|
|
|
|
size_t row_num);
|
2019-03-03 08:46:36 +00:00
|
|
|
|
2019-04-20 23:22:42 +00:00
|
|
|
/// Updates current weights according to the gradient from the last mini-batch
|
2019-07-16 21:11:10 +00:00
|
|
|
virtual void update(
|
|
|
|
UInt64 batch_size,
|
|
|
|
std::vector<Float64> & weights,
|
|
|
|
Float64 & bias,
|
|
|
|
Float64 learning_rate,
|
|
|
|
const std::vector<Float64> & gradient) = 0;
|
2019-03-03 08:46:36 +00:00
|
|
|
|
2019-04-20 23:22:42 +00:00
|
|
|
/// Used during the merge of two states
|
2019-05-14 19:52:29 +00:00
|
|
|
virtual void merge(const IWeightsUpdater &, Float64, Float64) {}
|
2019-03-03 08:46:36 +00:00
|
|
|
|
2019-04-20 23:22:42 +00:00
|
|
|
/// Used for serialization when necessary
|
2019-05-14 19:52:29 +00:00
|
|
|
virtual void write(WriteBuffer &) const {}
|
2019-03-03 08:46:36 +00:00
|
|
|
|
2019-04-20 23:22:42 +00:00
|
|
|
/// Used for serialization when necessary
|
2019-05-14 19:52:29 +00:00
|
|
|
virtual void read(ReadBuffer &) {}
|
2019-01-26 12:38:42 +00:00
|
|
|
};
|
|
|
|
|
2019-03-03 08:46:36 +00:00
|
|
|
|
2019-01-26 12:38:42 +00:00
|
|
|
class StochasticGradientDescent : public IWeightsUpdater
|
|
|
|
{
|
|
|
|
public:
|
2019-07-16 21:11:10 +00:00
|
|
|
void update(UInt64 batch_size, std::vector<Float64> & weights, Float64 & bias, Float64 learning_rate, const std::vector<Float64> & batch_gradient) override;
|
2019-01-26 12:38:42 +00:00
|
|
|
};
|
2019-02-10 22:07:47 +00:00
|
|
|
|
2019-03-03 08:46:36 +00:00
|
|
|
|
2019-01-28 10:39:57 +00:00
|
|
|
class Momentum : public IWeightsUpdater
|
|
|
|
{
|
|
|
|
public:
|
2020-11-17 13:24:45 +00:00
|
|
|
Momentum() = default;
|
2019-03-03 08:46:36 +00:00
|
|
|
|
2020-11-17 13:24:45 +00:00
|
|
|
explicit Momentum(Float64 alpha_) : alpha(alpha_) {}
|
2019-02-10 22:07:47 +00:00
|
|
|
|
2019-07-16 21:11:10 +00:00
|
|
|
void update(UInt64 batch_size, std::vector<Float64> & weights, Float64 & bias, Float64 learning_rate, const std::vector<Float64> & batch_gradient) override;
|
2019-03-03 08:46:36 +00:00
|
|
|
|
2019-05-14 19:52:29 +00:00
|
|
|
virtual void merge(const IWeightsUpdater & rhs, Float64 frac, Float64 rhs_frac) override;
|
2019-01-23 18:03:26 +00:00
|
|
|
|
2019-05-14 19:52:29 +00:00
|
|
|
void write(WriteBuffer & buf) const override;
|
2019-03-03 08:46:36 +00:00
|
|
|
|
2019-05-14 19:52:29 +00:00
|
|
|
void read(ReadBuffer & buf) override;
|
2019-03-03 08:46:36 +00:00
|
|
|
|
2019-02-10 22:07:47 +00:00
|
|
|
private:
|
2020-11-17 13:24:45 +00:00
|
|
|
Float64 alpha{0.1};
|
2019-02-15 22:47:56 +00:00
|
|
|
std::vector<Float64> accumulated_gradient;
|
2019-01-28 10:39:57 +00:00
|
|
|
};
|
2019-03-03 08:46:36 +00:00
|
|
|
|
|
|
|
|
2019-02-26 08:12:16 +00:00
|
|
|
class Nesterov : public IWeightsUpdater
|
|
|
|
{
|
|
|
|
public:
|
2020-11-17 13:24:45 +00:00
|
|
|
Nesterov() = default;
|
2019-03-03 08:46:36 +00:00
|
|
|
|
2020-11-17 13:24:45 +00:00
|
|
|
explicit Nesterov(Float64 alpha_) : alpha(alpha_) {}
|
2019-03-03 08:46:36 +00:00
|
|
|
|
2020-03-23 02:12:31 +00:00
|
|
|
void addToBatch(
|
2019-05-14 19:52:29 +00:00
|
|
|
std::vector<Float64> & batch_gradient,
|
|
|
|
IGradientComputer & gradient_computer,
|
|
|
|
const std::vector<Float64> & weights,
|
|
|
|
Float64 bias,
|
|
|
|
Float64 l2_reg_coef,
|
|
|
|
Float64 target,
|
|
|
|
const IColumn ** columns,
|
|
|
|
size_t row_num) override;
|
2019-03-03 08:46:36 +00:00
|
|
|
|
2019-07-16 21:11:10 +00:00
|
|
|
void update(UInt64 batch_size, std::vector<Float64> & weights, Float64 & bias, Float64 learning_rate, const std::vector<Float64> & batch_gradient) override;
|
2019-03-03 08:46:36 +00:00
|
|
|
|
2019-05-14 19:52:29 +00:00
|
|
|
virtual void merge(const IWeightsUpdater & rhs, Float64 frac, Float64 rhs_frac) override;
|
2019-03-03 08:46:36 +00:00
|
|
|
|
2019-05-14 19:52:29 +00:00
|
|
|
void write(WriteBuffer & buf) const override;
|
2019-02-26 08:12:16 +00:00
|
|
|
|
2019-05-14 19:52:29 +00:00
|
|
|
void read(ReadBuffer & buf) override;
|
2019-03-03 08:46:36 +00:00
|
|
|
|
|
|
|
private:
|
2020-11-17 13:24:45 +00:00
|
|
|
const Float64 alpha = 0.9;
|
2019-03-03 08:46:36 +00:00
|
|
|
std::vector<Float64> accumulated_gradient;
|
2019-02-26 08:12:16 +00:00
|
|
|
};
|
2019-03-03 08:46:36 +00:00
|
|
|
|
|
|
|
|
2019-07-14 20:35:34 +00:00
|
|
|
class Adam : public IWeightsUpdater
|
|
|
|
{
|
|
|
|
public:
|
|
|
|
Adam()
|
|
|
|
{
|
2020-11-17 13:24:45 +00:00
|
|
|
beta1_powered = beta1;
|
|
|
|
beta2_powered = beta2;
|
2019-07-14 20:35:34 +00:00
|
|
|
}
|
|
|
|
|
2020-03-23 02:12:31 +00:00
|
|
|
void addToBatch(
|
2019-07-14 20:35:34 +00:00
|
|
|
std::vector<Float64> & batch_gradient,
|
|
|
|
IGradientComputer & gradient_computer,
|
|
|
|
const std::vector<Float64> & weights,
|
|
|
|
Float64 bias,
|
|
|
|
Float64 l2_reg_coef,
|
|
|
|
Float64 target,
|
|
|
|
const IColumn ** columns,
|
|
|
|
size_t row_num) override;
|
|
|
|
|
2019-07-16 21:11:10 +00:00
|
|
|
void update(UInt64 batch_size, std::vector<Float64> & weights, Float64 & bias, Float64 learning_rate, const std::vector<Float64> & batch_gradient) override;
|
2019-07-14 20:35:34 +00:00
|
|
|
|
|
|
|
virtual void merge(const IWeightsUpdater & rhs, Float64 frac, Float64 rhs_frac) override;
|
|
|
|
|
|
|
|
void write(WriteBuffer & buf) const override;
|
|
|
|
|
|
|
|
void read(ReadBuffer & buf) override;
|
|
|
|
|
|
|
|
private:
|
|
|
|
/// beta1 and beta2 hyperparameters have such recommended values
|
2020-11-17 13:24:45 +00:00
|
|
|
const Float64 beta1 = 0.9;
|
|
|
|
const Float64 beta2 = 0.999;
|
|
|
|
const Float64 eps = 0.000001;
|
|
|
|
Float64 beta1_powered;
|
|
|
|
Float64 beta2_powered;
|
2019-07-16 21:11:10 +00:00
|
|
|
|
2019-07-14 20:35:34 +00:00
|
|
|
std::vector<Float64> average_gradient;
|
|
|
|
std::vector<Float64> average_squared_gradient;
|
|
|
|
};
|
|
|
|
|
|
|
|
|
2019-06-14 12:33:29 +00:00
|
|
|
/** LinearModelData is a class which manages current state of learning
|
|
|
|
*/
|
2019-01-26 12:38:42 +00:00
|
|
|
class LinearModelData
|
|
|
|
{
|
|
|
|
public:
|
2020-11-17 13:24:45 +00:00
|
|
|
LinearModelData() = default;
|
2019-01-22 21:07:05 +00:00
|
|
|
|
2019-05-14 19:52:29 +00:00
|
|
|
LinearModelData(
|
2019-08-03 11:02:40 +00:00
|
|
|
Float64 learning_rate_,
|
|
|
|
Float64 l2_reg_coef_,
|
|
|
|
UInt64 param_num_,
|
|
|
|
UInt64 batch_capacity_,
|
|
|
|
std::shared_ptr<IGradientComputer> gradient_computer_,
|
|
|
|
std::shared_ptr<IWeightsUpdater> weights_updater_);
|
2019-01-22 21:07:05 +00:00
|
|
|
|
2019-05-14 19:52:29 +00:00
|
|
|
void add(const IColumn ** columns, size_t row_num);
|
2019-01-23 18:03:26 +00:00
|
|
|
|
2019-05-14 19:52:29 +00:00
|
|
|
void merge(const LinearModelData & rhs);
|
2019-01-22 21:07:05 +00:00
|
|
|
|
2019-05-14 19:52:29 +00:00
|
|
|
void write(WriteBuffer & buf) const;
|
2019-01-22 21:07:05 +00:00
|
|
|
|
2019-05-14 19:52:29 +00:00
|
|
|
void read(ReadBuffer & buf);
|
2019-01-22 21:07:05 +00:00
|
|
|
|
2019-06-03 05:11:15 +00:00
|
|
|
void predict(
|
|
|
|
ColumnVector<Float64>::Container & container,
|
2020-11-17 13:24:45 +00:00
|
|
|
const ColumnsWithTypeAndName & arguments,
|
2019-06-03 05:11:15 +00:00
|
|
|
size_t offset,
|
|
|
|
size_t limit,
|
|
|
|
const Context & context) const;
|
2019-01-26 12:38:42 +00:00
|
|
|
|
2019-05-25 18:41:58 +00:00
|
|
|
void returnWeights(IColumn & to) const;
|
2019-01-26 12:38:42 +00:00
|
|
|
private:
|
|
|
|
std::vector<Float64> weights;
|
2019-03-03 08:46:36 +00:00
|
|
|
Float64 bias{0.0};
|
|
|
|
|
2019-01-26 12:38:42 +00:00
|
|
|
Float64 learning_rate;
|
2019-04-08 21:01:10 +00:00
|
|
|
Float64 l2_reg_coef;
|
2019-07-11 20:56:58 +00:00
|
|
|
UInt64 batch_capacity;
|
2019-03-03 08:46:36 +00:00
|
|
|
|
2019-07-11 20:56:58 +00:00
|
|
|
UInt64 iter_num = 0;
|
2019-03-03 08:46:36 +00:00
|
|
|
std::vector<Float64> gradient_batch;
|
2019-07-11 20:56:58 +00:00
|
|
|
UInt64 batch_size;
|
2019-03-03 08:46:36 +00:00
|
|
|
|
2019-01-26 12:38:42 +00:00
|
|
|
std::shared_ptr<IGradientComputer> gradient_computer;
|
|
|
|
std::shared_ptr<IWeightsUpdater> weights_updater;
|
|
|
|
|
2019-06-14 12:33:29 +00:00
|
|
|
/** The function is called when we want to flush current batch and update our weights
|
|
|
|
*/
|
2020-03-23 02:12:31 +00:00
|
|
|
void updateState();
|
2019-01-22 21:07:05 +00:00
|
|
|
};
|
|
|
|
|
2019-03-03 08:46:36 +00:00
|
|
|
|
2019-01-22 21:07:05 +00:00
|
|
|
template <
|
2019-04-08 22:40:37 +00:00
|
|
|
/// Implemented Machine Learning method
|
|
|
|
typename Data,
|
|
|
|
/// Name of the method
|
2019-05-14 19:52:29 +00:00
|
|
|
typename Name>
|
2019-01-22 21:07:05 +00:00
|
|
|
class AggregateFunctionMLMethod final : public IAggregateFunctionDataHelper<Data, AggregateFunctionMLMethod<Data, Name>>
|
|
|
|
{
|
|
|
|
public:
|
|
|
|
String getName() const override { return Name::name; }
|
|
|
|
|
2019-05-14 19:52:29 +00:00
|
|
|
explicit AggregateFunctionMLMethod(
|
2019-08-03 11:02:40 +00:00
|
|
|
UInt32 param_num_,
|
|
|
|
std::unique_ptr<IGradientComputer> gradient_computer_,
|
|
|
|
std::string weights_updater_name_,
|
|
|
|
Float64 learning_rate_,
|
|
|
|
Float64 l2_reg_coef_,
|
|
|
|
UInt64 batch_size_,
|
2019-05-14 19:52:29 +00:00
|
|
|
const DataTypes & arguments_types,
|
|
|
|
const Array & params)
|
|
|
|
: IAggregateFunctionDataHelper<Data, AggregateFunctionMLMethod<Data, Name>>(arguments_types, params)
|
2019-08-03 11:02:40 +00:00
|
|
|
, param_num(param_num_)
|
|
|
|
, learning_rate(learning_rate_)
|
|
|
|
, l2_reg_coef(l2_reg_coef_)
|
|
|
|
, batch_size(batch_size_)
|
|
|
|
, gradient_computer(std::move(gradient_computer_))
|
|
|
|
, weights_updater_name(std::move(weights_updater_name_))
|
2019-01-22 21:07:05 +00:00
|
|
|
{
|
|
|
|
}
|
|
|
|
|
2019-05-30 22:48:44 +00:00
|
|
|
/// This function is called when SELECT linearRegression(...) is called
|
2019-05-27 20:14:23 +00:00
|
|
|
DataTypePtr getReturnType() const override
|
|
|
|
{
|
|
|
|
return std::make_shared<DataTypeArray>(std::make_shared<DataTypeFloat64>());
|
|
|
|
}
|
|
|
|
|
2019-05-30 21:59:40 +00:00
|
|
|
/// This function is called from evalMLMethod function for correct predictValues call
|
2019-05-27 20:14:23 +00:00
|
|
|
DataTypePtr getReturnTypeToPredict() const override
|
|
|
|
{
|
|
|
|
return std::make_shared<DataTypeNumber<Float64>>();
|
|
|
|
}
|
2019-05-14 19:52:29 +00:00
|
|
|
|
2021-02-01 17:12:12 +00:00
|
|
|
void create(AggregateDataPtr __restrict place) const override
|
2019-01-23 01:29:53 +00:00
|
|
|
{
|
2019-05-30 21:59:40 +00:00
|
|
|
std::shared_ptr<IWeightsUpdater> new_weights_updater;
|
2019-07-11 20:56:58 +00:00
|
|
|
if (weights_updater_name == "SGD")
|
2019-05-30 21:59:40 +00:00
|
|
|
new_weights_updater = std::make_shared<StochasticGradientDescent>();
|
2019-07-11 20:56:58 +00:00
|
|
|
else if (weights_updater_name == "Momentum")
|
2019-05-30 21:59:40 +00:00
|
|
|
new_weights_updater = std::make_shared<Momentum>();
|
2019-07-11 20:56:58 +00:00
|
|
|
else if (weights_updater_name == "Nesterov")
|
2019-05-30 21:59:40 +00:00
|
|
|
new_weights_updater = std::make_shared<Nesterov>();
|
2019-07-14 20:35:34 +00:00
|
|
|
else if (weights_updater_name == "Adam")
|
|
|
|
new_weights_updater = std::make_shared<Adam>();
|
2019-07-11 20:56:58 +00:00
|
|
|
else
|
2019-05-30 21:59:40 +00:00
|
|
|
throw Exception("Illegal name of weights updater (should have been checked earlier)", ErrorCodes::LOGICAL_ERROR);
|
2019-06-14 12:33:29 +00:00
|
|
|
|
2019-05-30 21:59:40 +00:00
|
|
|
new (place) Data(learning_rate, l2_reg_coef, param_num, batch_size, gradient_computer, new_weights_updater);
|
2019-01-23 01:29:53 +00:00
|
|
|
}
|
|
|
|
|
2021-02-01 17:12:12 +00:00
|
|
|
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
2019-01-22 21:07:05 +00:00
|
|
|
{
|
2019-02-10 22:07:47 +00:00
|
|
|
this->data(place).add(columns, row_num);
|
2019-01-22 21:07:05 +00:00
|
|
|
}
|
|
|
|
|
2021-02-01 17:12:12 +00:00
|
|
|
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { this->data(place).merge(this->data(rhs)); }
|
2019-01-22 21:07:05 +00:00
|
|
|
|
2021-02-01 17:12:12 +00:00
|
|
|
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf) const override { this->data(place).write(buf); }
|
2019-01-22 21:07:05 +00:00
|
|
|
|
2021-02-01 17:12:12 +00:00
|
|
|
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, Arena *) const override { this->data(place).read(buf); }
|
2019-01-22 21:07:05 +00:00
|
|
|
|
2019-05-14 19:52:29 +00:00
|
|
|
void predictValues(
|
2019-06-03 05:11:15 +00:00
|
|
|
ConstAggregateDataPtr place,
|
|
|
|
IColumn & to,
|
2020-11-17 13:24:45 +00:00
|
|
|
const ColumnsWithTypeAndName & arguments,
|
2019-06-03 05:11:15 +00:00
|
|
|
size_t offset,
|
|
|
|
size_t limit,
|
|
|
|
const Context & context) const override
|
2019-01-22 21:07:05 +00:00
|
|
|
{
|
2019-01-23 01:29:53 +00:00
|
|
|
if (arguments.size() != param_num + 1)
|
2019-05-14 19:52:29 +00:00
|
|
|
throw Exception(
|
|
|
|
"Predict got incorrect number of arguments. Got: " + std::to_string(arguments.size())
|
|
|
|
+ ". Required: " + std::to_string(param_num + 1),
|
|
|
|
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
2019-01-23 01:29:53 +00:00
|
|
|
|
2019-05-30 22:48:44 +00:00
|
|
|
/// This cast might be correct because column type is based on getReturnTypeToPredict.
|
2019-06-03 05:11:15 +00:00
|
|
|
auto * column = typeid_cast<ColumnFloat64 *>(&to);
|
|
|
|
if (!column)
|
2019-05-30 22:48:44 +00:00
|
|
|
throw Exception("Cast of column of predictions is incorrect. getReturnTypeToPredict must return same value as it is casted to",
|
2020-06-03 12:59:11 +00:00
|
|
|
ErrorCodes::LOGICAL_ERROR);
|
2019-01-22 21:07:05 +00:00
|
|
|
|
2020-10-17 21:41:50 +00:00
|
|
|
this->data(place).predict(column->getData(), arguments, offset, limit, context);
|
2019-01-22 21:07:05 +00:00
|
|
|
}
|
|
|
|
|
2019-05-24 23:18:44 +00:00
|
|
|
/** This function is called if aggregate function without State modifier is selected in a query.
|
|
|
|
* Inserts all weights of the model into the column 'to', so user may use such information if needed
|
|
|
|
*/
|
2021-02-01 17:12:12 +00:00
|
|
|
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
2019-01-23 01:29:53 +00:00
|
|
|
{
|
2019-05-25 18:41:58 +00:00
|
|
|
this->data(place).returnWeights(to);
|
2019-01-22 21:07:05 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
private:
|
2019-07-11 20:56:58 +00:00
|
|
|
UInt64 param_num;
|
2019-01-22 21:07:05 +00:00
|
|
|
Float64 learning_rate;
|
2019-04-08 21:01:10 +00:00
|
|
|
Float64 l2_reg_coef;
|
2019-07-11 20:56:58 +00:00
|
|
|
UInt64 batch_size;
|
2019-04-20 23:22:42 +00:00
|
|
|
std::shared_ptr<IGradientComputer> gradient_computer;
|
2019-05-30 21:59:40 +00:00
|
|
|
std::string weights_updater_name;
|
2019-01-22 21:07:05 +00:00
|
|
|
};
|
|
|
|
|
2019-05-14 19:52:29 +00:00
|
|
|
struct NameLinearRegression
|
|
|
|
{
|
2019-05-30 21:59:40 +00:00
|
|
|
static constexpr auto name = "stochasticLinearRegression";
|
2019-05-14 19:52:29 +00:00
|
|
|
};
|
|
|
|
struct NameLogisticRegression
|
|
|
|
{
|
2019-05-30 21:59:40 +00:00
|
|
|
static constexpr auto name = "stochasticLogisticRegression";
|
2019-05-14 19:52:29 +00:00
|
|
|
};
|
2019-07-11 20:56:58 +00:00
|
|
|
|
2019-01-23 11:58:05 +00:00
|
|
|
}
|