Merge pull request #6 from Quid37/marikon1

logreg and momentum
This commit is contained in:
Quid37 2019-01-28 14:57:41 +03:00 committed by GitHub
commit 5a0068ff9b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 211 additions and 15 deletions

View File

@ -11,7 +11,7 @@ namespace
{ {
using FuncLinearRegression = AggregateFunctionMLMethod<LinearModelData, NameLinearRegression>; using FuncLinearRegression = AggregateFunctionMLMethod<LinearModelData, NameLinearRegression>;
using FuncLogisticRegression = AggregateFunctionMLMethod<LinearModelData, NameLogisticRegression>;
template <class Method> template <class Method>
AggregateFunctionPtr createAggregateFunctionMLMethod( AggregateFunctionPtr createAggregateFunctionMLMethod(
const std::string & name, const DataTypes & argument_types, const Array & parameters) const std::string & name, const DataTypes & argument_types, const Array & parameters)
@ -46,6 +46,9 @@ AggregateFunctionPtr createAggregateFunctionMLMethod(
if (std::is_same<Method, FuncLinearRegression>::value) if (std::is_same<Method, FuncLinearRegression>::value)
{ {
gc = std::make_shared<LinearRegression>(argument_types.size()); gc = std::make_shared<LinearRegression>(argument_types.size());
} else if (std::is_same<Method, FuncLogisticRegression>::value)
{
gc = std::make_shared<LogisticRegression>(argument_types.size());
} else } else
{ {
throw Exception("Such gradient computer is not implemented yet", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); throw Exception("Such gradient computer is not implemented yet", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
@ -55,6 +58,9 @@ AggregateFunctionPtr createAggregateFunctionMLMethod(
if (applyVisitor(FieldVisitorConvertToNumber<UInt32>(), parameters[2]) == Float64{1.0}) if (applyVisitor(FieldVisitorConvertToNumber<UInt32>(), parameters[2]) == Float64{1.0})
{ {
wu = std::make_shared<StochasticGradientDescent>(); wu = std::make_shared<StochasticGradientDescent>();
} else if (applyVisitor(FieldVisitorConvertToNumber<UInt32>(), parameters[2]) == Float64{2.0})
{
wu = std::make_shared<Momentum>();
} else } else
{ {
throw Exception("Such weights updater is not implemented yet", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); throw Exception("Such weights updater is not implemented yet", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
@ -75,6 +81,7 @@ AggregateFunctionPtr createAggregateFunctionMLMethod(
void registerAggregateFunctionMLMethod(AggregateFunctionFactory & factory) void registerAggregateFunctionMLMethod(AggregateFunctionFactory & factory)
{ {
factory.registerFunction("LinearRegression", createAggregateFunctionMLMethod<FuncLinearRegression>); factory.registerFunction("LinearRegression", createAggregateFunctionMLMethod<FuncLinearRegression>);
factory.registerFunction("LogisticRegression", createAggregateFunctionMLMethod<FuncLogisticRegression>);
} }
} }

View File

@ -59,6 +59,7 @@ public:
{ {
return batch_gradient; return batch_gradient;
} }
virtual Float64 predict(const std::vector<Float64> & predict_feature, const std::vector<Float64> & weights, Float64 bias) const = 0;
protected: protected:
std::vector<Float64> batch_gradient; // gradient for bias lies in batch_gradient[batch_gradient.size() - 1] std::vector<Float64> batch_gradient; // gradient for bias lies in batch_gradient[batch_gradient.size() - 1]
@ -87,6 +88,67 @@ public:
batch_gradient[i] += derivative * static_cast<const ColumnVector<Float64> &>(*columns[i + 1]).getData()[row_num];; batch_gradient[i] += derivative * static_cast<const ColumnVector<Float64> &>(*columns[i + 1]).getData()[row_num];;
} }
} }
Float64 predict(const std::vector<Float64> & predict_feature, const std::vector<Float64> & weights, Float64 bias) const override
{
/// не обновляем веса при предикте, т.к. это может замедлить предсказание
/// однако можно например обновлять их при каждом мердже не зависимо от того, сколько элементнов в батче
// if (cur_batch)
// {
// update_weights();
// }
Float64 res{0.0};
for (size_t i = 0; i < predict_feature.size(); ++i)
{
res += predict_feature[i] * weights[i];
}
res += bias;
return res;
}
};
class LogisticRegression : public IGradientComputer
{
public:
LogisticRegression(UInt32 sz)
: IGradientComputer(sz)
{}
void compute(const std::vector<Float64> & weights, Float64 bias, Float64 learning_rate,
Float64 target, const IColumn ** columns, size_t row_num) override
{
Float64 derivative = bias;
for (size_t i = 0; i < weights.size(); ++i)
{
derivative += weights[i] * static_cast<const ColumnVector<Float64> &>(*columns[i + 1]).getData()[row_num];;
}
derivative *= target;
derivative = learning_rate * exp(derivative);
batch_gradient[weights.size()] += target / (derivative + 1);;
for (size_t i = 0; i < weights.size(); ++i)
{
batch_gradient[i] += target / (derivative + 1) * static_cast<const ColumnVector<Float64> &>(*columns[i + 1]).getData()[row_num];
}
}
Float64 predict(const std::vector<Float64> & predict_feature, const std::vector<Float64> & weights, Float64 bias) const override
{
/// не обновляем веса при предикте, т.к. это может замедлить предсказание
/// однако можно например обновлять их при каждом мердже не зависимо от того, сколько элементнов в батче
// if (cur_batch)
// {
// update_weights();
// }
Float64 res{0.0};
for (size_t i = 0; i < predict_feature.size(); ++i)
{
res += predict_feature[i] * weights[i];
}
res += bias;
res = 1 / (1 + exp(-res));
return res;
}
}; };
class IWeightsUpdater class IWeightsUpdater
@ -95,6 +157,7 @@ public:
virtual ~IWeightsUpdater() = default; virtual ~IWeightsUpdater() = default;
virtual void update(UInt32 cur_batch, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & gradient) = 0; virtual void update(UInt32 cur_batch, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & gradient) = 0;
virtual void merge(const std::shared_ptr<IWeightsUpdater>, Float64, Float64) {}
}; };
class StochasticGradientDescent : public IWeightsUpdater class StochasticGradientDescent : public IWeightsUpdater
@ -108,7 +171,37 @@ public:
bias += batch_gradient[weights.size()] / cur_batch; bias += batch_gradient[weights.size()] / cur_batch;
} }
}; };
class Momentum : public IWeightsUpdater
{
public:
Momentum() {}
Momentum (Float64 alpha) : alpha_(alpha) {}
void update(UInt32 cur_batch, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & batch_gradient) override {
if (hk_.size() == 0)
{
hk_.resize(batch_gradient.size(), Float64{0.0});
}
for (size_t i = 0; i < batch_gradient.size(); ++i)
{
hk_[i] = hk_[i] * alpha_ + batch_gradient[i];
}
for (size_t i = 0; i < weights.size(); ++i)
{
weights[i] += hk_[i] / cur_batch;
}
bias += hk_[weights.size()] / cur_batch;
}
virtual void merge(const std::shared_ptr<IWeightsUpdater> rhs, Float64 frac, Float64 rhs_frac) override {
auto momentum_rhs = std::dynamic_pointer_cast<Momentum>(rhs);
for (size_t i = 0; i < hk_.size(); ++i)
{
hk_[i] = hk_[i] * frac + momentum_rhs->hk_[i] * rhs_frac;
}
}
Float64 alpha_{0.1};
std::vector<Float64> hk_;
};
class LinearModelData class LinearModelData
{ {
public: public:
@ -160,6 +253,7 @@ public:
bias = bias * frac + rhs.bias * rhs_frac; bias = bias * frac + rhs.bias * rhs_frac;
iter_num += rhs.iter_num; iter_num += rhs.iter_num;
weights_updater->merge(rhs.weights_updater, frac, rhs_frac);
} }
void write(WriteBuffer & buf) const void write(WriteBuffer & buf) const
@ -189,14 +283,7 @@ public:
// update_weights(); // update_weights();
// } // }
Float64 res{0.0}; return gradient_computer->predict(predict_feature, weights, bias);
for (size_t i = 0; i < predict_feature.size(); ++i)
{
res += predict_feature[i] * weights[i];
}
res += bias;
return res;
} }
private: private:
@ -316,5 +403,5 @@ private:
}; };
struct NameLinearRegression { static constexpr auto name = "LinearRegression"; }; struct NameLinearRegression { static constexpr auto name = "LinearRegression"; };
struct NameLogisticRegression { static constexpr auto name = "LogisticRegression"; };
} }

View File

@ -123,19 +123,27 @@ MutableColumnPtr ColumnAggregateFunction::predictValues(Block & block, const Col
} }
// auto ML_function = typeid_cast<AggregateFunctionMLMethod<LinearRegressionData, NameLinearRegression> *>(func.get()); // auto ML_function = typeid_cast<AggregateFunctionMLMethod<LinearRegressionData, NameLinearRegression> *>(func.get());
auto ML_function = typeid_cast<AggregateFunctionMLMethod<LinearModelData, NameLinearRegression> *>(func.get()); auto ML_function_Linear = typeid_cast<AggregateFunctionMLMethod<LinearModelData, NameLinearRegression> *>(func.get());
if (ML_function) auto ML_function_Logistic = typeid_cast<AggregateFunctionMLMethod<LinearModelData, NameLogisticRegression> *>(func.get());
if (ML_function_Linear)
{ {
size_t row_num = 0; size_t row_num = 0;
for (auto val : data) { for (auto val : data) {
ML_function->predictResultInto(val, *res, block, row_num, arguments); ML_function_Linear->predictResultInto(val, *res, block, row_num, arguments);
++row_num; ++row_num;
} }
} else { } else if (ML_function_Logistic)
{
size_t row_num = 0;
for (auto val : data) {
ML_function_Logistic->predictResultInto(val, *res, block, row_num, arguments);
++row_num;
}
} else
{
throw Exception("Illegal aggregate function is passed", throw Exception("Illegal aggregate function is passed",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
} }
return res; return res;
} }

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long