simple linear regression

This commit is contained in:
alexander kozhikhov 2019-01-23 00:07:05 +03:00
parent c70e8cc5f0
commit 61bb3b8ade
9 changed files with 401 additions and 0 deletions

View File

@ -0,0 +1,40 @@
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/AggregateFunctionMLMethod.h>
#include <AggregateFunctions/Helpers.h>
#include <AggregateFunctions/FactoryHelpers.h>
namespace DB
{
namespace
{
using FuncLinearRegression = AggregateFunctionMLMethod<LinearRegressionData, NameLinearRegression>;
template <class Method>
AggregateFunctionPtr createAggregateFunctionMLMethod(
const std::string & name, const DataTypes & arguments, const Array & parameters)
{
if (parameters.size() > 1)
throw Exception("Aggregate function " + name + " requires at most one parameter", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
Float64 lr;
if (parameters.empty())
lr = Float64(0.01);
else
lr = static_cast<const Float64>(parameters[0].template get<Float64>());
if (arguments.size() < 2)
throw Exception("Aggregate function " + name + " requires at least two arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
return std::make_shared<Method>(arguments.size() - 1, lr);
}
}
void registerAggregateFunctionMLMethod(AggregateFunctionFactory & factory) {
factory.registerFunction("LinearRegression", createAggregateFunctionMLMethod<FuncLinearRegression>);
}
}

View File

@ -0,0 +1,180 @@
#pragma once
#include <type_traits>
#include <IO/WriteHelpers.h>
#include <IO/ReadHelpers.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypesDecimal.h>
#include <Columns/ColumnVector.h>
#include <AggregateFunctions/IAggregateFunction.h>
#include <cmath>
#include <exception>
#include <Columns/ColumnsCommon.h>
#include <Columns/ColumnsNumber.h>
#include <Functions/FunctionHelpers.h>
#include <Common/FieldVisitors.h>
namespace DB {
struct LinearRegressionData {
Float64 bias{0.0};
std::vector<Float64> w1;
Float64 learning_rate{0.01};
UInt32 iter_num = 0;
UInt32 param_num = 0;
void add(Float64 target, std::vector<Float64>& feature, Float64 learning_rate_, UInt32 param_num_) {
if (w1.empty()) {
learning_rate = learning_rate_;
param_num = param_num_;
w1.resize(param_num);
}
Float64 derivative = (target - bias);
for (size_t i = 0; i < param_num; ++i)
{
derivative -= w1[i] * feature[i];
}
derivative *= (2 * learning_rate);
bias += derivative;
for (size_t i = 0; i < param_num; ++i)
{
w1[i] += derivative * feature[i];
}
++iter_num;
}
void merge(const LinearRegressionData & rhs) {
if (iter_num == 0 && rhs.iter_num == 0)
throw std::runtime_error("Strange...");
if (param_num == 0) {
param_num = rhs.param_num;
w1.resize(param_num);
}
Float64 frac = static_cast<Float64>(iter_num) / (iter_num + rhs.iter_num);
Float64 rhs_frac = static_cast<Float64>(rhs.iter_num) / (iter_num + rhs.iter_num);
for (size_t i = 0; i < param_num; ++i)
{
w1[i] = w1[i] * frac + rhs.w1[i] * rhs_frac;
}
bias = bias * frac + rhs.bias * rhs_frac;
iter_num += rhs.iter_num;
}
void write(WriteBuffer & buf) const {
writeBinary(bias, buf);
writeBinary(w1, buf);
writeBinary(iter_num, buf);
}
void read(ReadBuffer & buf) {
readBinary(bias, buf);
readBinary(w1, buf);
readBinary(iter_num, buf);
}
Float64 predict(std::vector<Float64>& predict_feature) const {
Float64 res{0.0};
for (size_t i = 0; i < static_cast<size_t>(param_num); ++i)
{
res += predict_feature[i] * w1[i];
}
res += bias;
return res;
}
};
template <
/// Implemented Machine Learning method
typename Data,
/// Name of the method
typename Name
>
class AggregateFunctionMLMethod final : public IAggregateFunctionDataHelper<Data, AggregateFunctionMLMethod<Data, Name>>
{
public:
String getName() const override { return Name::name; }
explicit AggregateFunctionMLMethod(UInt32 param_num, Float64 learning_rate)
: param_num(param_num), learning_rate(learning_rate)
{
}
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeNumber<Float64>>();
}
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
{
const auto & target = static_cast<const ColumnVector<Float64> &>(*columns[0]);
std::vector<Float64> x(param_num);
for (size_t i = 0; i < param_num; ++i)
{
x[i] = static_cast<const ColumnVector<Float64> &>(*columns[i + 1]).getData()[row_num];
}
this->data(place).add(target.getData()[row_num], x, learning_rate, param_num);
}
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
{
this->data(place).merge(this->data(rhs));
}
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{
this->data(place).write(buf);
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override
{
this->data(place).read(buf);
}
void predictResultInto(ConstAggregateDataPtr place, IColumn & to, Block & block, size_t row_num, const ColumnNumbers & arguments) const
{
auto &column = dynamic_cast<ColumnVector<Float64> &>(to);
std::vector<Float64> predict_features(arguments.size() - 1);
// for (size_t row_num = 0, rows = block.rows(); row_num < rows; ++row_num) {
for (size_t i = 1; i < arguments.size(); ++i) {
// predict_features[i] = array_elements[i].get<Float64>();
predict_features[i - 1] = applyVisitor(FieldVisitorConvertToNumber<Float64>(), (*block.getByPosition(arguments[i]).column)[row_num]);
}
// column.getData().push_back(this->data(place).predict(predict_features));
column.getData().push_back(this->data(place).predict(predict_features));
// }
}
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override {
auto &column = dynamic_cast<ColumnVector<Float64> &>(to);
std::ignore = column;
std::ignore = place;
}
const char * getHeaderFilePath() const override { return __FILE__; }
private:
UInt32 param_num;
Float64 learning_rate;
};
struct NameLinearRegression { static constexpr auto name = "LinearRegression"; };
}

View File

@ -27,6 +27,7 @@ void registerAggregateFunctionUniqUpTo(AggregateFunctionFactory &);
void registerAggregateFunctionTopK(AggregateFunctionFactory &); void registerAggregateFunctionTopK(AggregateFunctionFactory &);
void registerAggregateFunctionsBitwise(AggregateFunctionFactory &); void registerAggregateFunctionsBitwise(AggregateFunctionFactory &);
void registerAggregateFunctionsMaxIntersections(AggregateFunctionFactory &); void registerAggregateFunctionsMaxIntersections(AggregateFunctionFactory &);
void registerAggregateFunctionMLMethod(AggregateFunctionFactory &);
void registerAggregateFunctionCombinatorIf(AggregateFunctionCombinatorFactory &); void registerAggregateFunctionCombinatorIf(AggregateFunctionCombinatorFactory &);
void registerAggregateFunctionCombinatorArray(AggregateFunctionCombinatorFactory &); void registerAggregateFunctionCombinatorArray(AggregateFunctionCombinatorFactory &);
@ -65,6 +66,7 @@ void registerAggregateFunctions()
registerAggregateFunctionsMaxIntersections(factory); registerAggregateFunctionsMaxIntersections(factory);
registerAggregateFunctionHistogram(factory); registerAggregateFunctionHistogram(factory);
registerAggregateFunctionRetention(factory); registerAggregateFunctionRetention(factory);
registerAggregateFunctionMLMethod(factory);
} }
{ {

View File

@ -8,6 +8,7 @@
#include <Common/Arena.h> #include <Common/Arena.h>
#include <Columns/ColumnsCommon.h> #include <Columns/ColumnsCommon.h>
#include <AggregateFinctions/AggregateFunctionMLMethod.h>
namespace DB namespace DB
{ {
@ -80,6 +81,35 @@ MutableColumnPtr ColumnAggregateFunction::convertToValues() const
return res; return res;
} }
//MutableColumnPtr ColumnAggregateFunction::predictValues(std::vector<Float64> predict_feature) const
MutableColumnPtr ColumnAggregateFunction::predictValues(Block & block, const ColumnNumbers & arguments) const
{
if (const AggregateFunctionState * function_state = typeid_cast<const AggregateFunctionState *>(func.get()))
{
auto res = createView();
res->set(function_state->getNestedFunction());
res->data.assign(data.begin(), data.end());
return res;
}
MutableColumnPtr res = func->getReturnType()->createColumn();
res->reserve(data.size());
// const AggregateFunctionMLMethod * ML_function = typeid_cast<const AggregateFunctionMLMethod *>(func.get());
auto ML_function = typeid_cast<const AggregateFunctionMLMethod<LinearRegressionData, NameLinearRegression> *>(func.get());
if (ML_function)
{
size_t row_num = 0;
for (auto val : data) {
ML_function->predictResultInto(val, *res, block, row_num, arguments);
++row_num;
}
} else {
}
return res;
}
void ColumnAggregateFunction::ensureOwnership() void ColumnAggregateFunction::ensureOwnership()
{ {

View File

@ -10,6 +10,7 @@
#include <IO/WriteBuffer.h> #include <IO/WriteBuffer.h>
#include <IO/WriteHelpers.h> #include <IO/WriteHelpers.h>
#include <Functions/FunctionHelpers.h>
namespace DB namespace DB
{ {
@ -115,6 +116,8 @@ public:
std::string getName() const override { return "AggregateFunction(" + func->getName() + ")"; } std::string getName() const override { return "AggregateFunction(" + func->getName() + ")"; }
const char * getFamilyName() const override { return "AggregateFunction"; } const char * getFamilyName() const override { return "AggregateFunction"; }
MutableColumnPtr predictValues(Block & block, const ColumnNumbers & arguments) const;
size_t size() const override size_t size() const override
{ {
return getData().size(); return getData().size();

View File

@ -0,0 +1,95 @@
#include <Functions/IFunction.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <DataTypes/DataTypeAggregateFunction.h>
#include <Columns/ColumnAggregateFunction.h>
#include <Common/typeid_cast.h>
#include <Columns/ColumnVector.h>
#include <Columns/ColumnsNumber.h>
#include <iostream>
#include <Common/PODArray.h>
#include <Columns/ColumnArray.h>
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_COLUMN;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
/** finalizeAggregation(agg_state) - get the result from the aggregation state.
* Takes state of aggregate function. Returns result of aggregation (finalized state).
*/
class FunctionEvalMLMethod : public IFunction
{
public:
static constexpr auto name = "evalMLMethod";
static FunctionPtr create(const Context &)
{
return std::make_shared<FunctionEvalMLMethod>();
}
String getName() const override
{
return name;
}
bool isVariadic() const override {
return true;
}
size_t getNumberOfArguments() const override
{
return 0;
}
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
const DataTypeAggregateFunction * type = checkAndGetDataType<DataTypeAggregateFunction>(arguments[0].get());
if (!type)
throw Exception("Argument for function " + getName() + " must have type AggregateFunction - state of aggregate function.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return type->getReturnType();
}
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) override
{
const ColumnAggregateFunction * column_with_states
= typeid_cast<const ColumnAggregateFunction *>(&*block.getByPosition(arguments.at(0)).column);
if (!column_with_states)
throw Exception("Illegal column " + block.getByPosition(arguments.at(0)).column->getName()
+ " of first argument of function "
+ getName(),
ErrorCodes::ILLEGAL_COLUMN);
// const ColumnArray * col_array = checkAndGetColumnConstData<ColumnArray>(block.getByPosition(arguments[1]).column.get());
// if (!col_array)
// throw std::runtime_error("wtf");
// const IColumn & array_elements = col_array->getData();
/*
std::vector<Float64> predict_features(arguments.size());
for (size_t i = 1; i < arguments.size(); ++i)
{
// predict_features[i] = array_elements[i].get<Float64>();
predict_features[i - 1] = typeid_cast<const ColumnConst *>(block.getByPosition(arguments[i]).column.get())->getValue<Float64>();
}
block.getByPosition(result).column = column_with_states->predictValues(predict_features);
*/
block.getByPosition(result).column = column_with_states->predictValues(block, arguments);
}
};
void registerFunctionEvalMLMethod(FunctionFactory & factory)
{
factory.registerFunction<FunctionEvalMLMethod>();
}
}

View File

@ -42,6 +42,7 @@ void registerFunctionLowCardinalityKeys(FunctionFactory &);
void registerFunctionsIn(FunctionFactory &); void registerFunctionsIn(FunctionFactory &);
void registerFunctionJoinGet(FunctionFactory &); void registerFunctionJoinGet(FunctionFactory &);
void registerFunctionFilesystem(FunctionFactory &); void registerFunctionFilesystem(FunctionFactory &);
void registerFunctionEvalMLMethod(FunctionFactory &);
void registerFunctionsMiscellaneous(FunctionFactory & factory) void registerFunctionsMiscellaneous(FunctionFactory & factory)
{ {
@ -84,6 +85,7 @@ void registerFunctionsMiscellaneous(FunctionFactory & factory)
registerFunctionsIn(factory); registerFunctionsIn(factory);
registerFunctionJoinGet(factory); registerFunctionJoinGet(factory);
registerFunctionFilesystem(factory); registerFunctionFilesystem(factory);
registerFunctionEvalMLMethod(factory);
} }
} }

View File

@ -0,0 +1 @@
66.80107268499746

File diff suppressed because one or more lines are too long