2019-05-23 11:51:25 +00:00
|
|
|
#include <AggregateFunctions/AggregateFunctionSimpleLinearRegression.h>
|
2019-03-13 07:22:57 +00:00
|
|
|
|
|
|
|
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
|
|
|
#include <AggregateFunctions/FactoryHelpers.h>
|
|
|
|
|
2019-05-23 11:51:25 +00:00
|
|
|
#include <Core/TypeListNumber.h>
|
2019-12-15 06:34:43 +00:00
|
|
|
#include "registerAggregateFunctions.h"
|
2019-03-13 07:22:57 +00:00
|
|
|
|
|
|
|
namespace DB
|
|
|
|
{
|
2020-02-25 18:10:48 +00:00
|
|
|
namespace ErrorCodes
|
|
|
|
{
|
|
|
|
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
|
|
|
}
|
2019-03-13 07:22:57 +00:00
|
|
|
|
|
|
|
namespace
|
|
|
|
{
|
|
|
|
|
2019-05-23 11:51:25 +00:00
|
|
|
AggregateFunctionPtr createAggregateFunctionSimpleLinearRegression(
|
2019-03-13 07:22:57 +00:00
|
|
|
const String & name,
|
|
|
|
const DataTypes & arguments,
|
|
|
|
const Array & params
|
|
|
|
)
|
|
|
|
{
|
|
|
|
assertNoParameters(name, params);
|
|
|
|
assertBinary(name, arguments);
|
|
|
|
|
|
|
|
const IDataType * x_arg = arguments.front().get();
|
2019-05-23 11:51:25 +00:00
|
|
|
WhichDataType which_x = x_arg;
|
2019-03-13 07:22:57 +00:00
|
|
|
|
|
|
|
const IDataType * y_arg = arguments.back().get();
|
2019-05-23 11:51:25 +00:00
|
|
|
WhichDataType which_y = y_arg;
|
2019-03-13 07:22:57 +00:00
|
|
|
|
|
|
|
|
2019-03-22 07:57:17 +00:00
|
|
|
#define FOR_LEASTSQR_TYPES_2(M, T) \
|
|
|
|
M(T, UInt8) \
|
|
|
|
M(T, UInt16) \
|
|
|
|
M(T, UInt32) \
|
|
|
|
M(T, UInt64) \
|
|
|
|
M(T, Int8) \
|
|
|
|
M(T, Int16) \
|
|
|
|
M(T, Int32) \
|
|
|
|
M(T, Int64) \
|
|
|
|
M(T, Float32) \
|
|
|
|
M(T, Float64)
|
|
|
|
#define FOR_LEASTSQR_TYPES(M) \
|
|
|
|
FOR_LEASTSQR_TYPES_2(M, UInt8) \
|
|
|
|
FOR_LEASTSQR_TYPES_2(M, UInt16) \
|
|
|
|
FOR_LEASTSQR_TYPES_2(M, UInt32) \
|
|
|
|
FOR_LEASTSQR_TYPES_2(M, UInt64) \
|
|
|
|
FOR_LEASTSQR_TYPES_2(M, Int8) \
|
|
|
|
FOR_LEASTSQR_TYPES_2(M, Int16) \
|
|
|
|
FOR_LEASTSQR_TYPES_2(M, Int32) \
|
|
|
|
FOR_LEASTSQR_TYPES_2(M, Int64) \
|
|
|
|
FOR_LEASTSQR_TYPES_2(M, Float32) \
|
|
|
|
FOR_LEASTSQR_TYPES_2(M, Float64)
|
|
|
|
#define DISPATCH(T1, T2) \
|
|
|
|
if (which_x.idx == TypeIndex::T1 && which_y.idx == TypeIndex::T2) \
|
2020-03-18 03:29:25 +00:00
|
|
|
return std::make_shared<AggregateFunctionSimpleLinearRegression<T1, T2>>(/* NOLINT */ \
|
2019-03-22 07:57:17 +00:00
|
|
|
arguments, \
|
|
|
|
params \
|
|
|
|
);
|
|
|
|
|
|
|
|
FOR_LEASTSQR_TYPES(DISPATCH)
|
|
|
|
|
|
|
|
#undef FOR_LEASTSQR_TYPES_2
|
|
|
|
#undef FOR_LEASTSQR_TYPES
|
|
|
|
#undef DISPATCH
|
|
|
|
|
2019-03-23 12:20:40 +00:00
|
|
|
throw Exception(
|
2019-03-22 07:57:17 +00:00
|
|
|
"Illegal types ("
|
|
|
|
+ x_arg->getName() + ", " + y_arg->getName()
|
|
|
|
+ ") of arguments of aggregate function " + name
|
|
|
|
+ ", must be Native Ints, Native UInts or Floats",
|
|
|
|
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT
|
2019-03-23 12:20:40 +00:00
|
|
|
);
|
2019-03-13 07:22:57 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
2019-05-23 11:51:25 +00:00
|
|
|
void registerAggregateFunctionSimpleLinearRegression(AggregateFunctionFactory & factory)
|
2019-03-13 07:22:57 +00:00
|
|
|
{
|
2019-05-23 11:51:25 +00:00
|
|
|
factory.registerFunction("simpleLinearRegression", createAggregateFunctionSimpleLinearRegression);
|
2019-03-13 07:22:57 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
}
|