mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-09-21 09:10:48 +00:00
style
This commit is contained in:
parent
12132b8fdf
commit
19021e76bb
@ -25,9 +25,7 @@ AggregateFunctionPtr createAggregateFunctionMLMethod(
|
||||
for (size_t i = 0; i < argument_types.size(); ++i)
|
||||
{
|
||||
if (!WhichDataType(argument_types[i]).isFloat64())
|
||||
throw Exception("Illegal type " + argument_types[i]->getName() + " of argument "
|
||||
+ std::to_string(i) + "for aggregate function " + name,
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
throw Exception("Illegal type " + argument_types[i]->getName() + " of argument " + std::to_string(i) + "for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
}
|
||||
|
||||
Float64 learning_rate = Float64(0.01);
|
||||
@ -55,17 +53,22 @@ AggregateFunctionPtr createAggregateFunctionMLMethod(
|
||||
if (applyVisitor(FieldVisitorConvertToNumber<UInt32>(), parameters[3]) == Float64{1.0})
|
||||
{
|
||||
wu = std::make_shared<StochasticGradientDescent>();
|
||||
} else if (applyVisitor(FieldVisitorConvertToNumber<UInt32>(), parameters[3]) == Float64{2.0})
|
||||
}
|
||||
else if (applyVisitor(FieldVisitorConvertToNumber<UInt32>(), parameters[3]) == Float64{2.0})
|
||||
{
|
||||
wu = std::make_shared<Momentum>();
|
||||
} else if (applyVisitor(FieldVisitorConvertToNumber<UInt32>(), parameters[3]) == Float64{3.0})
|
||||
}
|
||||
else if (applyVisitor(FieldVisitorConvertToNumber<UInt32>(), parameters[3]) == Float64{3.0})
|
||||
{
|
||||
wu = std::make_shared<Nesterov>();
|
||||
|
||||
} else {
|
||||
}
|
||||
else
|
||||
{
|
||||
throw Exception("Invalid parameter for weights updater", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
}
|
||||
} else
|
||||
}
|
||||
else
|
||||
{
|
||||
wu = std::make_unique<StochasticGradientDescent>();
|
||||
}
|
||||
@ -73,10 +76,12 @@ AggregateFunctionPtr createAggregateFunctionMLMethod(
|
||||
if (std::is_same<Method, FuncLinearRegression>::value)
|
||||
{
|
||||
gc = std::make_shared<LinearRegression>();
|
||||
} else if (std::is_same<Method, FuncLogisticRegression>::value)
|
||||
}
|
||||
else if (std::is_same<Method, FuncLogisticRegression>::value)
|
||||
{
|
||||
gc = std::make_shared<LogisticRegression>();
|
||||
} else
|
||||
}
|
||||
else
|
||||
{
|
||||
throw Exception("Such gradient computer is not implemented yet", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
}
|
||||
|
@ -483,8 +483,8 @@ public:
|
||||
l2_reg_coef(l2_reg_coef),
|
||||
batch_size(batch_size),
|
||||
gc(std::move(gradient_computer)),
|
||||
wu(std::move(weights_updater)) {
|
||||
}
|
||||
wu(std::move(weights_updater))
|
||||
{}
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
{
|
||||
|
@ -120,7 +120,8 @@ MutableColumnPtr ColumnAggregateFunction::predictValues(Block & block, const Col
|
||||
++row_num;
|
||||
}
|
||||
|
||||
} else if (ML_function_Logistic)
|
||||
}
|
||||
else if (ML_function_Logistic)
|
||||
{
|
||||
size_t row_num = 0;
|
||||
for (auto val : data)
|
||||
@ -128,7 +129,8 @@ MutableColumnPtr ColumnAggregateFunction::predictValues(Block & block, const Col
|
||||
ML_function_Logistic->predictResultInto(val, *res, block, arguments);
|
||||
++row_num;
|
||||
}
|
||||
} else
|
||||
}
|
||||
else
|
||||
{
|
||||
throw Exception("Illegal aggregate function is passed",
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
@ -39,7 +39,8 @@ public:
|
||||
return name;
|
||||
}
|
||||
|
||||
bool isVariadic() const override {
|
||||
bool isVariadic() const override
|
||||
{
|
||||
return true;
|
||||
}
|
||||
size_t getNumberOfArguments() const override
|
||||
|
Loading…
Reference in New Issue
Block a user