Merge pull request #6 from Quid37/marikon1

logreg and momentum
This commit is contained in:
Quid37 2019-01-28 14:57:41 +03:00 committed by GitHub
commit 5a0068ff9b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 211 additions and 15 deletions

View File

@ -11,7 +11,7 @@ namespace
{
using FuncLinearRegression = AggregateFunctionMLMethod<LinearModelData, NameLinearRegression>;
using FuncLogisticRegression = AggregateFunctionMLMethod<LinearModelData, NameLogisticRegression>;
template <class Method>
AggregateFunctionPtr createAggregateFunctionMLMethod(
const std::string & name, const DataTypes & argument_types, const Array & parameters)
@ -46,6 +46,9 @@ AggregateFunctionPtr createAggregateFunctionMLMethod(
if (std::is_same<Method, FuncLinearRegression>::value)
{
gc = std::make_shared<LinearRegression>(argument_types.size());
} else if (std::is_same<Method, FuncLogisticRegression>::value)
{
gc = std::make_shared<LogisticRegression>(argument_types.size());
} else
{
throw Exception("Such gradient computer is not implemented yet", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
@ -55,6 +58,9 @@ AggregateFunctionPtr createAggregateFunctionMLMethod(
if (applyVisitor(FieldVisitorConvertToNumber<UInt32>(), parameters[2]) == Float64{1.0})
{
wu = std::make_shared<StochasticGradientDescent>();
} else if (applyVisitor(FieldVisitorConvertToNumber<UInt32>(), parameters[2]) == Float64{2.0})
{
wu = std::make_shared<Momentum>();
} else
{
throw Exception("Such weights updater is not implemented yet", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
@ -75,6 +81,7 @@ AggregateFunctionPtr createAggregateFunctionMLMethod(
void registerAggregateFunctionMLMethod(AggregateFunctionFactory & factory)
{
factory.registerFunction("LinearRegression", createAggregateFunctionMLMethod<FuncLinearRegression>);
factory.registerFunction("LogisticRegression", createAggregateFunctionMLMethod<FuncLogisticRegression>);
}
}

View File

@ -59,6 +59,7 @@ public:
{
return batch_gradient;
}
virtual Float64 predict(const std::vector<Float64> & predict_feature, const std::vector<Float64> & weights, Float64 bias) const = 0;
protected:
std::vector<Float64> batch_gradient; // gradient for bias lies in batch_gradient[batch_gradient.size() - 1]
@ -87,6 +88,67 @@ public:
batch_gradient[i] += derivative * static_cast<const ColumnVector<Float64> &>(*columns[i + 1]).getData()[row_num];;
}
}
Float64 predict(const std::vector<Float64> & predict_feature, const std::vector<Float64> & weights, Float64 bias) const override
{
/// не обновляем веса при предикте, т.к. это может замедлить предсказание
/// однако можно например обновлять их при каждом мердже не зависимо от того, сколько элементнов в батче
// if (cur_batch)
// {
// update_weights();
// }
Float64 res{0.0};
for (size_t i = 0; i < predict_feature.size(); ++i)
{
res += predict_feature[i] * weights[i];
}
res += bias;
return res;
}
};
class LogisticRegression : public IGradientComputer
{
public:
LogisticRegression(UInt32 sz)
: IGradientComputer(sz)
{}
void compute(const std::vector<Float64> & weights, Float64 bias, Float64 learning_rate,
Float64 target, const IColumn ** columns, size_t row_num) override
{
Float64 derivative = bias;
for (size_t i = 0; i < weights.size(); ++i)
{
derivative += weights[i] * static_cast<const ColumnVector<Float64> &>(*columns[i + 1]).getData()[row_num];;
}
derivative *= target;
derivative = learning_rate * exp(derivative);
batch_gradient[weights.size()] += target / (derivative + 1);;
for (size_t i = 0; i < weights.size(); ++i)
{
batch_gradient[i] += target / (derivative + 1) * static_cast<const ColumnVector<Float64> &>(*columns[i + 1]).getData()[row_num];
}
}
Float64 predict(const std::vector<Float64> & predict_feature, const std::vector<Float64> & weights, Float64 bias) const override
{
/// не обновляем веса при предикте, т.к. это может замедлить предсказание
/// однако можно например обновлять их при каждом мердже не зависимо от того, сколько элементнов в батче
// if (cur_batch)
// {
// update_weights();
// }
Float64 res{0.0};
for (size_t i = 0; i < predict_feature.size(); ++i)
{
res += predict_feature[i] * weights[i];
}
res += bias;
res = 1 / (1 + exp(-res));
return res;
}
};
class IWeightsUpdater
@ -95,6 +157,7 @@ public:
virtual ~IWeightsUpdater() = default;
virtual void update(UInt32 cur_batch, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & gradient) = 0;
virtual void merge(const std::shared_ptr<IWeightsUpdater>, Float64, Float64) {}
};
class StochasticGradientDescent : public IWeightsUpdater
@ -108,7 +171,37 @@ public:
bias += batch_gradient[weights.size()] / cur_batch;
}
};
class Momentum : public IWeightsUpdater
{
public:
Momentum() {}
Momentum (Float64 alpha) : alpha_(alpha) {}
void update(UInt32 cur_batch, std::vector<Float64> & weights, Float64 & bias, const std::vector<Float64> & batch_gradient) override {
if (hk_.size() == 0)
{
hk_.resize(batch_gradient.size(), Float64{0.0});
}
for (size_t i = 0; i < batch_gradient.size(); ++i)
{
hk_[i] = hk_[i] * alpha_ + batch_gradient[i];
}
for (size_t i = 0; i < weights.size(); ++i)
{
weights[i] += hk_[i] / cur_batch;
}
bias += hk_[weights.size()] / cur_batch;
}
virtual void merge(const std::shared_ptr<IWeightsUpdater> rhs, Float64 frac, Float64 rhs_frac) override {
auto momentum_rhs = std::dynamic_pointer_cast<Momentum>(rhs);
for (size_t i = 0; i < hk_.size(); ++i)
{
hk_[i] = hk_[i] * frac + momentum_rhs->hk_[i] * rhs_frac;
}
}
Float64 alpha_{0.1};
std::vector<Float64> hk_;
};
class LinearModelData
{
public:
@ -160,6 +253,7 @@ public:
bias = bias * frac + rhs.bias * rhs_frac;
iter_num += rhs.iter_num;
weights_updater->merge(rhs.weights_updater, frac, rhs_frac);
}
void write(WriteBuffer & buf) const
@ -189,14 +283,7 @@ public:
// update_weights();
// }
Float64 res{0.0};
for (size_t i = 0; i < predict_feature.size(); ++i)
{
res += predict_feature[i] * weights[i];
}
res += bias;
return res;
return gradient_computer->predict(predict_feature, weights, bias);
}
private:
@ -316,5 +403,5 @@ private:
};
struct NameLinearRegression { static constexpr auto name = "LinearRegression"; };
struct NameLogisticRegression { static constexpr auto name = "LogisticRegression"; };
}

View File

@ -123,19 +123,27 @@ MutableColumnPtr ColumnAggregateFunction::predictValues(Block & block, const Col
}
// auto ML_function = typeid_cast<AggregateFunctionMLMethod<LinearRegressionData, NameLinearRegression> *>(func.get());
auto ML_function = typeid_cast<AggregateFunctionMLMethod<LinearModelData, NameLinearRegression> *>(func.get());
if (ML_function)
auto ML_function_Linear = typeid_cast<AggregateFunctionMLMethod<LinearModelData, NameLinearRegression> *>(func.get());
auto ML_function_Logistic = typeid_cast<AggregateFunctionMLMethod<LinearModelData, NameLogisticRegression> *>(func.get());
if (ML_function_Linear)
{
size_t row_num = 0;
for (auto val : data) {
ML_function->predictResultInto(val, *res, block, row_num, arguments);
ML_function_Linear->predictResultInto(val, *res, block, row_num, arguments);
++row_num;
}
} else {
} else if (ML_function_Logistic)
{
size_t row_num = 0;
for (auto val : data) {
ML_function_Logistic->predictResultInto(val, *res, block, row_num, arguments);
++row_num;
}
} else
{
throw Exception("Illegal aggregate function is passed",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
return res;
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long