code style

This commit is contained in:
Maxim Kuznetsov 2019-01-23 15:11:17 +03:00
commit fd1beddb6d
6 changed files with 129 additions and 86 deletions

View File

@ -7,34 +7,45 @@
namespace DB
{
namespace
namespace
{
using FuncLinearRegression = AggregateFunctionMLMethod<LinearRegressionData, NameLinearRegression>;
template <class Method>
AggregateFunctionPtr createAggregateFunctionMLMethod(
const std::string & name, const DataTypes & argument_types, const Array & parameters)
{
if (parameters.size() > 1)
throw Exception("Aggregate function " + name + " requires at most one parameter", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
for (size_t i = 0; i < argument_types.size(); ++i)
{
using FuncLinearRegression = AggregateFunctionMLMethod<LinearRegressionData, NameLinearRegression>;
template <class Method>
AggregateFunctionPtr createAggregateFunctionMLMethod(
const std::string & name, const DataTypes & arguments, const Array & parameters)
{
if (parameters.size() > 1)
throw Exception("Aggregate function " + name + " requires at most one parameter", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
Float64 lr;
if (parameters.empty())
lr = Float64(0.01);
else
lr = static_cast<const Float64>(parameters[0].template get<Float64>());
if (arguments.size() < 2)
throw Exception("Aggregate function " + name + " requires at least two arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
return std::make_shared<Method>(arguments.size() - 1, lr);
}
if (!WhichDataType(argument_types[i]).isFloat64())
throw Exception("Illegal type " + argument_types[i]->getName() + " of argument " +
std::to_string(i) + "for aggregate function " + name,
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
void registerAggregateFunctionMLMethod(AggregateFunctionFactory & factory) {
factory.registerFunction("LinearRegression", createAggregateFunctionMLMethod<FuncLinearRegression>);
Float64 learning_rate;
if (parameters.empty())
{
learning_rate = Float64(0.01);
} else
{
learning_rate = applyVisitor(FieldVisitorConvertToNumber<Float64>(), parameters[0]);
}
if (argument_types.size() < 2)
throw Exception("Aggregate function " + name + " requires at least two arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
return std::make_shared<Method>(argument_types.size() - 1, learning_rate);
}
}
void registerAggregateFunctionMLMethod(AggregateFunctionFactory & factory) {
factory.registerFunction("LinearRegression", createAggregateFunctionMLMethod<FuncLinearRegression>);
}
}

View File

@ -23,35 +23,41 @@
namespace DB
{
namespace ErrorCodes
{
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
extern const int BAD_ARGUMENTS;
}
struct LinearRegressionData
{
LinearRegressionData()
{}
LinearRegressionData(Float64 learning_rate_, UInt32 param_num_)
: learning_rate(learning_rate_) {
weights.resize(param_num_);
}
Float64 bias{0.0};
std::vector<Float64> w1;
Float64 learning_rate{0.01};
std::vector<Float64> weights;
Float64 learning_rate;
UInt32 iter_num = 0;
UInt32 param_num = 0;
void add(Float64 target, std::vector<Float64>& feature, Float64 learning_rate_, UInt32 param_num_)
void add(Float64 target, const IColumn ** columns, size_t row_num)
{
if (w1.empty())
{
learning_rate = learning_rate_;
param_num = param_num_;
w1.resize(param_num);
}
Float64 derivative = (target - bias);
for (size_t i = 0; i < param_num; ++i)
for (size_t i = 0; i < weights.size(); ++i)
{
derivative -= w1[i] * feature[i];
derivative -= weights[i] * static_cast<const ColumnVector<Float64> &>(*columns[i + 1]).getData()[row_num];
}
derivative *= (2 * learning_rate);
bias += derivative;
for (size_t i = 0; i < param_num; ++i)
for (size_t i = 0; i < weights.size(); ++i)
{
w1[i] += derivative * feature[i];
weights[i] += derivative * static_cast<const ColumnVector<Float64> &>(*columns[i + 1]).getData()[row_num];;
}
++iter_num;
@ -60,20 +66,14 @@ struct LinearRegressionData
void merge(const LinearRegressionData & rhs)
{
if (iter_num == 0 && rhs.iter_num == 0)
throw std::runtime_error("Strange...");
if (param_num == 0)
{
param_num = rhs.param_num;
w1.resize(param_num);
}
return;
Float64 frac = static_cast<Float64>(iter_num) / (iter_num + rhs.iter_num);
Float64 rhs_frac = static_cast<Float64>(rhs.iter_num) / (iter_num + rhs.iter_num);
for (size_t i = 0; i < param_num; ++i)
for (size_t i = 0; i < weights.size(); ++i)
{
w1[i] = w1[i] * frac + rhs.w1[i] * rhs_frac;
weights[i] = weights[i] * frac + rhs.weights[i] * rhs_frac;
}
bias = bias * frac + rhs.bias * rhs_frac;
@ -83,23 +83,23 @@ struct LinearRegressionData
void write(WriteBuffer & buf) const
{
writeBinary(bias, buf);
writeBinary(w1, buf);
writeBinary(weights, buf);
writeBinary(iter_num, buf);
}
void read(ReadBuffer & buf)
{
readBinary(bias, buf);
readBinary(w1, buf);
readBinary(weights, buf);
readBinary(iter_num, buf);
}
Float64 predict(std::vector<Float64>& predict_feature) const
Float64 predict(const std::vector<Float64>& predict_feature) const
{
Float64 res{0.0};
for (size_t i = 0; i < static_cast<size_t>(param_num); ++i)
for (size_t i = 0; i < predict_feature.size(); ++i)
{
res += predict_feature[i] * w1[i];
res += predict_feature[i] * weights[i];
}
res += bias;
@ -128,18 +128,16 @@ public:
return std::make_shared<DataTypeNumber<Float64>>();
}
void create(AggregateDataPtr place) const override
{
new (place) Data(learning_rate, param_num);
}
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
{
const auto & target = static_cast<const ColumnVector<Float64> &>(*columns[0]);
std::vector<Float64> x(param_num);
for (size_t i = 0; i < param_num; ++i)
{
x[i] = static_cast<const ColumnVector<Float64> &>(*columns[i + 1]).getData()[row_num];
}
this->data(place).add(target.getData()[row_num], x, learning_rate, param_num);
this->data(place).add(target.getData()[row_num], columns, row_num);
}
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
@ -159,18 +157,23 @@ public:
void predictResultInto(ConstAggregateDataPtr place, IColumn & to, Block & block, size_t row_num, const ColumnNumbers & arguments) const
{
if (arguments.size() != param_num + 1)
throw Exception("Predict got incorrect number of arguments. Got: " + std::to_string(arguments.size()) + ". Required: " + std::to_string(param_num + 1),
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
auto &column = dynamic_cast<ColumnVector<Float64> &>(to);
std::vector<Float64> predict_features(arguments.size() - 1);
// for (size_t row_num = 0, rows = block.rows(); row_num < rows; ++row_num)
for (size_t i = 1; i < arguments.size(); ++i)
{
// predict_features[i] = array_elements[i].get<Float64>();
predict_features[i - 1] = applyVisitor(FieldVisitorConvertToNumber<Float64>(), (*block.getByPosition(arguments[i]).column)[row_num]);
const auto& element = (*block.getByPosition(arguments[i]).column)[row_num];
if (element.getType() != Field::Types::Float64)
throw Exception("Prediction arguments must be values of type Float",
ErrorCodes::BAD_ARGUMENTS);
predict_features[i - 1] = element.get<Float64>();
}
// column.getData().push_back(this->data(place).predict(predict_features));
column.getData().push_back(this->data(place).predict(predict_features));
// }
}
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override

View File

@ -137,7 +137,7 @@ protected:
static const Data & data(ConstAggregateDataPtr place) { return *reinterpret_cast<const Data*>(place); }
public:
void create(AggregateDataPtr place) const override
virtual void create(AggregateDataPtr place) const override
{
new (place) Data;
}

View File

@ -17,6 +17,7 @@ namespace ErrorCodes
{
extern const int PARAMETER_OUT_OF_BOUND;
extern const int SIZES_OF_COLUMNS_DOESNT_MATCH;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
@ -32,6 +33,23 @@ void ColumnAggregateFunction::addArena(ArenaPtr arena_)
arenas.push_back(arena_);
}
bool ColumnAggregateFunction::convertion(MutableColumnPtr* res_) const
{
if (const AggregateFunctionState * function_state = typeid_cast<const AggregateFunctionState *>(func.get()))
{
auto res = createView();
res->set(function_state->getNestedFunction());
res->data.assign(data.begin(), data.end());
*res_ = std::move(res);
return true;
}
MutableColumnPtr res = func->getReturnType()->createColumn();
res->reserve(data.size());
*res_ = std::move(res);
return false;
}
MutableColumnPtr ColumnAggregateFunction::convertToValues() const
{
/** If the aggregate function returns an unfinalized/unfinished state,
@ -64,38 +82,46 @@ MutableColumnPtr ColumnAggregateFunction::convertToValues() const
* AggregateFunction(quantileTiming(0.5), UInt64)
* into UInt16 - already finished result of `quantileTiming`.
*/
if (const AggregateFunctionState * function_state = typeid_cast<const AggregateFunctionState *>(func.get()))
// if (const AggregateFunctionState * function_state = typeid_cast<const AggregateFunctionState *>(func.get()))
// {
// auto res = createView();
// res->set(function_state->getNestedFunction());
// res->data.assign(data.begin(), data.end());
// return res;
// }
//
// MutableColumnPtr res = func->getReturnType()->createColumn();
// res->reserve(data.size());
MutableColumnPtr res;
if (convertion(&res))
{
auto res = createView();
res->set(function_state->getNestedFunction());
res->data.assign(data.begin(), data.end());
return res;
}
MutableColumnPtr res = func->getReturnType()->createColumn();
res->reserve(data.size());
for (auto val : data)
func->insertResultInto(val, *res);
return res;
}
//MutableColumnPtr ColumnAggregateFunction::predictValues(std::vector<Float64> predict_feature) const
MutableColumnPtr ColumnAggregateFunction::predictValues(Block & block, const ColumnNumbers & arguments) const
{
if (const AggregateFunctionState * function_state = typeid_cast<const AggregateFunctionState *>(func.get()))
// if (const AggregateFunctionState * function_state = typeid_cast<const AggregateFunctionState *>(func.get()))
// {
// auto res = createView();
// res->set(function_state->getNestedFunction());
// res->data.assign(data.begin(), data.end());
// return res;
// }
//
// MutableColumnPtr res = func->getReturnType()->createColumn();
// res->reserve(data.size());
MutableColumnPtr res;
if (convertion(&res))
{
auto res = createView();
res->set(function_state->getNestedFunction());
res->data.assign(data.begin(), data.end());
return res;
}
MutableColumnPtr res = func->getReturnType()->createColumn();
res->reserve(data.size());
// const AggregateFunctionMLMethod * ML_function = typeid_cast<const AggregateFunctionMLMethod *>(func.get());
auto ML_function = typeid_cast<const AggregateFunctionMLMethod<LinearRegressionData, NameLinearRegression> *>(func.get());
if (ML_function)
{
@ -105,7 +131,8 @@ MutableColumnPtr ColumnAggregateFunction::predictValues(Block & block, const Col
++row_num;
}
} else {
throw Exception("Illegal aggregate function is passed",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
return res;

View File

@ -116,6 +116,7 @@ public:
std::string getName() const override { return "AggregateFunction(" + func->getName() + ")"; }
const char * getFamilyName() const override { return "AggregateFunction"; }
bool convertion(MutableColumnPtr* res_) const;
MutableColumnPtr predictValues(Block & block, const ColumnNumbers & arguments) const;
size_t size() const override

View File

@ -59,6 +59,7 @@ namespace DB
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) override
{
// завести МЛ_аггрункции как отдельный класс, чтобы тут сразу это проверять, а не делать это внутри predictValues()
const ColumnAggregateFunction * column_with_states
= typeid_cast<const ColumnAggregateFunction *>(&*block.getByPosition(arguments.at(0)).column);
if (!column_with_states)