ClickHouse/src/AggregateFunctions/AggregateFunctionMLMethod.h

414 lines
12 KiB
C++
Raw Normal View History

2019-01-22 21:07:05 +00:00
#pragma once
2019-04-04 00:17:27 +00:00
#include <Columns/ColumnVector.h>
2019-01-22 21:07:05 +00:00
#include <Columns/ColumnsCommon.h>
#include <Columns/ColumnsNumber.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeTuple.h>
#include <DataTypes/DataTypesNumber.h>
#include <Common/typeid_cast.h>
#include "IAggregateFunction.h"
2019-01-22 21:07:05 +00:00
2019-01-23 11:58:05 +00:00
namespace DB
{
struct Settings;
2019-01-23 01:29:53 +00:00
namespace ErrorCodes
{
2020-02-25 18:02:41 +00:00
extern const int LOGICAL_ERROR;
2019-01-23 01:29:53 +00:00
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
2019-01-22 21:07:05 +00:00
2019-03-03 08:46:36 +00:00
/**
2019-04-08 21:01:10 +00:00
GradientComputer class computes gradient according to its loss function
2019-03-03 08:46:36 +00:00
*/
2019-01-24 14:22:35 +00:00
class IGradientComputer
{
public:
IGradientComputer() = default;
2019-01-24 14:22:35 +00:00
2019-03-03 08:46:36 +00:00
virtual ~IGradientComputer() = default;
2019-01-26 12:38:42 +00:00
2019-04-08 21:01:10 +00:00
/// Adds computed gradient in new point (weights, bias) to batch_gradient
virtual void compute(
std::vector<Float64> & batch_gradient,
const std::vector<Float64> & weights,
Float64 bias,
Float64 l2_reg_coef,
Float64 target,
const IColumn ** columns,
size_t row_num) = 0;
virtual void predict(
ColumnVector<Float64>::Container & container,
const ColumnsWithTypeAndName & arguments,
size_t offset,
size_t limit,
const std::vector<Float64> & weights,
Float64 bias,
ContextConstPtr context) const = 0;
2019-01-24 14:22:35 +00:00
};
2019-01-22 21:07:05 +00:00
2019-03-03 08:46:36 +00:00
2019-01-26 12:38:42 +00:00
class LinearRegression : public IGradientComputer
2019-01-23 01:29:53 +00:00
{
2019-01-26 12:38:42 +00:00
public:
LinearRegression() = default;
void compute(
std::vector<Float64> & batch_gradient,
const std::vector<Float64> & weights,
Float64 bias,
Float64 l2_reg_coef,
Float64 target,
const IColumn ** columns,
size_t row_num) override;
void predict(
ColumnVector<Float64>::Container & container,
const ColumnsWithTypeAndName & arguments,
size_t offset,
size_t limit,
const std::vector<Float64> & weights,
Float64 bias,
ContextConstPtr context) const override;
2019-01-28 10:39:57 +00:00
};
2019-02-10 22:07:47 +00:00
2019-03-03 08:46:36 +00:00
2019-01-28 11:54:55 +00:00
class LogisticRegression : public IGradientComputer
2019-01-28 10:39:57 +00:00
{
public:
LogisticRegression() = default;
void compute(
std::vector<Float64> & batch_gradient,
const std::vector<Float64> & weights,
Float64 bias,
Float64 l2_reg_coef,
Float64 target,
const IColumn ** columns,
2019-05-30 22:48:44 +00:00
size_t row_num) override;
void predict(
ColumnVector<Float64>::Container & container,
const ColumnsWithTypeAndName & arguments,
size_t offset,
size_t limit,
const std::vector<Float64> & weights,
Float64 bias,
ContextConstPtr context) const override;
2019-01-26 12:38:42 +00:00
};
2019-01-23 18:03:26 +00:00
2019-03-03 08:46:36 +00:00
/**
2019-04-08 21:01:10 +00:00
* IWeightsUpdater class defines the way to update current weights
2019-04-20 23:22:42 +00:00
* and uses GradientComputer class on each iteration
2019-03-03 08:46:36 +00:00
*/
2019-01-26 12:38:42 +00:00
class IWeightsUpdater
{
public:
virtual ~IWeightsUpdater() = default;
2019-01-23 18:03:26 +00:00
2019-04-20 23:22:42 +00:00
/// Calls GradientComputer to update current mini-batch
virtual void addToBatch(
std::vector<Float64> & batch_gradient,
IGradientComputer & gradient_computer,
const std::vector<Float64> & weights,
Float64 bias,
Float64 l2_reg_coef,
Float64 target,
const IColumn ** columns,
size_t row_num);
2019-03-03 08:46:36 +00:00
2019-04-20 23:22:42 +00:00
/// Updates current weights according to the gradient from the last mini-batch
2019-07-16 21:11:10 +00:00
virtual void update(
UInt64 batch_size,
std::vector<Float64> & weights,
Float64 & bias,
Float64 learning_rate,
const std::vector<Float64> & gradient) = 0;
2019-03-03 08:46:36 +00:00
2019-04-20 23:22:42 +00:00
/// Used during the merge of two states
virtual void merge(const IWeightsUpdater &, Float64, Float64) {}
2019-03-03 08:46:36 +00:00
2019-04-20 23:22:42 +00:00
/// Used for serialization when necessary
virtual void write(WriteBuffer &) const {}
2019-03-03 08:46:36 +00:00
2019-04-20 23:22:42 +00:00
/// Used for serialization when necessary
virtual void read(ReadBuffer &) {}
2019-01-26 12:38:42 +00:00
};
2019-03-03 08:46:36 +00:00
2019-01-26 12:38:42 +00:00
class StochasticGradientDescent : public IWeightsUpdater
{
public:
2019-07-16 21:11:10 +00:00
void update(UInt64 batch_size, std::vector<Float64> & weights, Float64 & bias, Float64 learning_rate, const std::vector<Float64> & batch_gradient) override;
2019-01-26 12:38:42 +00:00
};
2019-02-10 22:07:47 +00:00
2019-03-03 08:46:36 +00:00
2019-01-28 10:39:57 +00:00
class Momentum : public IWeightsUpdater
{
public:
Momentum() = default;
2019-03-03 08:46:36 +00:00
explicit Momentum(Float64 alpha_) : alpha(alpha_) {}
2019-02-10 22:07:47 +00:00
2019-07-16 21:11:10 +00:00
void update(UInt64 batch_size, std::vector<Float64> & weights, Float64 & bias, Float64 learning_rate, const std::vector<Float64> & batch_gradient) override;
2019-03-03 08:46:36 +00:00
virtual void merge(const IWeightsUpdater & rhs, Float64 frac, Float64 rhs_frac) override;
2019-01-23 18:03:26 +00:00
void write(WriteBuffer & buf) const override;
2019-03-03 08:46:36 +00:00
void read(ReadBuffer & buf) override;
2019-03-03 08:46:36 +00:00
2019-02-10 22:07:47 +00:00
private:
Float64 alpha{0.1};
2019-02-15 22:47:56 +00:00
std::vector<Float64> accumulated_gradient;
2019-01-28 10:39:57 +00:00
};
2019-03-03 08:46:36 +00:00
2019-02-26 08:12:16 +00:00
class Nesterov : public IWeightsUpdater
{
public:
Nesterov() = default;
2019-03-03 08:46:36 +00:00
explicit Nesterov(Float64 alpha_) : alpha(alpha_) {}
2019-03-03 08:46:36 +00:00
void addToBatch(
std::vector<Float64> & batch_gradient,
IGradientComputer & gradient_computer,
const std::vector<Float64> & weights,
Float64 bias,
Float64 l2_reg_coef,
Float64 target,
const IColumn ** columns,
size_t row_num) override;
2019-03-03 08:46:36 +00:00
2019-07-16 21:11:10 +00:00
void update(UInt64 batch_size, std::vector<Float64> & weights, Float64 & bias, Float64 learning_rate, const std::vector<Float64> & batch_gradient) override;
2019-03-03 08:46:36 +00:00
virtual void merge(const IWeightsUpdater & rhs, Float64 frac, Float64 rhs_frac) override;
2019-03-03 08:46:36 +00:00
void write(WriteBuffer & buf) const override;
2019-02-26 08:12:16 +00:00
void read(ReadBuffer & buf) override;
2019-03-03 08:46:36 +00:00
private:
const Float64 alpha = 0.9;
2019-03-03 08:46:36 +00:00
std::vector<Float64> accumulated_gradient;
2019-02-26 08:12:16 +00:00
};
2019-03-03 08:46:36 +00:00
2019-07-14 20:35:34 +00:00
class Adam : public IWeightsUpdater
{
public:
Adam()
{
beta1_powered = beta1;
beta2_powered = beta2;
2019-07-14 20:35:34 +00:00
}
void addToBatch(
2019-07-14 20:35:34 +00:00
std::vector<Float64> & batch_gradient,
IGradientComputer & gradient_computer,
const std::vector<Float64> & weights,
Float64 bias,
Float64 l2_reg_coef,
Float64 target,
const IColumn ** columns,
size_t row_num) override;
2019-07-16 21:11:10 +00:00
void update(UInt64 batch_size, std::vector<Float64> & weights, Float64 & bias, Float64 learning_rate, const std::vector<Float64> & batch_gradient) override;
2019-07-14 20:35:34 +00:00
virtual void merge(const IWeightsUpdater & rhs, Float64 frac, Float64 rhs_frac) override;
void write(WriteBuffer & buf) const override;
void read(ReadBuffer & buf) override;
private:
/// beta1 and beta2 hyperparameters have such recommended values
const Float64 beta1 = 0.9;
const Float64 beta2 = 0.999;
const Float64 eps = 0.000001;
Float64 beta1_powered;
Float64 beta2_powered;
2019-07-16 21:11:10 +00:00
2019-07-14 20:35:34 +00:00
std::vector<Float64> average_gradient;
std::vector<Float64> average_squared_gradient;
};
/** LinearModelData is a class which manages current state of learning
*/
2019-01-26 12:38:42 +00:00
class LinearModelData
{
public:
LinearModelData() = default;
2019-01-22 21:07:05 +00:00
LinearModelData(
2019-08-03 11:02:40 +00:00
Float64 learning_rate_,
Float64 l2_reg_coef_,
UInt64 param_num_,
UInt64 batch_capacity_,
std::shared_ptr<IGradientComputer> gradient_computer_,
std::shared_ptr<IWeightsUpdater> weights_updater_);
2019-01-22 21:07:05 +00:00
void add(const IColumn ** columns, size_t row_num);
2019-01-23 18:03:26 +00:00
void merge(const LinearModelData & rhs);
2019-01-22 21:07:05 +00:00
void write(WriteBuffer & buf) const;
2019-01-22 21:07:05 +00:00
void read(ReadBuffer & buf);
2019-01-22 21:07:05 +00:00
void predict(
ColumnVector<Float64>::Container & container,
const ColumnsWithTypeAndName & arguments,
size_t offset,
size_t limit,
ContextConstPtr context) const;
2019-01-26 12:38:42 +00:00
2019-05-25 18:41:58 +00:00
void returnWeights(IColumn & to) const;
2019-01-26 12:38:42 +00:00
private:
std::vector<Float64> weights;
2019-03-03 08:46:36 +00:00
Float64 bias{0.0};
2019-01-26 12:38:42 +00:00
Float64 learning_rate;
2019-04-08 21:01:10 +00:00
Float64 l2_reg_coef;
2019-07-11 20:56:58 +00:00
UInt64 batch_capacity;
2019-03-03 08:46:36 +00:00
2019-07-11 20:56:58 +00:00
UInt64 iter_num = 0;
2019-03-03 08:46:36 +00:00
std::vector<Float64> gradient_batch;
2019-07-11 20:56:58 +00:00
UInt64 batch_size;
2019-03-03 08:46:36 +00:00
2019-01-26 12:38:42 +00:00
std::shared_ptr<IGradientComputer> gradient_computer;
std::shared_ptr<IWeightsUpdater> weights_updater;
/** The function is called when we want to flush current batch and update our weights
*/
void updateState();
2019-01-22 21:07:05 +00:00
};
2019-03-03 08:46:36 +00:00
2019-01-22 21:07:05 +00:00
template <
2019-04-08 22:40:37 +00:00
/// Implemented Machine Learning method
typename Data,
/// Name of the method
typename Name>
2019-01-22 21:07:05 +00:00
class AggregateFunctionMLMethod final : public IAggregateFunctionDataHelper<Data, AggregateFunctionMLMethod<Data, Name>>
{
public:
String getName() const override { return Name::name; }
explicit AggregateFunctionMLMethod(
2019-08-03 11:02:40 +00:00
UInt32 param_num_,
std::unique_ptr<IGradientComputer> gradient_computer_,
std::string weights_updater_name_,
Float64 learning_rate_,
Float64 l2_reg_coef_,
UInt64 batch_size_,
const DataTypes & arguments_types,
const Array & params)
: IAggregateFunctionDataHelper<Data, AggregateFunctionMLMethod<Data, Name>>(arguments_types, params)
2019-08-03 11:02:40 +00:00
, param_num(param_num_)
, learning_rate(learning_rate_)
, l2_reg_coef(l2_reg_coef_)
, batch_size(batch_size_)
, gradient_computer(std::move(gradient_computer_))
, weights_updater_name(std::move(weights_updater_name_))
2019-01-22 21:07:05 +00:00
{
}
2019-05-30 22:48:44 +00:00
/// This function is called when SELECT linearRegression(...) is called
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeArray>(std::make_shared<DataTypeFloat64>());
}
bool allocatesMemoryInArena() const override { return false; }
/// This function is called from evalMLMethod function for correct predictValues call
DataTypePtr getReturnTypeToPredict() const override
{
return std::make_shared<DataTypeNumber<Float64>>();
}
void create(AggregateDataPtr __restrict place) const override
2019-01-23 01:29:53 +00:00
{
std::shared_ptr<IWeightsUpdater> new_weights_updater;
2019-07-11 20:56:58 +00:00
if (weights_updater_name == "SGD")
new_weights_updater = std::make_shared<StochasticGradientDescent>();
2019-07-11 20:56:58 +00:00
else if (weights_updater_name == "Momentum")
new_weights_updater = std::make_shared<Momentum>();
2019-07-11 20:56:58 +00:00
else if (weights_updater_name == "Nesterov")
new_weights_updater = std::make_shared<Nesterov>();
2019-07-14 20:35:34 +00:00
else if (weights_updater_name == "Adam")
new_weights_updater = std::make_shared<Adam>();
2019-07-11 20:56:58 +00:00
else
throw Exception("Illegal name of weights updater (should have been checked earlier)", ErrorCodes::LOGICAL_ERROR);
new (place) Data(learning_rate, l2_reg_coef, param_num, batch_size, gradient_computer, new_weights_updater);
2019-01-23 01:29:53 +00:00
}
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
2019-01-22 21:07:05 +00:00
{
2019-02-10 22:07:47 +00:00
this->data(place).add(columns, row_num);
2019-01-22 21:07:05 +00:00
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { this->data(place).merge(this->data(rhs)); }
2019-01-22 21:07:05 +00:00
2021-05-30 13:57:30 +00:00
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override { this->data(place).write(buf); }
2019-01-22 21:07:05 +00:00
2021-05-30 13:57:30 +00:00
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override { this->data(place).read(buf); }
2019-01-22 21:07:05 +00:00
void predictValues(
ConstAggregateDataPtr place,
IColumn & to,
const ColumnsWithTypeAndName & arguments,
size_t offset,
size_t limit,
ContextConstPtr context) const override
2019-01-22 21:07:05 +00:00
{
2019-01-23 01:29:53 +00:00
if (arguments.size() != param_num + 1)
throw Exception(
"Predict got incorrect number of arguments. Got: " + std::to_string(arguments.size())
+ ". Required: " + std::to_string(param_num + 1),
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
2019-01-23 01:29:53 +00:00
2019-05-30 22:48:44 +00:00
/// This cast might be correct because column type is based on getReturnTypeToPredict.
auto * column = typeid_cast<ColumnFloat64 *>(&to);
if (!column)
2019-05-30 22:48:44 +00:00
throw Exception("Cast of column of predictions is incorrect. getReturnTypeToPredict must return same value as it is casted to",
ErrorCodes::LOGICAL_ERROR);
2019-01-22 21:07:05 +00:00
2020-10-17 21:41:50 +00:00
this->data(place).predict(column->getData(), arguments, offset, limit, context);
2019-01-22 21:07:05 +00:00
}
/** This function is called if aggregate function without State modifier is selected in a query.
* Inserts all weights of the model into the column 'to', so user may use such information if needed
*/
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
2019-01-23 01:29:53 +00:00
{
2019-05-25 18:41:58 +00:00
this->data(place).returnWeights(to);
2019-01-22 21:07:05 +00:00
}
private:
2019-07-11 20:56:58 +00:00
UInt64 param_num;
2019-01-22 21:07:05 +00:00
Float64 learning_rate;
2019-04-08 21:01:10 +00:00
Float64 l2_reg_coef;
2019-07-11 20:56:58 +00:00
UInt64 batch_size;
2019-04-20 23:22:42 +00:00
std::shared_ptr<IGradientComputer> gradient_computer;
std::string weights_updater_name;
2019-01-22 21:07:05 +00:00
};
struct NameLinearRegression
{
static constexpr auto name = "stochasticLinearRegression";
};
struct NameLogisticRegression
{
static constexpr auto name = "stochasticLogisticRegression";
};
2019-07-11 20:56:58 +00:00
2019-01-23 11:58:05 +00:00
}