linear regression

This commit is contained in:
Alexander Kozhikhov 2019-04-09 01:40:37 +03:00
parent a9ae6d0681
commit c93aae6741
16 changed files with 62 additions and 465 deletions

View File

@ -459,10 +459,10 @@ private:
template <
/// Implemented Machine Learning method
typename Data,
/// Name of the method
typename Name
/// Implemented Machine Learning method
typename Data,
/// Name of the method
typename Name
>
class AggregateFunctionMLMethod final : public IAggregateFunctionDataHelper<Data, AggregateFunctionMLMethod<Data, Name>>
{

View File

@ -152,7 +152,7 @@ public:
IAggregateFunctionDataHelper(const DataTypes & argument_types_, const Array & parameters_)
: IAggregateFunctionHelper<Derived>(argument_types_, parameters_) {}
virtual void create(AggregateDataPtr place) const override
void create(AggregateDataPtr place) const override
{
new (place) Data;
}

View File

@ -84,16 +84,11 @@ MutableColumnPtr ColumnAggregateFunction::convertToValues() const
* AggregateFunction(quantileTiming(0.5), UInt64)
* into UInt16 - already finished result of `quantileTiming`.
*/
// if (const AggregateFunctionState * function_state = typeid_cast<const AggregateFunctionState *>(func.get()))
// {
// auto res = createView();
// res->set(function_state->getNestedFunction());
// res->data.assign(data.begin(), data.end());
// return res;
// }
//
// MutableColumnPtr res = func->getReturnType()->createColumn();
// res->reserve(data.size());
/** Convertion function is used in convertToValues and predictValues
* in the similar part of both functions
*/
MutableColumnPtr res;
if (convertion(&res))
{
@ -114,8 +109,6 @@ MutableColumnPtr ColumnAggregateFunction::predictValues(Block & block, const Col
return res;
}
/// На моих тестах дважды в эту функцию приходит нечтно, имеющее data.size() == 0 однако оно по сути ничего не делает в следующих строках
auto ML_function_Linear = typeid_cast<AggregateFunctionMLMethod<LinearModelData, NameLinearRegression> *>(func.get());
auto ML_function_Logistic = typeid_cast<AggregateFunctionMLMethod<LinearModelData, NameLogisticRegression> *>(func.get());
if (ML_function_Linear)

View File

@ -25,84 +25,59 @@ namespace DB
/** finalizeAggregation(agg_state) - get the result from the aggregation state.
* Takes state of aggregate function. Returns result of aggregation (finalized state).
*/
class FunctionEvalMLMethod : public IFunction
class FunctionEvalMLMethod : public IFunction
{
public:
static constexpr auto name = "evalMLMethod";
static FunctionPtr create(const Context &)
{
public:
static constexpr auto name = "evalMLMethod";
static FunctionPtr create(const Context &)
{
return std::make_shared<FunctionEvalMLMethod>();
}
String getName() const override
{
return name;
}
bool isVariadic() const override {
return true;
}
size_t getNumberOfArguments() const override
{
return 0;
}
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
const DataTypeAggregateFunction * type = checkAndGetDataType<DataTypeAggregateFunction>(arguments[0].get());
if (!type)
throw Exception("Argument for function " + getName() + " must have type AggregateFunction - state of aggregate function.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return type->getReturnType();
}
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) override
{
// const ColumnAggregateFunction * column_with_states =
// typeid_cast<const ColumnAggregateFunction *>(static_cast<const ColumnConst *>(&*block.getByPosition(arguments.at(0)).column));
// завести МЛ_аггрункции как отдельный класс, чтобы тут сразу это проверять, а не делать это внутри predictValues()
// const ColumnAggregateFunction * column_with_states
// = typeid_cast<const ColumnAggregateFunction *>(&*block.getByPosition(arguments.at(0)).column);
const ColumnConst * column_with_states
= typeid_cast<const ColumnConst *>(&*block.getByPosition(arguments.at(0)).column);
if (!column_with_states)
throw Exception("Illegal column " + block.getByPosition(arguments.at(0)).column->getName()
+ " of first argument of function "
+ getName(),
ErrorCodes::ILLEGAL_COLUMN);
// const ColumnArray * col_array = checkAndGetColumnConstData<ColumnArray>(block.getByPosition(arguments[1]).column.get());
// if (!col_array)
// throw std::runtime_error("wtf");
// const IColumn & array_elements = col_array->getData();
/*
std::vector<Float64> predict_features(arguments.size());
for (size_t i = 1; i < arguments.size(); ++i)
{
// predict_features[i] = array_elements[i].get<Float64>();
predict_features[i - 1] = typeid_cast<const ColumnConst *>(block.getByPosition(arguments[i]).column.get())->getValue<Float64>();
}
block.getByPosition(result).column = column_with_states->predictValues(predict_features);
*/
block.getByPosition(result).column =
typeid_cast<const ColumnAggregateFunction *>(&*column_with_states->getDataColumnPtr())->predictValues(block, arguments);
}
};
void registerFunctionEvalMLMethod(FunctionFactory & factory)
{
factory.registerFunction<FunctionEvalMLMethod>();
return std::make_shared<FunctionEvalMLMethod>();
}
String getName() const override
{
return name;
}
bool isVariadic() const override {
return true;
}
size_t getNumberOfArguments() const override
{
return 0;
}
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
const DataTypeAggregateFunction * type = checkAndGetDataType<DataTypeAggregateFunction>(arguments[0].get());
if (!type)
throw Exception("Argument for function " + getName() + " must have type AggregateFunction - state of aggregate function.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return type->getReturnType();
}
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) override
{
const ColumnConst * column_with_states
= typeid_cast<const ColumnConst *>(&*block.getByPosition(arguments.at(0)).column);
if (!column_with_states)
throw Exception("Illegal column " + block.getByPosition(arguments.at(0)).column->getName()
+ " of first argument of function "
+ getName(),
ErrorCodes::ILLEGAL_COLUMN);
block.getByPosition(result).column =
typeid_cast<const ColumnAggregateFunction *>(&*column_with_states->getDataColumnPtr())->predictValues(block, arguments);
}
};
void registerFunctionEvalMLMethod(FunctionFactory & factory)
{
factory.registerFunction<FunctionEvalMLMethod>();
}
}

View File

@ -1,370 +0,0 @@
0.72
0.72
0.72
0.72
0.72
0.72
0.72
0.72
0.72
0.72
0.72
0.72
0.72
0.89
0.89
0.89
0.89
0.89
0.89
0.89
0.89
0.89
0.89
0.89
0.89
0.95
0.95
0.95
0.95
0.95
0.95
0.95
0.95
0.95
0.95
0.95
0.95
0.79
0.79
0.79
0.79
0.79
0.79
0.79
0.79
0.79
0.79
0.79
0.79
0.79
0.39
0.39
0.39
0.39
0.39
0.39
0.39
0.39
0.39
0.39
0.39
0.39
0.38
0.38
0.38
0.38
0.38
0.38
0.38
0.38
0.38
0.38
0.38
0.38
0.34
0.34
0.34
0.34
0.34
0.34
0.34
0.34
0.34
0.34
0.34
0.34
0.34
0.47
0.47
0.47
0.47
0.47
0.47
0.47
0.47
0.47
0.47
0.47
0.47
0.56
0.56
0.56
0.56
0.56
0.56
0.56
0.56
0.56
0.56
0.56
0.56
0.71
0.71
0.71
0.71
0.71
0.71
0.71
0.71
0.71
0.71
0.71
0.71
0.71
0.78
0.78
0.78
0.78
0.78
0.78
0.78
0.78
0.78
0.78
0.78
0.78
0.73
0.73
0.73
0.73
0.73
0.73
0.73
0.73
0.73
0.73
0.73
0.73
0.82
0.82
0.82
0.82
0.82
0.82
0.82
0.82
0.82
0.82
0.82
0.82
0.82
0.62
0.62
0.62
0.62
0.62
0.62
0.62
0.62
0.62
0.62
0.62
0.62
0.96
0.96
0.96
0.96
0.96
0.96
0.96
0.96
0.96
0.96
0.96
0.96
0.96
0.96
0.96
0.96
0.96
0.96
0.96
0.96
0.96
0.96
0.96
0.96
0.96
0.46
0.46
0.46
0.46
0.46
0.46
0.46
0.46
0.46
0.46
0.46
0.46
0.53
0.53
0.53
0.53
0.53
0.53
0.53
0.53
0.53
0.53
0.53
0.53
0.49
0.49
0.49
0.49
0.49
0.49
0.49
0.49
0.49
0.49
0.49
0.49
0.49
0.76
0.76
0.76
0.76
0.76
0.76
0.76
0.76
0.76
0.76
0.76
0.76
0.64
0.64
0.64
0.64
0.64
0.64
0.64
0.64
0.64
0.64
0.64
0.64
0.71
0.71
0.71
0.71
0.71
0.71
0.71
0.71
0.71
0.71
0.71
0.71
0.71
0.84
0.84
0.84
0.84
0.84
0.84
0.84
0.84
0.84
0.84
0.84
0.84
0.77
0.77
0.77
0.77
0.77
0.77
0.77
0.77
0.77
0.77
0.77
0.77
0.89
0.89
0.89
0.89
0.89
0.89
0.89
0.89
0.89
0.89
0.89
0.89
0.89
0.82
0.82
0.82
0.82
0.82
0.82
0.82
0.82
0.82
0.82
0.82
0.82
0.84
0.84
0.84
0.84
0.84
0.84
0.84
0.84
0.84
0.84
0.84
0.84
0.91
0.91
0.91
0.91
0.91
0.91
0.91
0.91
0.91
0.91
0.91
0.91
0.91
0.67
0.67
0.67
0.67
0.67
0.67
0.67
0.67
0.67
0.67
0.67
0.67
0.95
0.95
0.95
0.95
0.95
0.95
0.95
0.95
0.95
0.95
0.95
0.95