Remane regression methods.

This commit is contained in:
Nikolai Kochetov 2019-05-23 14:51:25 +03:00
parent 6fa907c089
commit 54a52853e8
11 changed files with 42 additions and 48 deletions

View File

@ -110,8 +110,8 @@ namespace
void registerAggregateFunctionMLMethod(AggregateFunctionFactory & factory)
{
factory.registerFunction("LinearRegression", createAggregateFunctionMLMethod<FuncLinearRegression>);
factory.registerFunction("LogisticRegression", createAggregateFunctionMLMethod<FuncLogisticRegression>);
factory.registerFunction("linearRegression", createAggregateFunctionMLMethod<FuncLinearRegression>);
factory.registerFunction("logisticRegression", createAggregateFunctionMLMethod<FuncLogisticRegression>);
}
LinearModelData::LinearModelData(

View File

@ -321,10 +321,10 @@ private:
struct NameLinearRegression
{
static constexpr auto name = "LinearRegression";
static constexpr auto name = "linearRegression";
};
struct NameLogisticRegression
{
static constexpr auto name = "LogisticRegression";
static constexpr auto name = "logisticRegression";
};
}

View File

@ -1,8 +1,9 @@
#include <AggregateFunctions/AggregateFunctionLeastSqr.h>
#include <AggregateFunctions/AggregateFunctionSimpleLinearRegression.h>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/FactoryHelpers.h>
#include <Core/TypeListNumber.h>
namespace DB
{
@ -10,7 +11,7 @@ namespace DB
namespace
{
AggregateFunctionPtr createAggregateFunctionLeastSqr(
AggregateFunctionPtr createAggregateFunctionSimpleLinearRegression(
const String & name,
const DataTypes & arguments,
const Array & params
@ -20,16 +21,11 @@ AggregateFunctionPtr createAggregateFunctionLeastSqr(
assertBinary(name, arguments);
const IDataType * x_arg = arguments.front().get();
WhichDataType which_x {
x_arg
};
WhichDataType which_x = x_arg;
const IDataType * y_arg = arguments.back().get();
WhichDataType which_y = y_arg;
WhichDataType which_y {
y_arg
};
#define FOR_LEASTSQR_TYPES_2(M, T) \
M(T, UInt8) \
@ -55,7 +51,7 @@ AggregateFunctionPtr createAggregateFunctionLeastSqr(
FOR_LEASTSQR_TYPES_2(M, Float64)
#define DISPATCH(T1, T2) \
if (which_x.idx == TypeIndex::T1 && which_y.idx == TypeIndex::T2) \
return std::make_shared<AggregateFunctionLeastSqr<T1, T2>>( \
return std::make_shared<AggregateFunctionSimpleLinearRegression<T1, T2>>( \
arguments, \
params \
);
@ -77,9 +73,9 @@ AggregateFunctionPtr createAggregateFunctionLeastSqr(
}
void registerAggregateFunctionLeastSqr(AggregateFunctionFactory & factory)
void registerAggregateFunctionSimpleLinearRegression(AggregateFunctionFactory & factory)
{
factory.registerFunction("leastSqr", createAggregateFunctionLeastSqr);
factory.registerFunction("simpleLinearRegression", createAggregateFunctionSimpleLinearRegression);
}
}

View File

@ -19,7 +19,7 @@ namespace ErrorCodes
}
template <typename X, typename Y, typename Ret>
struct AggregateFunctionLeastSqrData final
struct AggregateFunctionSimpleLinearRegressionData final
{
size_t count = 0;
Ret sum_x = 0;
@ -36,7 +36,7 @@ struct AggregateFunctionLeastSqrData final
sum_xy += x * y;
}
void merge(const AggregateFunctionLeastSqrData & other)
void merge(const AggregateFunctionSimpleLinearRegressionData & other)
{
count += other.count;
sum_x += other.sum_x;
@ -85,19 +85,19 @@ struct AggregateFunctionLeastSqrData final
/// Calculates simple linear regression parameters.
/// Result is a tuple (k, b) for y = k * x + b equation, solved by least squares approximation.
template <typename X, typename Y, typename Ret = Float64>
class AggregateFunctionLeastSqr final : public IAggregateFunctionDataHelper<
AggregateFunctionLeastSqrData<X, Y, Ret>,
AggregateFunctionLeastSqr<X, Y, Ret>
class AggregateFunctionSimpleLinearRegression final : public IAggregateFunctionDataHelper<
AggregateFunctionSimpleLinearRegressionData<X, Y, Ret>,
AggregateFunctionSimpleLinearRegression<X, Y, Ret>
>
{
public:
AggregateFunctionLeastSqr(
AggregateFunctionSimpleLinearRegression(
const DataTypes & arguments,
const Array & params
):
IAggregateFunctionDataHelper<
AggregateFunctionLeastSqrData<X, Y, Ret>,
AggregateFunctionLeastSqr<X, Y, Ret>
AggregateFunctionSimpleLinearRegressionData<X, Y, Ret>,
AggregateFunctionSimpleLinearRegression<X, Y, Ret>
> {arguments, params}
{
// notice: arguments has been checked before
@ -105,7 +105,7 @@ public:
String getName() const override
{
return "leastSqr";
return "simpleLinearRegression";
}
const char * getHeaderFilePath() const override
@ -120,12 +120,8 @@ public:
Arena *
) const override
{
auto col_x {
static_cast<const ColumnVector<X> *>(columns[0])
};
auto col_y {
static_cast<const ColumnVector<Y> *>(columns[1])
};
auto col_x = static_cast<const ColumnVector<X> *>(columns[0]);
auto col_y = static_cast<const ColumnVector<Y> *>(columns[1]);
X x = col_x->getData()[row_num];
Y y = col_y->getData()[row_num];
@ -159,12 +155,14 @@ public:
DataTypePtr getReturnType() const override
{
DataTypes types {
DataTypes types
{
std::make_shared<DataTypeNumber<Ret>>(),
std::make_shared<DataTypeNumber<Ret>>(),
};
Strings names {
Strings names
{
"k",
"b",
};

View File

@ -30,7 +30,7 @@ void registerAggregateFunctionsBitmap(AggregateFunctionFactory &);
void registerAggregateFunctionsMaxIntersections(AggregateFunctionFactory &);
void registerAggregateFunctionMLMethod(AggregateFunctionFactory &);
void registerAggregateFunctionEntropy(AggregateFunctionFactory &);
void registerAggregateFunctionLeastSqr(AggregateFunctionFactory &);
void registerAggregateFunctionSimpleLinearRegression(AggregateFunctionFactory &);
void registerAggregateFunctionCombinatorIf(AggregateFunctionCombinatorFactory &);
void registerAggregateFunctionCombinatorArray(AggregateFunctionCombinatorFactory &);
@ -73,7 +73,7 @@ void registerAggregateFunctions()
registerAggregateFunctionTSgroupSum(factory);
registerAggregateFunctionMLMethod(factory);
registerAggregateFunctionEntropy(factory);
registerAggregateFunctionLeastSqr(factory);
registerAggregateFunctionSimpleLinearRegression(factory);
}
{

View File

@ -1,9 +1,9 @@
select arrayReduce('leastSqr', [1, 2, 3, 4], [100, 110, 120, 130]);
select arrayReduce('leastSqr', [1, 2, 3, 4], [100, 110, 120, 131]);
select arrayReduce('leastSqr', [-1, -2, -3, -4], [-100, -110, -120, -130]);
select arrayReduce('leastSqr', [5, 5.1], [6, 6.1]);
select arrayReduce('leastSqr', [0], [0]);
select arrayReduce('leastSqr', [3, 4], [3, 3]);
select arrayReduce('leastSqr', [3, 3], [3, 4]);
select arrayReduce('leastSqr', emptyArrayUInt8(), emptyArrayUInt8());
select arrayReduce('simpleLinearRegression', [1, 2, 3, 4], [100, 110, 120, 130]);
select arrayReduce('simpleLinearRegression', [1, 2, 3, 4], [100, 110, 120, 131]);
select arrayReduce('simpleLinearRegression', [-1, -2, -3, -4], [-100, -110, -120, -130]);
select arrayReduce('simpleLinearRegression', [5, 5.1], [6, 6.1]);
select arrayReduce('simpleLinearRegression', [0], [0]);
select arrayReduce('simpleLinearRegression', [3, 4], [3, 3]);
select arrayReduce('simpleLinearRegression', [3, 3], [3, 4]);
select arrayReduce('simpleLinearRegression', emptyArrayUInt8(), emptyArrayUInt8());

File diff suppressed because one or more lines are too long

View File

@ -11,7 +11,7 @@ CREATE TABLE IF NOT EXISTS test.defaults
insert into test.defaults values (-3.273, -1.452, 4.267, 20.0, 40.0), (0.121, -0.615, 4.290, 20.0, 40.0);
DROP TABLE IF EXISTS test.model;
create table test.model engine = Memory as select LinearRegressionState(0.1, 0.0, 2, 'SGD')(target, param1, param2) as state from test.defaults;
create table test.model engine = Memory as select linearRegressionState(0.1, 0.0, 2, 'SGD')(target, param1, param2) as state from test.defaults;
select ans < -61.374 and ans > -61.375 from
(with (select state from remote('127.0.0.1', test.model)) as model select evalMLMethod(model, predict1, predict2) as ans from remote('127.0.0.1', test.defaults));

View File

@ -10,7 +10,7 @@ CREATE TABLE IF NOT EXISTS test.defaults
) ENGINE = Memory;
insert into test.defaults values (1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2)
DROP TABLE IF EXISTS test.model;
create table test.model engine = Memory as select LogisticRegressionState(0.1, 0.0, 1.0, 'SGD')(target, param1, param2) as state from test.defaults;
create table test.model engine = Memory as select logisticRegressionState(0.1, 0.0, 1.0, 'SGD')(target, param1, param2) as state from test.defaults;
select ans < 1.1 and ans > 0.9 from
(with (select state from test.model) as model select evalMLMethod(model, predict1, predict2) as ans from test.defaults limit 2);

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long