Merge pull request #3 from Quid37/alexkoja_ML

mini-batches
This commit is contained in:
Quid37 2019-01-23 21:11:37 +03:00 committed by GitHub
commit e45e1c3ddc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 372 additions and 21 deletions

View File

@ -16,8 +16,8 @@ template <class Method>
AggregateFunctionPtr createAggregateFunctionMLMethod(
const std::string & name, const DataTypes & argument_types, const Array & parameters)
{
if (parameters.size() > 1)
throw Exception("Aggregate function " + name + " requires at most one parameter", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
if (parameters.size() > 2)
throw Exception("Aggregate function " + name + " requires at most two parameters", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
for (size_t i = 0; i < argument_types.size(); ++i)
{
@ -26,18 +26,23 @@ AggregateFunctionPtr createAggregateFunctionMLMethod(
+ std::to_string(i) + "for aggregate function " + name,
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
Float64 learning_rate;
if (parameters.empty())
{
learning_rate = Float64(0.01);
} else
Float64 learning_rate = Float64(0.01);
UInt32 batch_size = 1;
if (!parameters.empty())
{
learning_rate = applyVisitor(FieldVisitorConvertToNumber<Float64>(), parameters[0]);
}
if (parameters.size() > 1)
{
batch_size = applyVisitor(FieldVisitorConvertToNumber<UInt32>(), parameters[1]);
}
if (argument_types.size() < 2)
throw Exception("Aggregate function " + name + " requires at least two arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
return std::make_shared<Method>(argument_types.size() - 1, learning_rate);
return std::make_shared<Method>(argument_types.size() - 1, learning_rate, batch_size);
}
}

View File

@ -34,33 +34,63 @@ struct LinearRegressionData
{
LinearRegressionData()
{}
LinearRegressionData(Float64 learning_rate_, UInt32 param_num_)
: learning_rate(learning_rate_) {
weights.resize(param_num_);
LinearRegressionData(Float64 learning_rate_, UInt32 param_num_, UInt32 batch_size_)
: learning_rate(learning_rate_), batch_size(batch_size_) {
weights.resize(param_num_, Float64{0.0});
batch_gradient.resize(param_num_ + 1, Float64{0.0});
cur_batch = 0;
}
Float64 bias{0.0};
std::vector<Float64> weights;
Float64 learning_rate;
UInt32 iter_num = 0;
std::vector<Float64> batch_gradient;
UInt32 cur_batch;
UInt32 batch_size;
void add(Float64 target, const IColumn ** columns, size_t row_num)
void update_gradient(Float64 target, const IColumn ** columns, size_t row_num)
{
Float64 derivative = (target - bias);
for (size_t i = 0; i < weights.size(); ++i)
{
derivative -= weights[i] * static_cast<const ColumnVector<Float64> &>(*columns[i + 1]).getData()[row_num];
}
derivative *= (2 * learning_rate);
bias += derivative;
// bias += derivative;
batch_gradient[weights.size()] += derivative;
for (size_t i = 0; i < weights.size(); ++i)
{
weights[i] += derivative * static_cast<const ColumnVector<Float64> &>(*columns[i + 1]).getData()[row_num];;
batch_gradient[i] += derivative * static_cast<const ColumnVector<Float64> &>(*columns[i + 1]).getData()[row_num];;
}
}
void update_weights()
{
if (!cur_batch)
return;
for (size_t i = 0; i < weights.size(); ++i)
{
weights[i] += batch_gradient[i] / cur_batch;
}
bias += batch_gradient[weights.size()] / cur_batch;
batch_gradient.assign(batch_gradient.size(), Float64{0.0});
++iter_num;
cur_batch = 0;
}
void add(Float64 target, const IColumn ** columns, size_t row_num)
{
update_gradient(target, columns, row_num);
++cur_batch;
if (cur_batch == batch_size)
{
update_weights();
}
}
void merge(const LinearRegressionData & rhs)
@ -68,6 +98,10 @@ struct LinearRegressionData
if (iter_num == 0 && rhs.iter_num == 0)
return;
update_weights();
/// нельзя обновить из-за константости
// rhs.update_weights();
Float64 frac = static_cast<Float64>(iter_num) / (iter_num + rhs.iter_num);
Float64 rhs_frac = static_cast<Float64>(rhs.iter_num) / (iter_num + rhs.iter_num);
@ -85,6 +119,8 @@ struct LinearRegressionData
writeBinary(bias, buf);
writeBinary(weights, buf);
writeBinary(iter_num, buf);
writeBinary(batch_gradient, buf);
writeBinary(cur_batch, buf);
}
void read(ReadBuffer & buf)
@ -92,10 +128,19 @@ struct LinearRegressionData
readBinary(bias, buf);
readBinary(weights, buf);
readBinary(iter_num, buf);
readBinary(batch_gradient, buf);
readBinary(cur_batch, buf);
}
Float64 predict(const std::vector<Float64>& predict_feature) const
{
/// не обновляем веса при предикте, т.к. это может замедлить предсказание
/// однако можно например обновлять их при каждом мердже не зависимо от того, сколько элементнов в батче
// if (cur_batch)
// {
// update_weights();
// }
Float64 res{0.0};
for (size_t i = 0; i < predict_feature.size(); ++i)
{
@ -118,8 +163,8 @@ class AggregateFunctionMLMethod final : public IAggregateFunctionDataHelper<Data
public:
String getName() const override { return Name::name; }
explicit AggregateFunctionMLMethod(UInt32 param_num, Float64 learning_rate)
: param_num(param_num), learning_rate(learning_rate)
explicit AggregateFunctionMLMethod(UInt32 param_num, Float64 learning_rate, UInt32 batch_size)
: param_num(param_num), learning_rate(learning_rate), batch_size(batch_size)
{
}
@ -130,7 +175,7 @@ public:
void create(AggregateDataPtr place) const override
{
new (place) Data(learning_rate, param_num);
new (place) Data(learning_rate, param_num, batch_size);
}
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
@ -140,6 +185,7 @@ public:
this->data(place).add(target.getData()[row_num], columns, row_num);
}
/// хочется не константный rhs
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
{
this->data(place).merge(this->data(rhs));
@ -188,6 +234,7 @@ public:
private:
UInt32 param_num;
Float64 learning_rate;
UInt32 batch_size;
};
struct NameLinearRegression { static constexpr auto name = "LinearRegression"; };

View File

@ -122,7 +122,7 @@ MutableColumnPtr ColumnAggregateFunction::predictValues(Block & block, const Col
return res;
}
auto ML_function = typeid_cast<const AggregateFunctionMLMethod<LinearRegressionData, NameLinearRegression> *>(func.get());
auto ML_function = typeid_cast<AggregateFunctionMLMethod<LinearRegressionData, NameLinearRegression> *>(func.get());
if (ML_function)
{
size_t row_num = 0;

View File

@ -1 +1,300 @@
66.80107268499746
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556
-66.99407003753556

File diff suppressed because one or more lines are too long