mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-24 08:32:02 +00:00
Remane regression methods.
This commit is contained in:
parent
6fa907c089
commit
54a52853e8
@ -110,8 +110,8 @@ namespace
|
||||
|
||||
void registerAggregateFunctionMLMethod(AggregateFunctionFactory & factory)
|
||||
{
|
||||
factory.registerFunction("LinearRegression", createAggregateFunctionMLMethod<FuncLinearRegression>);
|
||||
factory.registerFunction("LogisticRegression", createAggregateFunctionMLMethod<FuncLogisticRegression>);
|
||||
factory.registerFunction("linearRegression", createAggregateFunctionMLMethod<FuncLinearRegression>);
|
||||
factory.registerFunction("logisticRegression", createAggregateFunctionMLMethod<FuncLogisticRegression>);
|
||||
}
|
||||
|
||||
LinearModelData::LinearModelData(
|
||||
|
@ -321,10 +321,10 @@ private:
|
||||
|
||||
struct NameLinearRegression
|
||||
{
|
||||
static constexpr auto name = "LinearRegression";
|
||||
static constexpr auto name = "linearRegression";
|
||||
};
|
||||
struct NameLogisticRegression
|
||||
{
|
||||
static constexpr auto name = "LogisticRegression";
|
||||
static constexpr auto name = "logisticRegression";
|
||||
};
|
||||
}
|
||||
|
@ -1,8 +1,9 @@
|
||||
#include <AggregateFunctions/AggregateFunctionLeastSqr.h>
|
||||
#include <AggregateFunctions/AggregateFunctionSimpleLinearRegression.h>
|
||||
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
|
||||
#include <Core/TypeListNumber.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
@ -10,7 +11,7 @@ namespace DB
|
||||
namespace
|
||||
{
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionLeastSqr(
|
||||
AggregateFunctionPtr createAggregateFunctionSimpleLinearRegression(
|
||||
const String & name,
|
||||
const DataTypes & arguments,
|
||||
const Array & params
|
||||
@ -20,16 +21,11 @@ AggregateFunctionPtr createAggregateFunctionLeastSqr(
|
||||
assertBinary(name, arguments);
|
||||
|
||||
const IDataType * x_arg = arguments.front().get();
|
||||
|
||||
WhichDataType which_x {
|
||||
x_arg
|
||||
};
|
||||
WhichDataType which_x = x_arg;
|
||||
|
||||
const IDataType * y_arg = arguments.back().get();
|
||||
WhichDataType which_y = y_arg;
|
||||
|
||||
WhichDataType which_y {
|
||||
y_arg
|
||||
};
|
||||
|
||||
#define FOR_LEASTSQR_TYPES_2(M, T) \
|
||||
M(T, UInt8) \
|
||||
@ -55,7 +51,7 @@ AggregateFunctionPtr createAggregateFunctionLeastSqr(
|
||||
FOR_LEASTSQR_TYPES_2(M, Float64)
|
||||
#define DISPATCH(T1, T2) \
|
||||
if (which_x.idx == TypeIndex::T1 && which_y.idx == TypeIndex::T2) \
|
||||
return std::make_shared<AggregateFunctionLeastSqr<T1, T2>>( \
|
||||
return std::make_shared<AggregateFunctionSimpleLinearRegression<T1, T2>>( \
|
||||
arguments, \
|
||||
params \
|
||||
);
|
||||
@ -77,9 +73,9 @@ AggregateFunctionPtr createAggregateFunctionLeastSqr(
|
||||
|
||||
}
|
||||
|
||||
void registerAggregateFunctionLeastSqr(AggregateFunctionFactory & factory)
|
||||
void registerAggregateFunctionSimpleLinearRegression(AggregateFunctionFactory & factory)
|
||||
{
|
||||
factory.registerFunction("leastSqr", createAggregateFunctionLeastSqr);
|
||||
factory.registerFunction("simpleLinearRegression", createAggregateFunctionSimpleLinearRegression);
|
||||
}
|
||||
|
||||
}
|
@ -19,7 +19,7 @@ namespace ErrorCodes
|
||||
}
|
||||
|
||||
template <typename X, typename Y, typename Ret>
|
||||
struct AggregateFunctionLeastSqrData final
|
||||
struct AggregateFunctionSimpleLinearRegressionData final
|
||||
{
|
||||
size_t count = 0;
|
||||
Ret sum_x = 0;
|
||||
@ -36,7 +36,7 @@ struct AggregateFunctionLeastSqrData final
|
||||
sum_xy += x * y;
|
||||
}
|
||||
|
||||
void merge(const AggregateFunctionLeastSqrData & other)
|
||||
void merge(const AggregateFunctionSimpleLinearRegressionData & other)
|
||||
{
|
||||
count += other.count;
|
||||
sum_x += other.sum_x;
|
||||
@ -85,19 +85,19 @@ struct AggregateFunctionLeastSqrData final
|
||||
/// Calculates simple linear regression parameters.
|
||||
/// Result is a tuple (k, b) for y = k * x + b equation, solved by least squares approximation.
|
||||
template <typename X, typename Y, typename Ret = Float64>
|
||||
class AggregateFunctionLeastSqr final : public IAggregateFunctionDataHelper<
|
||||
AggregateFunctionLeastSqrData<X, Y, Ret>,
|
||||
AggregateFunctionLeastSqr<X, Y, Ret>
|
||||
class AggregateFunctionSimpleLinearRegression final : public IAggregateFunctionDataHelper<
|
||||
AggregateFunctionSimpleLinearRegressionData<X, Y, Ret>,
|
||||
AggregateFunctionSimpleLinearRegression<X, Y, Ret>
|
||||
>
|
||||
{
|
||||
public:
|
||||
AggregateFunctionLeastSqr(
|
||||
AggregateFunctionSimpleLinearRegression(
|
||||
const DataTypes & arguments,
|
||||
const Array & params
|
||||
):
|
||||
IAggregateFunctionDataHelper<
|
||||
AggregateFunctionLeastSqrData<X, Y, Ret>,
|
||||
AggregateFunctionLeastSqr<X, Y, Ret>
|
||||
AggregateFunctionSimpleLinearRegressionData<X, Y, Ret>,
|
||||
AggregateFunctionSimpleLinearRegression<X, Y, Ret>
|
||||
> {arguments, params}
|
||||
{
|
||||
// notice: arguments has been checked before
|
||||
@ -105,7 +105,7 @@ public:
|
||||
|
||||
String getName() const override
|
||||
{
|
||||
return "leastSqr";
|
||||
return "simpleLinearRegression";
|
||||
}
|
||||
|
||||
const char * getHeaderFilePath() const override
|
||||
@ -120,12 +120,8 @@ public:
|
||||
Arena *
|
||||
) const override
|
||||
{
|
||||
auto col_x {
|
||||
static_cast<const ColumnVector<X> *>(columns[0])
|
||||
};
|
||||
auto col_y {
|
||||
static_cast<const ColumnVector<Y> *>(columns[1])
|
||||
};
|
||||
auto col_x = static_cast<const ColumnVector<X> *>(columns[0]);
|
||||
auto col_y = static_cast<const ColumnVector<Y> *>(columns[1]);
|
||||
|
||||
X x = col_x->getData()[row_num];
|
||||
Y y = col_y->getData()[row_num];
|
||||
@ -159,12 +155,14 @@ public:
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
{
|
||||
DataTypes types {
|
||||
DataTypes types
|
||||
{
|
||||
std::make_shared<DataTypeNumber<Ret>>(),
|
||||
std::make_shared<DataTypeNumber<Ret>>(),
|
||||
};
|
||||
|
||||
Strings names {
|
||||
Strings names
|
||||
{
|
||||
"k",
|
||||
"b",
|
||||
};
|
@ -30,7 +30,7 @@ void registerAggregateFunctionsBitmap(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionsMaxIntersections(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionMLMethod(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionEntropy(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionLeastSqr(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionSimpleLinearRegression(AggregateFunctionFactory &);
|
||||
|
||||
void registerAggregateFunctionCombinatorIf(AggregateFunctionCombinatorFactory &);
|
||||
void registerAggregateFunctionCombinatorArray(AggregateFunctionCombinatorFactory &);
|
||||
@ -73,7 +73,7 @@ void registerAggregateFunctions()
|
||||
registerAggregateFunctionTSgroupSum(factory);
|
||||
registerAggregateFunctionMLMethod(factory);
|
||||
registerAggregateFunctionEntropy(factory);
|
||||
registerAggregateFunctionLeastSqr(factory);
|
||||
registerAggregateFunctionSimpleLinearRegression(factory);
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -1,9 +1,9 @@
|
||||
select arrayReduce('leastSqr', [1, 2, 3, 4], [100, 110, 120, 130]);
|
||||
select arrayReduce('leastSqr', [1, 2, 3, 4], [100, 110, 120, 131]);
|
||||
select arrayReduce('leastSqr', [-1, -2, -3, -4], [-100, -110, -120, -130]);
|
||||
select arrayReduce('leastSqr', [5, 5.1], [6, 6.1]);
|
||||
select arrayReduce('leastSqr', [0], [0]);
|
||||
select arrayReduce('leastSqr', [3, 4], [3, 3]);
|
||||
select arrayReduce('leastSqr', [3, 3], [3, 4]);
|
||||
select arrayReduce('leastSqr', emptyArrayUInt8(), emptyArrayUInt8());
|
||||
select arrayReduce('simpleLinearRegression', [1, 2, 3, 4], [100, 110, 120, 130]);
|
||||
select arrayReduce('simpleLinearRegression', [1, 2, 3, 4], [100, 110, 120, 131]);
|
||||
select arrayReduce('simpleLinearRegression', [-1, -2, -3, -4], [-100, -110, -120, -130]);
|
||||
select arrayReduce('simpleLinearRegression', [5, 5.1], [6, 6.1]);
|
||||
select arrayReduce('simpleLinearRegression', [0], [0]);
|
||||
select arrayReduce('simpleLinearRegression', [3, 4], [3, 3]);
|
||||
select arrayReduce('simpleLinearRegression', [3, 3], [3, 4]);
|
||||
select arrayReduce('simpleLinearRegression', emptyArrayUInt8(), emptyArrayUInt8());
|
||||
|
||||
|
File diff suppressed because one or more lines are too long
@ -11,7 +11,7 @@ CREATE TABLE IF NOT EXISTS test.defaults
|
||||
insert into test.defaults values (-3.273, -1.452, 4.267, 20.0, 40.0), (0.121, -0.615, 4.290, 20.0, 40.0);
|
||||
|
||||
DROP TABLE IF EXISTS test.model;
|
||||
create table test.model engine = Memory as select LinearRegressionState(0.1, 0.0, 2, 'SGD')(target, param1, param2) as state from test.defaults;
|
||||
create table test.model engine = Memory as select linearRegressionState(0.1, 0.0, 2, 'SGD')(target, param1, param2) as state from test.defaults;
|
||||
|
||||
select ans < -61.374 and ans > -61.375 from
|
||||
(with (select state from remote('127.0.0.1', test.model)) as model select evalMLMethod(model, predict1, predict2) as ans from remote('127.0.0.1', test.defaults));
|
||||
|
@ -10,7 +10,7 @@ CREATE TABLE IF NOT EXISTS test.defaults
|
||||
) ENGINE = Memory;
|
||||
insert into test.defaults values (1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2),(1,2,1,-1,-2),(-1,-2,-1,1,2)
|
||||
DROP TABLE IF EXISTS test.model;
|
||||
create table test.model engine = Memory as select LogisticRegressionState(0.1, 0.0, 1.0, 'SGD')(target, param1, param2) as state from test.defaults;
|
||||
create table test.model engine = Memory as select logisticRegressionState(0.1, 0.0, 1.0, 'SGD')(target, param1, param2) as state from test.defaults;
|
||||
|
||||
select ans < 1.1 and ans > 0.9 from
|
||||
(with (select state from test.model) as model select evalMLMethod(model, predict1, predict2) as ans from test.defaults limit 2);
|
||||
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user