mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-10-08 17:40:49 +00:00
commit
5a0068ff9b
@ -11,7 +11,7 @@ namespace
|
||||
{
|
||||
|
||||
using FuncLinearRegression = AggregateFunctionMLMethod<LinearModelData, NameLinearRegression>;
|
||||
|
||||
using FuncLogisticRegression = AggregateFunctionMLMethod<LinearModelData, NameLogisticRegression>;
|
||||
template <class Method>
|
||||
AggregateFunctionPtr createAggregateFunctionMLMethod(
|
||||
const std::string & name, const DataTypes & argument_types, const Array & parameters)
|
||||
@ -46,6 +46,9 @@ AggregateFunctionPtr createAggregateFunctionMLMethod(
|
||||
if (std::is_same<Method, FuncLinearRegression>::value)
|
||||
{
|
||||
gc = std::make_shared<LinearRegression>(argument_types.size());
|
||||
} else if (std::is_same<Method, FuncLogisticRegression>::value)
|
||||
{
|
||||
gc = std::make_shared<LogisticRegression>(argument_types.size());
|
||||
} else
|
||||
{
|
||||
throw Exception("Such gradient computer is not implemented yet", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
@ -55,6 +58,9 @@ AggregateFunctionPtr createAggregateFunctionMLMethod(
|
||||
if (applyVisitor(FieldVisitorConvertToNumber<UInt32>(), parameters[2]) == Float64{1.0})
|
||||
{
|
||||
wu = std::make_shared<StochasticGradientDescent>();
|
||||
} else if (applyVisitor(FieldVisitorConvertToNumber<UInt32>(), parameters[2]) == Float64{2.0})
|
||||
{
|
||||
wu = std::make_shared<Momentum>();
|
||||
} else
|
||||
{
|
||||
throw Exception("Such weights updater is not implemented yet", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
@ -75,6 +81,7 @@ AggregateFunctionPtr createAggregateFunctionMLMethod(
|
||||
void registerAggregateFunctionMLMethod(AggregateFunctionFactory & factory)
|
||||
{
|
||||
factory.registerFunction("LinearRegression", createAggregateFunctionMLMethod<FuncLinearRegression>);
|
||||
factory.registerFunction("LogisticRegression", createAggregateFunctionMLMethod<FuncLogisticRegression>);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -59,6 +59,7 @@ public:
|
||||
{
|
||||
return batch_gradient;
|
||||
}
|
||||
virtual Float64 predict(const std::vector<Float64> & predict_feature, const std::vector<Float64> & weights, Float64 bias) const = 0;
|
||||
|
||||
protected:
|
||||
std::vector<Float64> batch_gradient; // gradient for bias lies in batch_gradient[batch_gradient.size() - 1]
|
||||
@ -87,6 +88,67 @@ public:
|
||||
batch_gradient[i] += derivative * static_cast<const ColumnVector<Float64> &>(*columns[i + 1]).getData()[row_num];;
|
||||
}
|
||||
}
|
||||
Float64 predict(const std::vector<Float64> & predict_feature, const std::vector<Float64> & weights, Float64 bias) const override
|
||||
{
|
||||
/// не обновляем веса при предикте, т.к. это может замедлить предсказание
|
||||
/// однако можно например обновлять их при каждом мердже не зависимо от того, сколько элементнов в батче
|
||||
// if (cur_batch)
|
||||
// {
|
||||
// update_weights();
|
||||
// }
|
||||
|
||||
Float64 res{0.0};
|
||||
for (size_t i = 0; i < predict_feature.size(); ++i)
|
||||
{
|
||||
res += predict_feature[i] * weights[i];
|
||||
}
|
||||
res += bias;
|
||||
|
||||
return res;
|
||||
}
|
||||
};
|
||||
class LogisticRegression : public IGradientComputer
|
||||
{
|
||||
public:
|
||||
LogisticRegression(UInt32 sz)
|
||||
: IGradientComputer(sz)
|
||||
{}
|
||||
|
||||
void compute(const std::vector<Float64> & weights, Float64 bias, Float64 learning_rate,
|
||||
Float64 target, const IColumn ** columns, size_t row_num) override
|
||||
{
|
||||
Float64 derivative = bias;
|
||||
for (size_t i = 0; i < weights.size(); ++i)
|
||||
{
|
||||
derivative += weights[i] * static_cast<const ColumnVector<Float64> &>(*columns[i + 1]).getData()[row_num];;
|
||||
}
|
||||
derivative *= target;
|
||||
derivative = learning_rate * exp(derivative);
|
||||
|
||||
batch_gradient[weights.size()] += target / (derivative + 1);;
|
||||
for (size_t i = 0; i < weights.size(); ++i)
|
||||
{
|
||||
batch_gradient[i] += target / (derivative + 1) * static_cast<const ColumnVector<Float64> &>(*columns[i + 1]).getData()[row_num];
|
||||
}
|
||||
}
|
||||
Float64 predict(const std::vector<Float64> & predict_feature, const std::vector<Float64> & weights, Float64 bias) const override
|
||||
{
|
||||
/// не обновляем веса при предикте, т.к. это может замедлить предсказание
|
||||
/// однако можно например обновлять их при каждом мердже не зависимо от того, сколько элементнов в батче
|
||||
// if (cur_batch)
|
||||
// {
|
||||
// update_weights();
|
||||
// }
|
||||
|
||||
Float64 res{0.0};
|
||||
for (size_t i = 0; i < predict_feature.size(); ++i)
|
||||
{
|
||||
res += predict_feature[i] * weights[i];
|
||||
}
|
||||
res += bias;
|
||||
res = 1 / (1 + exp(-res));
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
class IWeightsUpdater
|
||||
@ -95,6 +157,7 @@ public:
|
||||
virtual ~IWeightsUpdater() = default;
|
||||
|
||||
virtual void update(UInt32 cur_batch, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & gradient) = 0;
|
||||
virtual void merge(const std::shared_ptr<IWeightsUpdater>, Float64, Float64) {}
|
||||
};
|
||||
|
||||
class StochasticGradientDescent : public IWeightsUpdater
|
||||
@ -108,7 +171,37 @@ public:
|
||||
bias += batch_gradient[weights.size()] / cur_batch;
|
||||
}
|
||||
};
|
||||
class Momentum : public IWeightsUpdater
|
||||
{
|
||||
public:
|
||||
Momentum() {}
|
||||
Momentum (Float64 alpha) : alpha_(alpha) {}
|
||||
void update(UInt32 cur_batch, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & batch_gradient) override {
|
||||
if (hk_.size() == 0)
|
||||
{
|
||||
hk_.resize(batch_gradient.size(), Float64{0.0});
|
||||
}
|
||||
for (size_t i = 0; i < batch_gradient.size(); ++i)
|
||||
{
|
||||
hk_[i] = hk_[i] * alpha_ + batch_gradient[i];
|
||||
}
|
||||
for (size_t i = 0; i < weights.size(); ++i)
|
||||
{
|
||||
weights[i] += hk_[i] / cur_batch;
|
||||
}
|
||||
bias += hk_[weights.size()] / cur_batch;
|
||||
}
|
||||
virtual void merge(const std::shared_ptr<IWeightsUpdater> rhs, Float64 frac, Float64 rhs_frac) override {
|
||||
auto momentum_rhs = std::dynamic_pointer_cast<Momentum>(rhs);
|
||||
for (size_t i = 0; i < hk_.size(); ++i)
|
||||
{
|
||||
hk_[i] = hk_[i] * frac + momentum_rhs->hk_[i] * rhs_frac;
|
||||
}
|
||||
}
|
||||
|
||||
Float64 alpha_{0.1};
|
||||
std::vector<Float64> hk_;
|
||||
};
|
||||
class LinearModelData
|
||||
{
|
||||
public:
|
||||
@ -160,6 +253,7 @@ public:
|
||||
|
||||
bias = bias * frac + rhs.bias * rhs_frac;
|
||||
iter_num += rhs.iter_num;
|
||||
weights_updater->merge(rhs.weights_updater, frac, rhs_frac);
|
||||
}
|
||||
|
||||
void write(WriteBuffer & buf) const
|
||||
@ -189,14 +283,7 @@ public:
|
||||
// update_weights();
|
||||
// }
|
||||
|
||||
Float64 res{0.0};
|
||||
for (size_t i = 0; i < predict_feature.size(); ++i)
|
||||
{
|
||||
res += predict_feature[i] * weights[i];
|
||||
}
|
||||
res += bias;
|
||||
|
||||
return res;
|
||||
return gradient_computer->predict(predict_feature, weights, bias);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -316,5 +403,5 @@ private:
|
||||
};
|
||||
|
||||
struct NameLinearRegression { static constexpr auto name = "LinearRegression"; };
|
||||
|
||||
struct NameLogisticRegression { static constexpr auto name = "LogisticRegression"; };
|
||||
}
|
||||
|
@ -123,19 +123,27 @@ MutableColumnPtr ColumnAggregateFunction::predictValues(Block & block, const Col
|
||||
}
|
||||
|
||||
// auto ML_function = typeid_cast<AggregateFunctionMLMethod<LinearRegressionData, NameLinearRegression> *>(func.get());
|
||||
auto ML_function = typeid_cast<AggregateFunctionMLMethod<LinearModelData, NameLinearRegression> *>(func.get());
|
||||
if (ML_function)
|
||||
auto ML_function_Linear = typeid_cast<AggregateFunctionMLMethod<LinearModelData, NameLinearRegression> *>(func.get());
|
||||
auto ML_function_Logistic = typeid_cast<AggregateFunctionMLMethod<LinearModelData, NameLogisticRegression> *>(func.get());
|
||||
if (ML_function_Linear)
|
||||
{
|
||||
size_t row_num = 0;
|
||||
for (auto val : data) {
|
||||
ML_function->predictResultInto(val, *res, block, row_num, arguments);
|
||||
ML_function_Linear->predictResultInto(val, *res, block, row_num, arguments);
|
||||
++row_num;
|
||||
}
|
||||
} else {
|
||||
} else if (ML_function_Logistic)
|
||||
{
|
||||
size_t row_num = 0;
|
||||
for (auto val : data) {
|
||||
ML_function_Logistic->predictResultInto(val, *res, block, row_num, arguments);
|
||||
++row_num;
|
||||
}
|
||||
} else
|
||||
{
|
||||
throw Exception("Illegal aggregate function is passed",
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
|
48
dbms/tests/queries/0_stateless/00900_stest.sql
Normal file
48
dbms/tests/queries/0_stateless/00900_stest.sql
Normal file
File diff suppressed because one or more lines are too long
46
dbms/tests/queries/0_stateless/00901_mytest.sql
Normal file
46
dbms/tests/queries/0_stateless/00901_mytest.sql
Normal file
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user