mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-09-20 16:50:48 +00:00
linear regression
This commit is contained in:
parent
a9ae6d0681
commit
c93aae6741
@ -459,10 +459,10 @@ private:
|
||||
|
||||
|
||||
template <
|
||||
/// Implemented Machine Learning method
|
||||
typename Data,
|
||||
/// Name of the method
|
||||
typename Name
|
||||
/// Implemented Machine Learning method
|
||||
typename Data,
|
||||
/// Name of the method
|
||||
typename Name
|
||||
>
|
||||
class AggregateFunctionMLMethod final : public IAggregateFunctionDataHelper<Data, AggregateFunctionMLMethod<Data, Name>>
|
||||
{
|
||||
|
@ -152,7 +152,7 @@ public:
|
||||
IAggregateFunctionDataHelper(const DataTypes & argument_types_, const Array & parameters_)
|
||||
: IAggregateFunctionHelper<Derived>(argument_types_, parameters_) {}
|
||||
|
||||
virtual void create(AggregateDataPtr place) const override
|
||||
void create(AggregateDataPtr place) const override
|
||||
{
|
||||
new (place) Data;
|
||||
}
|
||||
|
@ -84,16 +84,11 @@ MutableColumnPtr ColumnAggregateFunction::convertToValues() const
|
||||
* AggregateFunction(quantileTiming(0.5), UInt64)
|
||||
* into UInt16 - already finished result of `quantileTiming`.
|
||||
*/
|
||||
// if (const AggregateFunctionState * function_state = typeid_cast<const AggregateFunctionState *>(func.get()))
|
||||
// {
|
||||
// auto res = createView();
|
||||
// res->set(function_state->getNestedFunction());
|
||||
// res->data.assign(data.begin(), data.end());
|
||||
// return res;
|
||||
// }
|
||||
//
|
||||
// MutableColumnPtr res = func->getReturnType()->createColumn();
|
||||
// res->reserve(data.size());
|
||||
|
||||
/** Convertion function is used in convertToValues and predictValues
|
||||
* in the similar part of both functions
|
||||
*/
|
||||
|
||||
MutableColumnPtr res;
|
||||
if (convertion(&res))
|
||||
{
|
||||
@ -114,8 +109,6 @@ MutableColumnPtr ColumnAggregateFunction::predictValues(Block & block, const Col
|
||||
return res;
|
||||
}
|
||||
|
||||
/// На моих тестах дважды в эту функцию приходит нечтно, имеющее data.size() == 0 однако оно по сути ничего не делает в следующих строках
|
||||
|
||||
auto ML_function_Linear = typeid_cast<AggregateFunctionMLMethod<LinearModelData, NameLinearRegression> *>(func.get());
|
||||
auto ML_function_Logistic = typeid_cast<AggregateFunctionMLMethod<LinearModelData, NameLogisticRegression> *>(func.get());
|
||||
if (ML_function_Linear)
|
||||
|
@ -25,84 +25,59 @@ namespace DB
|
||||
/** finalizeAggregation(agg_state) - get the result from the aggregation state.
|
||||
* Takes state of aggregate function. Returns result of aggregation (finalized state).
|
||||
*/
|
||||
class FunctionEvalMLMethod : public IFunction
|
||||
class FunctionEvalMLMethod : public IFunction
|
||||
{
|
||||
public:
|
||||
static constexpr auto name = "evalMLMethod";
|
||||
static FunctionPtr create(const Context &)
|
||||
{
|
||||
public:
|
||||
static constexpr auto name = "evalMLMethod";
|
||||
static FunctionPtr create(const Context &)
|
||||
{
|
||||
return std::make_shared<FunctionEvalMLMethod>();
|
||||
}
|
||||
|
||||
String getName() const override
|
||||
{
|
||||
return name;
|
||||
}
|
||||
|
||||
bool isVariadic() const override {
|
||||
return true;
|
||||
}
|
||||
size_t getNumberOfArguments() const override
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
|
||||
{
|
||||
const DataTypeAggregateFunction * type = checkAndGetDataType<DataTypeAggregateFunction>(arguments[0].get());
|
||||
if (!type)
|
||||
throw Exception("Argument for function " + getName() + " must have type AggregateFunction - state of aggregate function.",
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
return type->getReturnType();
|
||||
}
|
||||
|
||||
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) override
|
||||
{
|
||||
// const ColumnAggregateFunction * column_with_states =
|
||||
// typeid_cast<const ColumnAggregateFunction *>(static_cast<const ColumnConst *>(&*block.getByPosition(arguments.at(0)).column));
|
||||
|
||||
|
||||
// завести МЛ_аггр_функции как отдельный класс, чтобы тут сразу это проверять, а не делать это внутри predictValues()
|
||||
|
||||
// const ColumnAggregateFunction * column_with_states
|
||||
// = typeid_cast<const ColumnAggregateFunction *>(&*block.getByPosition(arguments.at(0)).column);
|
||||
|
||||
|
||||
const ColumnConst * column_with_states
|
||||
= typeid_cast<const ColumnConst *>(&*block.getByPosition(arguments.at(0)).column);
|
||||
|
||||
|
||||
if (!column_with_states)
|
||||
throw Exception("Illegal column " + block.getByPosition(arguments.at(0)).column->getName()
|
||||
+ " of first argument of function "
|
||||
+ getName(),
|
||||
ErrorCodes::ILLEGAL_COLUMN);
|
||||
|
||||
// const ColumnArray * col_array = checkAndGetColumnConstData<ColumnArray>(block.getByPosition(arguments[1]).column.get());
|
||||
// if (!col_array)
|
||||
// throw std::runtime_error("wtf");
|
||||
|
||||
// const IColumn & array_elements = col_array->getData();
|
||||
|
||||
/*
|
||||
std::vector<Float64> predict_features(arguments.size());
|
||||
for (size_t i = 1; i < arguments.size(); ++i)
|
||||
{
|
||||
// predict_features[i] = array_elements[i].get<Float64>();
|
||||
predict_features[i - 1] = typeid_cast<const ColumnConst *>(block.getByPosition(arguments[i]).column.get())->getValue<Float64>();
|
||||
}
|
||||
block.getByPosition(result).column = column_with_states->predictValues(predict_features);
|
||||
*/
|
||||
block.getByPosition(result).column =
|
||||
typeid_cast<const ColumnAggregateFunction *>(&*column_with_states->getDataColumnPtr())->predictValues(block, arguments);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
void registerFunctionEvalMLMethod(FunctionFactory & factory)
|
||||
{
|
||||
factory.registerFunction<FunctionEvalMLMethod>();
|
||||
return std::make_shared<FunctionEvalMLMethod>();
|
||||
}
|
||||
|
||||
String getName() const override
|
||||
{
|
||||
return name;
|
||||
}
|
||||
|
||||
bool isVariadic() const override {
|
||||
return true;
|
||||
}
|
||||
size_t getNumberOfArguments() const override
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
|
||||
{
|
||||
const DataTypeAggregateFunction * type = checkAndGetDataType<DataTypeAggregateFunction>(arguments[0].get());
|
||||
if (!type)
|
||||
throw Exception("Argument for function " + getName() + " must have type AggregateFunction - state of aggregate function.",
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
return type->getReturnType();
|
||||
}
|
||||
|
||||
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) override
|
||||
{
|
||||
const ColumnConst * column_with_states
|
||||
= typeid_cast<const ColumnConst *>(&*block.getByPosition(arguments.at(0)).column);
|
||||
|
||||
|
||||
if (!column_with_states)
|
||||
throw Exception("Illegal column " + block.getByPosition(arguments.at(0)).column->getName()
|
||||
+ " of first argument of function "
|
||||
+ getName(),
|
||||
ErrorCodes::ILLEGAL_COLUMN);
|
||||
|
||||
block.getByPosition(result).column =
|
||||
typeid_cast<const ColumnAggregateFunction *>(&*column_with_states->getDataColumnPtr())->predictValues(block, arguments);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
void registerFunctionEvalMLMethod(FunctionFactory & factory)
|
||||
{
|
||||
factory.registerFunction<FunctionEvalMLMethod>();
|
||||
}
|
||||
|
||||
}
|
@ -1,370 +0,0 @@
|
||||
0.72
|
||||
0.72
|
||||
0.72
|
||||
0.72
|
||||
0.72
|
||||
0.72
|
||||
0.72
|
||||
0.72
|
||||
0.72
|
||||
0.72
|
||||
0.72
|
||||
0.72
|
||||
0.72
|
||||
0.89
|
||||
0.89
|
||||
0.89
|
||||
0.89
|
||||
0.89
|
||||
0.89
|
||||
0.89
|
||||
0.89
|
||||
0.89
|
||||
0.89
|
||||
0.89
|
||||
0.89
|
||||
0.95
|
||||
0.95
|
||||
0.95
|
||||
0.95
|
||||
0.95
|
||||
0.95
|
||||
0.95
|
||||
0.95
|
||||
0.95
|
||||
0.95
|
||||
0.95
|
||||
0.95
|
||||
0.79
|
||||
0.79
|
||||
0.79
|
||||
0.79
|
||||
0.79
|
||||
0.79
|
||||
0.79
|
||||
0.79
|
||||
0.79
|
||||
0.79
|
||||
0.79
|
||||
0.79
|
||||
0.79
|
||||
0.39
|
||||
0.39
|
||||
0.39
|
||||
0.39
|
||||
0.39
|
||||
0.39
|
||||
0.39
|
||||
0.39
|
||||
0.39
|
||||
0.39
|
||||
0.39
|
||||
0.39
|
||||
0.38
|
||||
0.38
|
||||
0.38
|
||||
0.38
|
||||
0.38
|
||||
0.38
|
||||
0.38
|
||||
0.38
|
||||
0.38
|
||||
0.38
|
||||
0.38
|
||||
0.38
|
||||
0.34
|
||||
0.34
|
||||
0.34
|
||||
0.34
|
||||
0.34
|
||||
0.34
|
||||
0.34
|
||||
0.34
|
||||
0.34
|
||||
0.34
|
||||
0.34
|
||||
0.34
|
||||
0.34
|
||||
0.47
|
||||
0.47
|
||||
0.47
|
||||
0.47
|
||||
0.47
|
||||
0.47
|
||||
0.47
|
||||
0.47
|
||||
0.47
|
||||
0.47
|
||||
0.47
|
||||
0.47
|
||||
0.56
|
||||
0.56
|
||||
0.56
|
||||
0.56
|
||||
0.56
|
||||
0.56
|
||||
0.56
|
||||
0.56
|
||||
0.56
|
||||
0.56
|
||||
0.56
|
||||
0.56
|
||||
0.71
|
||||
0.71
|
||||
0.71
|
||||
0.71
|
||||
0.71
|
||||
0.71
|
||||
0.71
|
||||
0.71
|
||||
0.71
|
||||
0.71
|
||||
0.71
|
||||
0.71
|
||||
0.71
|
||||
0.78
|
||||
0.78
|
||||
0.78
|
||||
0.78
|
||||
0.78
|
||||
0.78
|
||||
0.78
|
||||
0.78
|
||||
0.78
|
||||
0.78
|
||||
0.78
|
||||
0.78
|
||||
0.73
|
||||
0.73
|
||||
0.73
|
||||
0.73
|
||||
0.73
|
||||
0.73
|
||||
0.73
|
||||
0.73
|
||||
0.73
|
||||
0.73
|
||||
0.73
|
||||
0.73
|
||||
0.82
|
||||
0.82
|
||||
0.82
|
||||
0.82
|
||||
0.82
|
||||
0.82
|
||||
0.82
|
||||
0.82
|
||||
0.82
|
||||
0.82
|
||||
0.82
|
||||
0.82
|
||||
0.82
|
||||
0.62
|
||||
0.62
|
||||
0.62
|
||||
0.62
|
||||
0.62
|
||||
0.62
|
||||
0.62
|
||||
0.62
|
||||
0.62
|
||||
0.62
|
||||
0.62
|
||||
0.62
|
||||
0.96
|
||||
0.96
|
||||
0.96
|
||||
0.96
|
||||
0.96
|
||||
0.96
|
||||
0.96
|
||||
0.96
|
||||
0.96
|
||||
0.96
|
||||
0.96
|
||||
0.96
|
||||
0.96
|
||||
0.96
|
||||
0.96
|
||||
0.96
|
||||
0.96
|
||||
0.96
|
||||
0.96
|
||||
0.96
|
||||
0.96
|
||||
0.96
|
||||
0.96
|
||||
0.96
|
||||
0.96
|
||||
0.46
|
||||
0.46
|
||||
0.46
|
||||
0.46
|
||||
0.46
|
||||
0.46
|
||||
0.46
|
||||
0.46
|
||||
0.46
|
||||
0.46
|
||||
0.46
|
||||
0.46
|
||||
0.53
|
||||
0.53
|
||||
0.53
|
||||
0.53
|
||||
0.53
|
||||
0.53
|
||||
0.53
|
||||
0.53
|
||||
0.53
|
||||
0.53
|
||||
0.53
|
||||
0.53
|
||||
0.49
|
||||
0.49
|
||||
0.49
|
||||
0.49
|
||||
0.49
|
||||
0.49
|
||||
0.49
|
||||
0.49
|
||||
0.49
|
||||
0.49
|
||||
0.49
|
||||
0.49
|
||||
0.49
|
||||
0.76
|
||||
0.76
|
||||
0.76
|
||||
0.76
|
||||
0.76
|
||||
0.76
|
||||
0.76
|
||||
0.76
|
||||
0.76
|
||||
0.76
|
||||
0.76
|
||||
0.76
|
||||
0.64
|
||||
0.64
|
||||
0.64
|
||||
0.64
|
||||
0.64
|
||||
0.64
|
||||
0.64
|
||||
0.64
|
||||
0.64
|
||||
0.64
|
||||
0.64
|
||||
0.64
|
||||
0.71
|
||||
0.71
|
||||
0.71
|
||||
0.71
|
||||
0.71
|
||||
0.71
|
||||
0.71
|
||||
0.71
|
||||
0.71
|
||||
0.71
|
||||
0.71
|
||||
0.71
|
||||
0.71
|
||||
0.84
|
||||
0.84
|
||||
0.84
|
||||
0.84
|
||||
0.84
|
||||
0.84
|
||||
0.84
|
||||
0.84
|
||||
0.84
|
||||
0.84
|
||||
0.84
|
||||
0.84
|
||||
0.77
|
||||
0.77
|
||||
0.77
|
||||
0.77
|
||||
0.77
|
||||
0.77
|
||||
0.77
|
||||
0.77
|
||||
0.77
|
||||
0.77
|
||||
0.77
|
||||
0.77
|
||||
0.89
|
||||
0.89
|
||||
0.89
|
||||
0.89
|
||||
0.89
|
||||
0.89
|
||||
0.89
|
||||
0.89
|
||||
0.89
|
||||
0.89
|
||||
0.89
|
||||
0.89
|
||||
0.89
|
||||
0.82
|
||||
0.82
|
||||
0.82
|
||||
0.82
|
||||
0.82
|
||||
0.82
|
||||
0.82
|
||||
0.82
|
||||
0.82
|
||||
0.82
|
||||
0.82
|
||||
0.82
|
||||
0.84
|
||||
0.84
|
||||
0.84
|
||||
0.84
|
||||
0.84
|
||||
0.84
|
||||
0.84
|
||||
0.84
|
||||
0.84
|
||||
0.84
|
||||
0.84
|
||||
0.84
|
||||
0.91
|
||||
0.91
|
||||
0.91
|
||||
0.91
|
||||
0.91
|
||||
0.91
|
||||
0.91
|
||||
0.91
|
||||
0.91
|
||||
0.91
|
||||
0.91
|
||||
0.91
|
||||
0.91
|
||||
0.67
|
||||
0.67
|
||||
0.67
|
||||
0.67
|
||||
0.67
|
||||
0.67
|
||||
0.67
|
||||
0.67
|
||||
0.67
|
||||
0.67
|
||||
0.67
|
||||
0.67
|
||||
0.95
|
||||
0.95
|
||||
0.95
|
||||
0.95
|
||||
0.95
|
||||
0.95
|
||||
0.95
|
||||
0.95
|
||||
0.95
|
||||
0.95
|
||||
0.95
|
||||
0.95
|
@ -1 +0,0 @@
|
||||
0.00676015
|
Loading…
Reference in New Issue
Block a user