mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-12-11 17:02:25 +00:00
code style
This commit is contained in:
commit
fd1beddb6d
@ -7,34 +7,45 @@
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace
|
||||
namespace
|
||||
{
|
||||
|
||||
using FuncLinearRegression = AggregateFunctionMLMethod<LinearRegressionData, NameLinearRegression>;
|
||||
|
||||
template <class Method>
|
||||
AggregateFunctionPtr createAggregateFunctionMLMethod(
|
||||
const std::string & name, const DataTypes & argument_types, const Array & parameters)
|
||||
{
|
||||
if (parameters.size() > 1)
|
||||
throw Exception("Aggregate function " + name + " requires at most one parameter", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
|
||||
for (size_t i = 0; i < argument_types.size(); ++i)
|
||||
{
|
||||
|
||||
using FuncLinearRegression = AggregateFunctionMLMethod<LinearRegressionData, NameLinearRegression>;
|
||||
|
||||
template <class Method>
|
||||
AggregateFunctionPtr createAggregateFunctionMLMethod(
|
||||
const std::string & name, const DataTypes & arguments, const Array & parameters)
|
||||
{
|
||||
if (parameters.size() > 1)
|
||||
throw Exception("Aggregate function " + name + " requires at most one parameter", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
|
||||
Float64 lr;
|
||||
if (parameters.empty())
|
||||
lr = Float64(0.01);
|
||||
else
|
||||
lr = static_cast<const Float64>(parameters[0].template get<Float64>());
|
||||
|
||||
if (arguments.size() < 2)
|
||||
throw Exception("Aggregate function " + name + " requires at least two arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
|
||||
return std::make_shared<Method>(arguments.size() - 1, lr);
|
||||
}
|
||||
|
||||
if (!WhichDataType(argument_types[i]).isFloat64())
|
||||
throw Exception("Illegal type " + argument_types[i]->getName() + " of argument " +
|
||||
std::to_string(i) + "for aggregate function " + name,
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
}
|
||||
|
||||
void registerAggregateFunctionMLMethod(AggregateFunctionFactory & factory) {
|
||||
factory.registerFunction("LinearRegression", createAggregateFunctionMLMethod<FuncLinearRegression>);
|
||||
Float64 learning_rate;
|
||||
if (parameters.empty())
|
||||
{
|
||||
learning_rate = Float64(0.01);
|
||||
} else
|
||||
{
|
||||
learning_rate = applyVisitor(FieldVisitorConvertToNumber<Float64>(), parameters[0]);
|
||||
}
|
||||
|
||||
if (argument_types.size() < 2)
|
||||
throw Exception("Aggregate function " + name + " requires at least two arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
|
||||
return std::make_shared<Method>(argument_types.size() - 1, learning_rate);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void registerAggregateFunctionMLMethod(AggregateFunctionFactory & factory) {
|
||||
factory.registerFunction("LinearRegression", createAggregateFunctionMLMethod<FuncLinearRegression>);
|
||||
}
|
||||
|
||||
}
|
@ -23,35 +23,41 @@
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
|
||||
extern const int BAD_ARGUMENTS;
|
||||
}
|
||||
|
||||
|
||||
struct LinearRegressionData
|
||||
{
|
||||
LinearRegressionData()
|
||||
{}
|
||||
LinearRegressionData(Float64 learning_rate_, UInt32 param_num_)
|
||||
: learning_rate(learning_rate_) {
|
||||
weights.resize(param_num_);
|
||||
}
|
||||
|
||||
Float64 bias{0.0};
|
||||
std::vector<Float64> w1;
|
||||
Float64 learning_rate{0.01};
|
||||
std::vector<Float64> weights;
|
||||
Float64 learning_rate;
|
||||
UInt32 iter_num = 0;
|
||||
UInt32 param_num = 0;
|
||||
|
||||
|
||||
void add(Float64 target, std::vector<Float64>& feature, Float64 learning_rate_, UInt32 param_num_)
|
||||
void add(Float64 target, const IColumn ** columns, size_t row_num)
|
||||
{
|
||||
if (w1.empty())
|
||||
{
|
||||
learning_rate = learning_rate_;
|
||||
param_num = param_num_;
|
||||
w1.resize(param_num);
|
||||
}
|
||||
|
||||
Float64 derivative = (target - bias);
|
||||
for (size_t i = 0; i < param_num; ++i)
|
||||
for (size_t i = 0; i < weights.size(); ++i)
|
||||
{
|
||||
derivative -= w1[i] * feature[i];
|
||||
derivative -= weights[i] * static_cast<const ColumnVector<Float64> &>(*columns[i + 1]).getData()[row_num];
|
||||
|
||||
}
|
||||
derivative *= (2 * learning_rate);
|
||||
|
||||
bias += derivative;
|
||||
for (size_t i = 0; i < param_num; ++i)
|
||||
for (size_t i = 0; i < weights.size(); ++i)
|
||||
{
|
||||
w1[i] += derivative * feature[i];
|
||||
weights[i] += derivative * static_cast<const ColumnVector<Float64> &>(*columns[i + 1]).getData()[row_num];;
|
||||
}
|
||||
|
||||
++iter_num;
|
||||
@ -60,20 +66,14 @@ struct LinearRegressionData
|
||||
void merge(const LinearRegressionData & rhs)
|
||||
{
|
||||
if (iter_num == 0 && rhs.iter_num == 0)
|
||||
throw std::runtime_error("Strange...");
|
||||
|
||||
if (param_num == 0)
|
||||
{
|
||||
param_num = rhs.param_num;
|
||||
w1.resize(param_num);
|
||||
}
|
||||
return;
|
||||
|
||||
Float64 frac = static_cast<Float64>(iter_num) / (iter_num + rhs.iter_num);
|
||||
Float64 rhs_frac = static_cast<Float64>(rhs.iter_num) / (iter_num + rhs.iter_num);
|
||||
|
||||
for (size_t i = 0; i < param_num; ++i)
|
||||
for (size_t i = 0; i < weights.size(); ++i)
|
||||
{
|
||||
w1[i] = w1[i] * frac + rhs.w1[i] * rhs_frac;
|
||||
weights[i] = weights[i] * frac + rhs.weights[i] * rhs_frac;
|
||||
}
|
||||
|
||||
bias = bias * frac + rhs.bias * rhs_frac;
|
||||
@ -83,23 +83,23 @@ struct LinearRegressionData
|
||||
void write(WriteBuffer & buf) const
|
||||
{
|
||||
writeBinary(bias, buf);
|
||||
writeBinary(w1, buf);
|
||||
writeBinary(weights, buf);
|
||||
writeBinary(iter_num, buf);
|
||||
}
|
||||
|
||||
void read(ReadBuffer & buf)
|
||||
{
|
||||
readBinary(bias, buf);
|
||||
readBinary(w1, buf);
|
||||
readBinary(weights, buf);
|
||||
readBinary(iter_num, buf);
|
||||
}
|
||||
|
||||
Float64 predict(std::vector<Float64>& predict_feature) const
|
||||
Float64 predict(const std::vector<Float64>& predict_feature) const
|
||||
{
|
||||
Float64 res{0.0};
|
||||
for (size_t i = 0; i < static_cast<size_t>(param_num); ++i)
|
||||
for (size_t i = 0; i < predict_feature.size(); ++i)
|
||||
{
|
||||
res += predict_feature[i] * w1[i];
|
||||
res += predict_feature[i] * weights[i];
|
||||
}
|
||||
res += bias;
|
||||
|
||||
@ -128,18 +128,16 @@ public:
|
||||
return std::make_shared<DataTypeNumber<Float64>>();
|
||||
}
|
||||
|
||||
void create(AggregateDataPtr place) const override
|
||||
{
|
||||
new (place) Data(learning_rate, param_num);
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
{
|
||||
const auto & target = static_cast<const ColumnVector<Float64> &>(*columns[0]);
|
||||
|
||||
std::vector<Float64> x(param_num);
|
||||
for (size_t i = 0; i < param_num; ++i)
|
||||
{
|
||||
x[i] = static_cast<const ColumnVector<Float64> &>(*columns[i + 1]).getData()[row_num];
|
||||
}
|
||||
|
||||
this->data(place).add(target.getData()[row_num], x, learning_rate, param_num);
|
||||
|
||||
this->data(place).add(target.getData()[row_num], columns, row_num);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
@ -159,18 +157,23 @@ public:
|
||||
|
||||
void predictResultInto(ConstAggregateDataPtr place, IColumn & to, Block & block, size_t row_num, const ColumnNumbers & arguments) const
|
||||
{
|
||||
if (arguments.size() != param_num + 1)
|
||||
throw Exception("Predict got incorrect number of arguments. Got: " + std::to_string(arguments.size()) + ". Required: " + std::to_string(param_num + 1),
|
||||
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
|
||||
auto &column = dynamic_cast<ColumnVector<Float64> &>(to);
|
||||
|
||||
std::vector<Float64> predict_features(arguments.size() - 1);
|
||||
// for (size_t row_num = 0, rows = block.rows(); row_num < rows; ++row_num)
|
||||
for (size_t i = 1; i < arguments.size(); ++i)
|
||||
{
|
||||
// predict_features[i] = array_elements[i].get<Float64>();
|
||||
predict_features[i - 1] = applyVisitor(FieldVisitorConvertToNumber<Float64>(), (*block.getByPosition(arguments[i]).column)[row_num]);
|
||||
const auto& element = (*block.getByPosition(arguments[i]).column)[row_num];
|
||||
if (element.getType() != Field::Types::Float64)
|
||||
throw Exception("Prediction arguments must be values of type Float",
|
||||
ErrorCodes::BAD_ARGUMENTS);
|
||||
|
||||
predict_features[i - 1] = element.get<Float64>();
|
||||
}
|
||||
// column.getData().push_back(this->data(place).predict(predict_features));
|
||||
column.getData().push_back(this->data(place).predict(predict_features));
|
||||
// }
|
||||
}
|
||||
|
||||
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
|
||||
|
@ -137,7 +137,7 @@ protected:
|
||||
static const Data & data(ConstAggregateDataPtr place) { return *reinterpret_cast<const Data*>(place); }
|
||||
|
||||
public:
|
||||
void create(AggregateDataPtr place) const override
|
||||
virtual void create(AggregateDataPtr place) const override
|
||||
{
|
||||
new (place) Data;
|
||||
}
|
||||
|
@ -17,6 +17,7 @@ namespace ErrorCodes
|
||||
{
|
||||
extern const int PARAMETER_OUT_OF_BOUND;
|
||||
extern const int SIZES_OF_COLUMNS_DOESNT_MATCH;
|
||||
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||
}
|
||||
|
||||
|
||||
@ -32,6 +33,23 @@ void ColumnAggregateFunction::addArena(ArenaPtr arena_)
|
||||
arenas.push_back(arena_);
|
||||
}
|
||||
|
||||
bool ColumnAggregateFunction::convertion(MutableColumnPtr* res_) const
|
||||
{
|
||||
if (const AggregateFunctionState * function_state = typeid_cast<const AggregateFunctionState *>(func.get()))
|
||||
{
|
||||
auto res = createView();
|
||||
res->set(function_state->getNestedFunction());
|
||||
res->data.assign(data.begin(), data.end());
|
||||
*res_ = std::move(res);
|
||||
return true;
|
||||
}
|
||||
|
||||
MutableColumnPtr res = func->getReturnType()->createColumn();
|
||||
res->reserve(data.size());
|
||||
*res_ = std::move(res);
|
||||
return false;
|
||||
}
|
||||
|
||||
MutableColumnPtr ColumnAggregateFunction::convertToValues() const
|
||||
{
|
||||
/** If the aggregate function returns an unfinalized/unfinished state,
|
||||
@ -64,38 +82,46 @@ MutableColumnPtr ColumnAggregateFunction::convertToValues() const
|
||||
* AggregateFunction(quantileTiming(0.5), UInt64)
|
||||
* into UInt16 - already finished result of `quantileTiming`.
|
||||
*/
|
||||
if (const AggregateFunctionState * function_state = typeid_cast<const AggregateFunctionState *>(func.get()))
|
||||
// if (const AggregateFunctionState * function_state = typeid_cast<const AggregateFunctionState *>(func.get()))
|
||||
// {
|
||||
// auto res = createView();
|
||||
// res->set(function_state->getNestedFunction());
|
||||
// res->data.assign(data.begin(), data.end());
|
||||
// return res;
|
||||
// }
|
||||
//
|
||||
// MutableColumnPtr res = func->getReturnType()->createColumn();
|
||||
// res->reserve(data.size());
|
||||
MutableColumnPtr res;
|
||||
if (convertion(&res))
|
||||
{
|
||||
auto res = createView();
|
||||
res->set(function_state->getNestedFunction());
|
||||
res->data.assign(data.begin(), data.end());
|
||||
return res;
|
||||
}
|
||||
|
||||
MutableColumnPtr res = func->getReturnType()->createColumn();
|
||||
res->reserve(data.size());
|
||||
|
||||
for (auto val : data)
|
||||
func->insertResultInto(val, *res);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
//MutableColumnPtr ColumnAggregateFunction::predictValues(std::vector<Float64> predict_feature) const
|
||||
MutableColumnPtr ColumnAggregateFunction::predictValues(Block & block, const ColumnNumbers & arguments) const
|
||||
{
|
||||
if (const AggregateFunctionState * function_state = typeid_cast<const AggregateFunctionState *>(func.get()))
|
||||
// if (const AggregateFunctionState * function_state = typeid_cast<const AggregateFunctionState *>(func.get()))
|
||||
// {
|
||||
// auto res = createView();
|
||||
// res->set(function_state->getNestedFunction());
|
||||
// res->data.assign(data.begin(), data.end());
|
||||
// return res;
|
||||
// }
|
||||
//
|
||||
// MutableColumnPtr res = func->getReturnType()->createColumn();
|
||||
// res->reserve(data.size());
|
||||
MutableColumnPtr res;
|
||||
if (convertion(&res))
|
||||
{
|
||||
auto res = createView();
|
||||
res->set(function_state->getNestedFunction());
|
||||
res->data.assign(data.begin(), data.end());
|
||||
return res;
|
||||
}
|
||||
|
||||
MutableColumnPtr res = func->getReturnType()->createColumn();
|
||||
res->reserve(data.size());
|
||||
|
||||
// const AggregateFunctionMLMethod * ML_function = typeid_cast<const AggregateFunctionMLMethod *>(func.get());
|
||||
auto ML_function = typeid_cast<const AggregateFunctionMLMethod<LinearRegressionData, NameLinearRegression> *>(func.get());
|
||||
if (ML_function)
|
||||
{
|
||||
@ -105,7 +131,8 @@ MutableColumnPtr ColumnAggregateFunction::predictValues(Block & block, const Col
|
||||
++row_num;
|
||||
}
|
||||
} else {
|
||||
|
||||
throw Exception("Illegal aggregate function is passed",
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
}
|
||||
|
||||
return res;
|
||||
|
@ -116,6 +116,7 @@ public:
|
||||
std::string getName() const override { return "AggregateFunction(" + func->getName() + ")"; }
|
||||
const char * getFamilyName() const override { return "AggregateFunction"; }
|
||||
|
||||
bool convertion(MutableColumnPtr* res_) const;
|
||||
MutableColumnPtr predictValues(Block & block, const ColumnNumbers & arguments) const;
|
||||
|
||||
size_t size() const override
|
||||
|
@ -59,6 +59,7 @@ namespace DB
|
||||
|
||||
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) override
|
||||
{
|
||||
// завести МЛ_аггр_функции как отдельный класс, чтобы тут сразу это проверять, а не делать это внутри predictValues()
|
||||
const ColumnAggregateFunction * column_with_states
|
||||
= typeid_cast<const ColumnAggregateFunction *>(&*block.getByPosition(arguments.at(0)).column);
|
||||
if (!column_with_states)
|
||||
|
Loading…
Reference in New Issue
Block a user