mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-24 16:42:05 +00:00
simple linear regression
This commit is contained in:
parent
c70e8cc5f0
commit
61bb3b8ade
40
dbms/src/AggregateFunctions/AggregateFunctionMLMethod.cpp
Normal file
40
dbms/src/AggregateFunctions/AggregateFunctionMLMethod.cpp
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||||
|
#include <AggregateFunctions/AggregateFunctionMLMethod.h>
|
||||||
|
#include <AggregateFunctions/Helpers.h>
|
||||||
|
#include <AggregateFunctions/FactoryHelpers.h>
|
||||||
|
|
||||||
|
|
||||||
|
namespace DB
|
||||||
|
{
|
||||||
|
|
||||||
|
namespace
|
||||||
|
{
|
||||||
|
|
||||||
|
using FuncLinearRegression = AggregateFunctionMLMethod<LinearRegressionData, NameLinearRegression>;
|
||||||
|
|
||||||
|
template <class Method>
|
||||||
|
AggregateFunctionPtr createAggregateFunctionMLMethod(
|
||||||
|
const std::string & name, const DataTypes & arguments, const Array & parameters)
|
||||||
|
{
|
||||||
|
if (parameters.size() > 1)
|
||||||
|
throw Exception("Aggregate function " + name + " requires at most one parameter", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||||
|
|
||||||
|
Float64 lr;
|
||||||
|
if (parameters.empty())
|
||||||
|
lr = Float64(0.01);
|
||||||
|
else
|
||||||
|
lr = static_cast<const Float64>(parameters[0].template get<Float64>());
|
||||||
|
|
||||||
|
if (arguments.size() < 2)
|
||||||
|
throw Exception("Aggregate function " + name + " requires at least two arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||||
|
|
||||||
|
return std::make_shared<Method>(arguments.size() - 1, lr);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
void registerAggregateFunctionMLMethod(AggregateFunctionFactory & factory) {
|
||||||
|
factory.registerFunction("LinearRegression", createAggregateFunctionMLMethod<FuncLinearRegression>);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
180
dbms/src/AggregateFunctions/AggregateFunctionMLMethod.h
Normal file
180
dbms/src/AggregateFunctions/AggregateFunctionMLMethod.h
Normal file
@ -0,0 +1,180 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
|
#include <IO/WriteHelpers.h>
|
||||||
|
#include <IO/ReadHelpers.h>
|
||||||
|
|
||||||
|
#include <DataTypes/DataTypesNumber.h>
|
||||||
|
#include <DataTypes/DataTypesDecimal.h>
|
||||||
|
#include <Columns/ColumnVector.h>
|
||||||
|
|
||||||
|
#include <AggregateFunctions/IAggregateFunction.h>
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
#include <exception>
|
||||||
|
|
||||||
|
#include <Columns/ColumnsCommon.h>
|
||||||
|
#include <Columns/ColumnsNumber.h>
|
||||||
|
#include <Functions/FunctionHelpers.h>
|
||||||
|
#include <Common/FieldVisitors.h>
|
||||||
|
|
||||||
|
|
||||||
|
namespace DB {
|
||||||
|
|
||||||
|
struct LinearRegressionData {
|
||||||
|
Float64 bias{0.0};
|
||||||
|
std::vector<Float64> w1;
|
||||||
|
Float64 learning_rate{0.01};
|
||||||
|
UInt32 iter_num = 0;
|
||||||
|
UInt32 param_num = 0;
|
||||||
|
|
||||||
|
|
||||||
|
void add(Float64 target, std::vector<Float64>& feature, Float64 learning_rate_, UInt32 param_num_) {
|
||||||
|
if (w1.empty()) {
|
||||||
|
learning_rate = learning_rate_;
|
||||||
|
param_num = param_num_;
|
||||||
|
w1.resize(param_num);
|
||||||
|
}
|
||||||
|
|
||||||
|
Float64 derivative = (target - bias);
|
||||||
|
for (size_t i = 0; i < param_num; ++i)
|
||||||
|
{
|
||||||
|
derivative -= w1[i] * feature[i];
|
||||||
|
}
|
||||||
|
derivative *= (2 * learning_rate);
|
||||||
|
|
||||||
|
bias += derivative;
|
||||||
|
for (size_t i = 0; i < param_num; ++i)
|
||||||
|
{
|
||||||
|
w1[i] += derivative * feature[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
++iter_num;
|
||||||
|
}
|
||||||
|
|
||||||
|
void merge(const LinearRegressionData & rhs) {
|
||||||
|
if (iter_num == 0 && rhs.iter_num == 0)
|
||||||
|
throw std::runtime_error("Strange...");
|
||||||
|
|
||||||
|
if (param_num == 0) {
|
||||||
|
param_num = rhs.param_num;
|
||||||
|
w1.resize(param_num);
|
||||||
|
}
|
||||||
|
|
||||||
|
Float64 frac = static_cast<Float64>(iter_num) / (iter_num + rhs.iter_num);
|
||||||
|
Float64 rhs_frac = static_cast<Float64>(rhs.iter_num) / (iter_num + rhs.iter_num);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < param_num; ++i)
|
||||||
|
{
|
||||||
|
w1[i] = w1[i] * frac + rhs.w1[i] * rhs_frac;
|
||||||
|
}
|
||||||
|
|
||||||
|
bias = bias * frac + rhs.bias * rhs_frac;
|
||||||
|
iter_num += rhs.iter_num;
|
||||||
|
}
|
||||||
|
|
||||||
|
void write(WriteBuffer & buf) const {
|
||||||
|
writeBinary(bias, buf);
|
||||||
|
writeBinary(w1, buf);
|
||||||
|
writeBinary(iter_num, buf);
|
||||||
|
}
|
||||||
|
|
||||||
|
void read(ReadBuffer & buf) {
|
||||||
|
readBinary(bias, buf);
|
||||||
|
readBinary(w1, buf);
|
||||||
|
readBinary(iter_num, buf);
|
||||||
|
}
|
||||||
|
Float64 predict(std::vector<Float64>& predict_feature) const {
|
||||||
|
Float64 res{0.0};
|
||||||
|
for (size_t i = 0; i < static_cast<size_t>(param_num); ++i)
|
||||||
|
{
|
||||||
|
res += predict_feature[i] * w1[i];
|
||||||
|
}
|
||||||
|
res += bias;
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <
|
||||||
|
/// Implemented Machine Learning method
|
||||||
|
typename Data,
|
||||||
|
/// Name of the method
|
||||||
|
typename Name
|
||||||
|
>
|
||||||
|
class AggregateFunctionMLMethod final : public IAggregateFunctionDataHelper<Data, AggregateFunctionMLMethod<Data, Name>>
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
String getName() const override { return Name::name; }
|
||||||
|
|
||||||
|
explicit AggregateFunctionMLMethod(UInt32 param_num, Float64 learning_rate)
|
||||||
|
: param_num(param_num), learning_rate(learning_rate)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
DataTypePtr getReturnType() const override
|
||||||
|
{
|
||||||
|
return std::make_shared<DataTypeNumber<Float64>>();
|
||||||
|
}
|
||||||
|
|
||||||
|
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||||
|
{
|
||||||
|
const auto & target = static_cast<const ColumnVector<Float64> &>(*columns[0]);
|
||||||
|
|
||||||
|
std::vector<Float64> x(param_num);
|
||||||
|
for (size_t i = 0; i < param_num; ++i)
|
||||||
|
{
|
||||||
|
x[i] = static_cast<const ColumnVector<Float64> &>(*columns[i + 1]).getData()[row_num];
|
||||||
|
}
|
||||||
|
|
||||||
|
this->data(place).add(target.getData()[row_num], x, learning_rate, param_num);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||||
|
{
|
||||||
|
this->data(place).merge(this->data(rhs));
|
||||||
|
}
|
||||||
|
|
||||||
|
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
|
||||||
|
{
|
||||||
|
this->data(place).write(buf);
|
||||||
|
}
|
||||||
|
|
||||||
|
void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override
|
||||||
|
{
|
||||||
|
this->data(place).read(buf);
|
||||||
|
}
|
||||||
|
|
||||||
|
void predictResultInto(ConstAggregateDataPtr place, IColumn & to, Block & block, size_t row_num, const ColumnNumbers & arguments) const
|
||||||
|
{
|
||||||
|
auto &column = dynamic_cast<ColumnVector<Float64> &>(to);
|
||||||
|
|
||||||
|
std::vector<Float64> predict_features(arguments.size() - 1);
|
||||||
|
// for (size_t row_num = 0, rows = block.rows(); row_num < rows; ++row_num) {
|
||||||
|
for (size_t i = 1; i < arguments.size(); ++i) {
|
||||||
|
// predict_features[i] = array_elements[i].get<Float64>();
|
||||||
|
predict_features[i - 1] = applyVisitor(FieldVisitorConvertToNumber<Float64>(), (*block.getByPosition(arguments[i]).column)[row_num]);
|
||||||
|
}
|
||||||
|
// column.getData().push_back(this->data(place).predict(predict_features));
|
||||||
|
column.getData().push_back(this->data(place).predict(predict_features));
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
|
||||||
|
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override {
|
||||||
|
auto &column = dynamic_cast<ColumnVector<Float64> &>(to);
|
||||||
|
std::ignore = column;
|
||||||
|
std::ignore = place;
|
||||||
|
}
|
||||||
|
|
||||||
|
const char * getHeaderFilePath() const override { return __FILE__; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
UInt32 param_num;
|
||||||
|
Float64 learning_rate;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct NameLinearRegression { static constexpr auto name = "LinearRegression"; };
|
||||||
|
|
||||||
|
}
|
@ -27,6 +27,7 @@ void registerAggregateFunctionUniqUpTo(AggregateFunctionFactory &);
|
|||||||
void registerAggregateFunctionTopK(AggregateFunctionFactory &);
|
void registerAggregateFunctionTopK(AggregateFunctionFactory &);
|
||||||
void registerAggregateFunctionsBitwise(AggregateFunctionFactory &);
|
void registerAggregateFunctionsBitwise(AggregateFunctionFactory &);
|
||||||
void registerAggregateFunctionsMaxIntersections(AggregateFunctionFactory &);
|
void registerAggregateFunctionsMaxIntersections(AggregateFunctionFactory &);
|
||||||
|
void registerAggregateFunctionMLMethod(AggregateFunctionFactory &);
|
||||||
|
|
||||||
void registerAggregateFunctionCombinatorIf(AggregateFunctionCombinatorFactory &);
|
void registerAggregateFunctionCombinatorIf(AggregateFunctionCombinatorFactory &);
|
||||||
void registerAggregateFunctionCombinatorArray(AggregateFunctionCombinatorFactory &);
|
void registerAggregateFunctionCombinatorArray(AggregateFunctionCombinatorFactory &);
|
||||||
@ -65,6 +66,7 @@ void registerAggregateFunctions()
|
|||||||
registerAggregateFunctionsMaxIntersections(factory);
|
registerAggregateFunctionsMaxIntersections(factory);
|
||||||
registerAggregateFunctionHistogram(factory);
|
registerAggregateFunctionHistogram(factory);
|
||||||
registerAggregateFunctionRetention(factory);
|
registerAggregateFunctionRetention(factory);
|
||||||
|
registerAggregateFunctionMLMethod(factory);
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -8,6 +8,7 @@
|
|||||||
#include <Common/Arena.h>
|
#include <Common/Arena.h>
|
||||||
#include <Columns/ColumnsCommon.h>
|
#include <Columns/ColumnsCommon.h>
|
||||||
|
|
||||||
|
#include <AggregateFinctions/AggregateFunctionMLMethod.h>
|
||||||
|
|
||||||
namespace DB
|
namespace DB
|
||||||
{
|
{
|
||||||
@ -80,6 +81,35 @@ MutableColumnPtr ColumnAggregateFunction::convertToValues() const
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//MutableColumnPtr ColumnAggregateFunction::predictValues(std::vector<Float64> predict_feature) const
|
||||||
|
MutableColumnPtr ColumnAggregateFunction::predictValues(Block & block, const ColumnNumbers & arguments) const
|
||||||
|
{
|
||||||
|
if (const AggregateFunctionState * function_state = typeid_cast<const AggregateFunctionState *>(func.get()))
|
||||||
|
{
|
||||||
|
auto res = createView();
|
||||||
|
res->set(function_state->getNestedFunction());
|
||||||
|
res->data.assign(data.begin(), data.end());
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
MutableColumnPtr res = func->getReturnType()->createColumn();
|
||||||
|
res->reserve(data.size());
|
||||||
|
|
||||||
|
// const AggregateFunctionMLMethod * ML_function = typeid_cast<const AggregateFunctionMLMethod *>(func.get());
|
||||||
|
auto ML_function = typeid_cast<const AggregateFunctionMLMethod<LinearRegressionData, NameLinearRegression> *>(func.get());
|
||||||
|
if (ML_function)
|
||||||
|
{
|
||||||
|
size_t row_num = 0;
|
||||||
|
for (auto val : data) {
|
||||||
|
ML_function->predictResultInto(val, *res, block, row_num, arguments);
|
||||||
|
++row_num;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
void ColumnAggregateFunction::ensureOwnership()
|
void ColumnAggregateFunction::ensureOwnership()
|
||||||
{
|
{
|
||||||
|
@ -10,6 +10,7 @@
|
|||||||
#include <IO/WriteBuffer.h>
|
#include <IO/WriteBuffer.h>
|
||||||
#include <IO/WriteHelpers.h>
|
#include <IO/WriteHelpers.h>
|
||||||
|
|
||||||
|
#include <Functions/FunctionHelpers.h>
|
||||||
|
|
||||||
namespace DB
|
namespace DB
|
||||||
{
|
{
|
||||||
@ -115,6 +116,8 @@ public:
|
|||||||
std::string getName() const override { return "AggregateFunction(" + func->getName() + ")"; }
|
std::string getName() const override { return "AggregateFunction(" + func->getName() + ")"; }
|
||||||
const char * getFamilyName() const override { return "AggregateFunction"; }
|
const char * getFamilyName() const override { return "AggregateFunction"; }
|
||||||
|
|
||||||
|
MutableColumnPtr predictValues(Block & block, const ColumnNumbers & arguments) const;
|
||||||
|
|
||||||
size_t size() const override
|
size_t size() const override
|
||||||
{
|
{
|
||||||
return getData().size();
|
return getData().size();
|
||||||
|
95
dbms/src/Functions/evalMLMethod.cpp
Normal file
95
dbms/src/Functions/evalMLMethod.cpp
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
#include <Functions/IFunction.h>
|
||||||
|
#include <Functions/FunctionFactory.h>
|
||||||
|
#include <Functions/FunctionHelpers.h>
|
||||||
|
#include <DataTypes/DataTypeAggregateFunction.h>
|
||||||
|
#include <Columns/ColumnAggregateFunction.h>
|
||||||
|
#include <Common/typeid_cast.h>
|
||||||
|
|
||||||
|
#include <Columns/ColumnVector.h>
|
||||||
|
#include <Columns/ColumnsNumber.h>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
#include <Common/PODArray.h>
|
||||||
|
#include <Columns/ColumnArray.h>
|
||||||
|
|
||||||
|
namespace DB
|
||||||
|
{
|
||||||
|
|
||||||
|
namespace ErrorCodes
|
||||||
|
{
|
||||||
|
extern const int ILLEGAL_COLUMN;
|
||||||
|
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/** finalizeAggregation(agg_state) - get the result from the aggregation state.
|
||||||
|
* Takes state of aggregate function. Returns result of aggregation (finalized state).
|
||||||
|
*/
|
||||||
|
class FunctionEvalMLMethod : public IFunction
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
static constexpr auto name = "evalMLMethod";
|
||||||
|
static FunctionPtr create(const Context &)
|
||||||
|
{
|
||||||
|
return std::make_shared<FunctionEvalMLMethod>();
|
||||||
|
}
|
||||||
|
|
||||||
|
String getName() const override
|
||||||
|
{
|
||||||
|
return name;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isVariadic() const override {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
size_t getNumberOfArguments() const override
|
||||||
|
{
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
|
||||||
|
{
|
||||||
|
const DataTypeAggregateFunction * type = checkAndGetDataType<DataTypeAggregateFunction>(arguments[0].get());
|
||||||
|
if (!type)
|
||||||
|
throw Exception("Argument for function " + getName() + " must have type AggregateFunction - state of aggregate function.",
|
||||||
|
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||||
|
|
||||||
|
return type->getReturnType();
|
||||||
|
}
|
||||||
|
|
||||||
|
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) override
|
||||||
|
{
|
||||||
|
const ColumnAggregateFunction * column_with_states
|
||||||
|
= typeid_cast<const ColumnAggregateFunction *>(&*block.getByPosition(arguments.at(0)).column);
|
||||||
|
if (!column_with_states)
|
||||||
|
throw Exception("Illegal column " + block.getByPosition(arguments.at(0)).column->getName()
|
||||||
|
+ " of first argument of function "
|
||||||
|
+ getName(),
|
||||||
|
ErrorCodes::ILLEGAL_COLUMN);
|
||||||
|
|
||||||
|
// const ColumnArray * col_array = checkAndGetColumnConstData<ColumnArray>(block.getByPosition(arguments[1]).column.get());
|
||||||
|
// if (!col_array)
|
||||||
|
// throw std::runtime_error("wtf");
|
||||||
|
|
||||||
|
// const IColumn & array_elements = col_array->getData();
|
||||||
|
|
||||||
|
/*
|
||||||
|
std::vector<Float64> predict_features(arguments.size());
|
||||||
|
for (size_t i = 1; i < arguments.size(); ++i)
|
||||||
|
{
|
||||||
|
// predict_features[i] = array_elements[i].get<Float64>();
|
||||||
|
predict_features[i - 1] = typeid_cast<const ColumnConst *>(block.getByPosition(arguments[i]).column.get())->getValue<Float64>();
|
||||||
|
}
|
||||||
|
block.getByPosition(result).column = column_with_states->predictValues(predict_features);
|
||||||
|
*/
|
||||||
|
block.getByPosition(result).column = column_with_states->predictValues(block, arguments);
|
||||||
|
}
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
void registerFunctionEvalMLMethod(FunctionFactory & factory)
|
||||||
|
{
|
||||||
|
factory.registerFunction<FunctionEvalMLMethod>();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -42,6 +42,7 @@ void registerFunctionLowCardinalityKeys(FunctionFactory &);
|
|||||||
void registerFunctionsIn(FunctionFactory &);
|
void registerFunctionsIn(FunctionFactory &);
|
||||||
void registerFunctionJoinGet(FunctionFactory &);
|
void registerFunctionJoinGet(FunctionFactory &);
|
||||||
void registerFunctionFilesystem(FunctionFactory &);
|
void registerFunctionFilesystem(FunctionFactory &);
|
||||||
|
void registerFunctionEvalMLMethod(FunctionFactory &);
|
||||||
|
|
||||||
void registerFunctionsMiscellaneous(FunctionFactory & factory)
|
void registerFunctionsMiscellaneous(FunctionFactory & factory)
|
||||||
{
|
{
|
||||||
@ -84,6 +85,7 @@ void registerFunctionsMiscellaneous(FunctionFactory & factory)
|
|||||||
registerFunctionsIn(factory);
|
registerFunctionsIn(factory);
|
||||||
registerFunctionJoinGet(factory);
|
registerFunctionJoinGet(factory);
|
||||||
registerFunctionFilesystem(factory);
|
registerFunctionFilesystem(factory);
|
||||||
|
registerFunctionEvalMLMethod(factory);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
1
dbms/tests/queries/0_stateless/00900_mytest.reference
Normal file
1
dbms/tests/queries/0_stateless/00900_mytest.reference
Normal file
@ -0,0 +1 @@
|
|||||||
|
66.80107268499746
|
48
dbms/tests/queries/0_stateless/00900_mytest.sql
Normal file
48
dbms/tests/queries/0_stateless/00900_mytest.sql
Normal file
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user