mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-19 14:11:58 +00:00
commit
e45e1c3ddc
@ -16,8 +16,8 @@ template <class Method>
|
||||
AggregateFunctionPtr createAggregateFunctionMLMethod(
|
||||
const std::string & name, const DataTypes & argument_types, const Array & parameters)
|
||||
{
|
||||
if (parameters.size() > 1)
|
||||
throw Exception("Aggregate function " + name + " requires at most one parameter", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
if (parameters.size() > 2)
|
||||
throw Exception("Aggregate function " + name + " requires at most two parameters", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
|
||||
for (size_t i = 0; i < argument_types.size(); ++i)
|
||||
{
|
||||
@ -26,18 +26,23 @@ AggregateFunctionPtr createAggregateFunctionMLMethod(
|
||||
+ std::to_string(i) + "for aggregate function " + name,
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
}
|
||||
Float64 learning_rate;
|
||||
if (parameters.empty())
|
||||
{
|
||||
learning_rate = Float64(0.01);
|
||||
} else
|
||||
|
||||
Float64 learning_rate = Float64(0.01);
|
||||
UInt32 batch_size = 1;
|
||||
if (!parameters.empty())
|
||||
{
|
||||
learning_rate = applyVisitor(FieldVisitorConvertToNumber<Float64>(), parameters[0]);
|
||||
}
|
||||
if (parameters.size() > 1)
|
||||
{
|
||||
batch_size = applyVisitor(FieldVisitorConvertToNumber<UInt32>(), parameters[1]);
|
||||
|
||||
}
|
||||
|
||||
if (argument_types.size() < 2)
|
||||
throw Exception("Aggregate function " + name + " requires at least two arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
|
||||
return std::make_shared<Method>(argument_types.size() - 1, learning_rate);
|
||||
return std::make_shared<Method>(argument_types.size() - 1, learning_rate, batch_size);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -34,33 +34,63 @@ struct LinearRegressionData
|
||||
{
|
||||
LinearRegressionData()
|
||||
{}
|
||||
LinearRegressionData(Float64 learning_rate_, UInt32 param_num_)
|
||||
: learning_rate(learning_rate_) {
|
||||
weights.resize(param_num_);
|
||||
LinearRegressionData(Float64 learning_rate_, UInt32 param_num_, UInt32 batch_size_)
|
||||
: learning_rate(learning_rate_), batch_size(batch_size_) {
|
||||
weights.resize(param_num_, Float64{0.0});
|
||||
batch_gradient.resize(param_num_ + 1, Float64{0.0});
|
||||
cur_batch = 0;
|
||||
}
|
||||
|
||||
Float64 bias{0.0};
|
||||
std::vector<Float64> weights;
|
||||
Float64 learning_rate;
|
||||
UInt32 iter_num = 0;
|
||||
std::vector<Float64> batch_gradient;
|
||||
UInt32 cur_batch;
|
||||
UInt32 batch_size;
|
||||
|
||||
void add(Float64 target, const IColumn ** columns, size_t row_num)
|
||||
void update_gradient(Float64 target, const IColumn ** columns, size_t row_num)
|
||||
{
|
||||
Float64 derivative = (target - bias);
|
||||
for (size_t i = 0; i < weights.size(); ++i)
|
||||
{
|
||||
derivative -= weights[i] * static_cast<const ColumnVector<Float64> &>(*columns[i + 1]).getData()[row_num];
|
||||
|
||||
}
|
||||
derivative *= (2 * learning_rate);
|
||||
|
||||
bias += derivative;
|
||||
// bias += derivative;
|
||||
batch_gradient[weights.size()] += derivative;
|
||||
for (size_t i = 0; i < weights.size(); ++i)
|
||||
{
|
||||
weights[i] += derivative * static_cast<const ColumnVector<Float64> &>(*columns[i + 1]).getData()[row_num];;
|
||||
batch_gradient[i] += derivative * static_cast<const ColumnVector<Float64> &>(*columns[i + 1]).getData()[row_num];;
|
||||
}
|
||||
}
|
||||
|
||||
void update_weights()
|
||||
{
|
||||
if (!cur_batch)
|
||||
return;
|
||||
|
||||
for (size_t i = 0; i < weights.size(); ++i)
|
||||
{
|
||||
weights[i] += batch_gradient[i] / cur_batch;
|
||||
}
|
||||
bias += batch_gradient[weights.size()] / cur_batch;
|
||||
|
||||
batch_gradient.assign(batch_gradient.size(), Float64{0.0});
|
||||
|
||||
++iter_num;
|
||||
cur_batch = 0;
|
||||
}
|
||||
|
||||
void add(Float64 target, const IColumn ** columns, size_t row_num)
|
||||
{
|
||||
update_gradient(target, columns, row_num);
|
||||
++cur_batch;
|
||||
if (cur_batch == batch_size)
|
||||
{
|
||||
update_weights();
|
||||
}
|
||||
}
|
||||
|
||||
void merge(const LinearRegressionData & rhs)
|
||||
@ -68,6 +98,10 @@ struct LinearRegressionData
|
||||
if (iter_num == 0 && rhs.iter_num == 0)
|
||||
return;
|
||||
|
||||
update_weights();
|
||||
/// нельзя обновить из-за константости
|
||||
// rhs.update_weights();
|
||||
|
||||
Float64 frac = static_cast<Float64>(iter_num) / (iter_num + rhs.iter_num);
|
||||
Float64 rhs_frac = static_cast<Float64>(rhs.iter_num) / (iter_num + rhs.iter_num);
|
||||
|
||||
@ -85,6 +119,8 @@ struct LinearRegressionData
|
||||
writeBinary(bias, buf);
|
||||
writeBinary(weights, buf);
|
||||
writeBinary(iter_num, buf);
|
||||
writeBinary(batch_gradient, buf);
|
||||
writeBinary(cur_batch, buf);
|
||||
}
|
||||
|
||||
void read(ReadBuffer & buf)
|
||||
@ -92,10 +128,19 @@ struct LinearRegressionData
|
||||
readBinary(bias, buf);
|
||||
readBinary(weights, buf);
|
||||
readBinary(iter_num, buf);
|
||||
readBinary(batch_gradient, buf);
|
||||
readBinary(cur_batch, buf);
|
||||
}
|
||||
|
||||
Float64 predict(const std::vector<Float64>& predict_feature) const
|
||||
{
|
||||
/// не обновляем веса при предикте, т.к. это может замедлить предсказание
|
||||
/// однако можно например обновлять их при каждом мердже не зависимо от того, сколько элементнов в батче
|
||||
// if (cur_batch)
|
||||
// {
|
||||
// update_weights();
|
||||
// }
|
||||
|
||||
Float64 res{0.0};
|
||||
for (size_t i = 0; i < predict_feature.size(); ++i)
|
||||
{
|
||||
@ -118,8 +163,8 @@ class AggregateFunctionMLMethod final : public IAggregateFunctionDataHelper<Data
|
||||
public:
|
||||
String getName() const override { return Name::name; }
|
||||
|
||||
explicit AggregateFunctionMLMethod(UInt32 param_num, Float64 learning_rate)
|
||||
: param_num(param_num), learning_rate(learning_rate)
|
||||
explicit AggregateFunctionMLMethod(UInt32 param_num, Float64 learning_rate, UInt32 batch_size)
|
||||
: param_num(param_num), learning_rate(learning_rate), batch_size(batch_size)
|
||||
{
|
||||
}
|
||||
|
||||
@ -130,7 +175,7 @@ public:
|
||||
|
||||
void create(AggregateDataPtr place) const override
|
||||
{
|
||||
new (place) Data(learning_rate, param_num);
|
||||
new (place) Data(learning_rate, param_num, batch_size);
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
@ -140,6 +185,7 @@ public:
|
||||
this->data(place).add(target.getData()[row_num], columns, row_num);
|
||||
}
|
||||
|
||||
/// хочется не константный rhs
|
||||
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
this->data(place).merge(this->data(rhs));
|
||||
@ -188,6 +234,7 @@ public:
|
||||
private:
|
||||
UInt32 param_num;
|
||||
Float64 learning_rate;
|
||||
UInt32 batch_size;
|
||||
};
|
||||
|
||||
struct NameLinearRegression { static constexpr auto name = "LinearRegression"; };
|
||||
|
@ -122,7 +122,7 @@ MutableColumnPtr ColumnAggregateFunction::predictValues(Block & block, const Col
|
||||
return res;
|
||||
}
|
||||
|
||||
auto ML_function = typeid_cast<const AggregateFunctionMLMethod<LinearRegressionData, NameLinearRegression> *>(func.get());
|
||||
auto ML_function = typeid_cast<AggregateFunctionMLMethod<LinearRegressionData, NameLinearRegression> *>(func.get());
|
||||
if (ML_function)
|
||||
{
|
||||
size_t row_num = 0;
|
||||
|
@ -1 +1,300 @@
|
||||
66.80107268499746
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
-66.99407003753556
|
||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user