From 922872fb9a975bd4d9e9236807c06c12b7022598 Mon Sep 17 00:00:00 2001 From: Alexey Milovidov Date: Sat, 2 Oct 2021 11:39:26 +0300 Subject: [PATCH] Move modelEvaluate function to its own file --- src/Functions/FunctionsExternalModels.cpp | 141 ----------------- src/Functions/FunctionsExternalModels.h | 44 ------ src/Functions/modelEvaluate.cpp | 177 ++++++++++++++++++++++ src/Interpreters/CatBoostModel.h | 1 + 4 files changed, 178 insertions(+), 185 deletions(-) delete mode 100644 src/Functions/FunctionsExternalModels.cpp delete mode 100644 src/Functions/FunctionsExternalModels.h create mode 100644 src/Functions/modelEvaluate.cpp diff --git a/src/Functions/FunctionsExternalModels.cpp b/src/Functions/FunctionsExternalModels.cpp deleted file mode 100644 index e3b0e852731..00000000000 --- a/src/Functions/FunctionsExternalModels.cpp +++ /dev/null @@ -1,141 +0,0 @@ -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - - -namespace DB -{ - -FunctionPtr FunctionModelEvaluate::create(ContextPtr context) -{ - return std::make_shared(context->getExternalModelsLoader()); -} - -namespace ErrorCodes -{ - extern const int ILLEGAL_TYPE_OF_ARGUMENT; - extern const int TOO_FEW_ARGUMENTS_FOR_FUNCTION; - extern const int ILLEGAL_COLUMN; -} - -DataTypePtr FunctionModelEvaluate::getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const -{ - if (arguments.size() < 2) - throw Exception("Function " + getName() + " expects at least 2 arguments", - ErrorCodes::TOO_FEW_ARGUMENTS_FOR_FUNCTION); - - if (!isString(arguments[0].type)) - throw Exception("Illegal type " + arguments[0].type->getName() + " of first argument of function " + getName() - + ", expected a string.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); - - const auto * name_col = checkAndGetColumnConst(arguments[0].column.get()); - if (!name_col) - throw Exception("First argument of function " + getName() + " must be a constant string", - ErrorCodes::ILLEGAL_COLUMN); - - bool has_nullable = false; - for (size_t i = 1; i < arguments.size(); ++i) - has_nullable = has_nullable || arguments[i].type->isNullable(); - - auto model = models_loader.getModel(name_col->getValue()); - auto type = model->getReturnType(); - - if (has_nullable) - { - if (const auto * tuple = typeid_cast(type.get())) - { - auto elements = tuple->getElements(); - for (auto & element : elements) - element = makeNullable(element); - - type = std::make_shared(elements); - } - else - type = makeNullable(type); - } - - return type; -} - -ColumnPtr FunctionModelEvaluate::executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t /*input_rows_count*/) const -{ - const auto * name_col = checkAndGetColumnConst(arguments[0].column.get()); - if (!name_col) - throw Exception("First argument of function " + getName() + " must be a constant string", - ErrorCodes::ILLEGAL_COLUMN); - - auto model = models_loader.getModel(name_col->getValue()); - - ColumnRawPtrs column_ptrs; - Columns materialized_columns; - ColumnPtr null_map; - - column_ptrs.reserve(arguments.size()); - for (auto arg : collections::range(1, arguments.size())) - { - const auto & column = arguments[arg].column; - column_ptrs.push_back(column.get()); - if (auto full_column = column->convertToFullColumnIfConst()) - { - materialized_columns.push_back(full_column); - column_ptrs.back() = full_column.get(); - } - if (const auto * col_nullable = checkAndGetColumn(*column_ptrs.back())) - { - if (!null_map) - null_map = col_nullable->getNullMapColumnPtr(); - else - { - auto mut_null_map = IColumn::mutate(std::move(null_map)); - - NullMap & result_null_map = assert_cast(*mut_null_map).getData(); - const NullMap & src_null_map = col_nullable->getNullMapColumn().getData(); - - for (size_t i = 0, size = result_null_map.size(); i < size; ++i) - if (src_null_map[i]) - result_null_map[i] = 1; - - null_map = std::move(mut_null_map); - } - - column_ptrs.back() = &col_nullable->getNestedColumn(); - } - } - - auto res = model->evaluate(column_ptrs); - - if (null_map) - { - if (const auto * tuple = typeid_cast(res.get())) - { - auto nested = tuple->getColumns(); - for (auto & col : nested) - col = ColumnNullable::create(col, null_map); - - res = ColumnTuple::create(nested); - } - else - res = ColumnNullable::create(res, null_map); - } - - return res; -} - -void registerFunctionsExternalModels(FunctionFactory & factory) -{ - factory.registerFunction(); -} - -} diff --git a/src/Functions/FunctionsExternalModels.h b/src/Functions/FunctionsExternalModels.h deleted file mode 100644 index ecfb4179638..00000000000 --- a/src/Functions/FunctionsExternalModels.h +++ /dev/null @@ -1,44 +0,0 @@ -#pragma once - -#include -#include - -namespace DB -{ - -class ExternalModelsLoader; - -/// Evaluate external model. -/// First argument - model name, the others - model arguments. -/// * for CatBoost model - float features first, then categorical -/// Result - Float64. -class FunctionModelEvaluate final : public IFunction -{ -public: - static constexpr auto name = "modelEvaluate"; - - static FunctionPtr create(ContextPtr context); - - explicit FunctionModelEvaluate(const ExternalModelsLoader & models_loader_) : models_loader(models_loader_) {} - - String getName() const override { return name; } - - bool isVariadic() const override { return true; } - - bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; } - - bool isDeterministic() const override { return false; } - - bool useDefaultImplementationForNulls() const override { return false; } - - size_t getNumberOfArguments() const override { return 0; } - - DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override; - - ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override; - -private: - const ExternalModelsLoader & models_loader; -}; - -} diff --git a/src/Functions/modelEvaluate.cpp b/src/Functions/modelEvaluate.cpp new file mode 100644 index 00000000000..9b9c431981d --- /dev/null +++ b/src/Functions/modelEvaluate.cpp @@ -0,0 +1,177 @@ +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int ILLEGAL_TYPE_OF_ARGUMENT; + extern const int TOO_FEW_ARGUMENTS_FOR_FUNCTION; + extern const int ILLEGAL_COLUMN; +} + +class ExternalModelsLoader; + + +/// Evaluate external model. +/// First argument - model name, the others - model arguments. +/// * for CatBoost model - float features first, then categorical +/// Result - Float64. +class FunctionModelEvaluate final : public IFunction +{ +public: + static constexpr auto name = "modelEvaluate"; + + static FunctionPtr create(ContextPtr context) + { + return std::make_shared(context->getExternalModelsLoader()); + } + + explicit FunctionModelEvaluate(const ExternalModelsLoader & models_loader_) + : models_loader(models_loader_) {} + + String getName() const override { return name; } + + bool isVariadic() const override { return true; } + + bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; } + + bool isDeterministic() const override { return false; } + + bool useDefaultImplementationForNulls() const override { return false; } + + size_t getNumberOfArguments() const override { return 0; } + + + DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override + { + if (arguments.size() < 2) + throw Exception("Function " + getName() + " expects at least 2 arguments", + ErrorCodes::TOO_FEW_ARGUMENTS_FOR_FUNCTION); + + if (!isString(arguments[0].type)) + throw Exception("Illegal type " + arguments[0].type->getName() + " of first argument of function " + getName() + + ", expected a string.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + const auto * name_col = checkAndGetColumnConst(arguments[0].column.get()); + if (!name_col) + throw Exception("First argument of function " + getName() + " must be a constant string", + ErrorCodes::ILLEGAL_COLUMN); + + bool has_nullable = false; + for (size_t i = 1; i < arguments.size(); ++i) + has_nullable = has_nullable || arguments[i].type->isNullable(); + + auto model = models_loader.getModel(name_col->getValue()); + auto type = model->getReturnType(); + + if (has_nullable) + { + if (const auto * tuple = typeid_cast(type.get())) + { + auto elements = tuple->getElements(); + for (auto & element : elements) + element = makeNullable(element); + + type = std::make_shared(elements); + } + else + type = makeNullable(type); + } + + return type; + } + + + ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override + { + const auto * name_col = checkAndGetColumnConst(arguments[0].column.get()); + if (!name_col) + throw Exception("First argument of function " + getName() + " must be a constant string", + ErrorCodes::ILLEGAL_COLUMN); + + auto model = models_loader.getModel(name_col->getValue()); + + ColumnRawPtrs column_ptrs; + Columns materialized_columns; + ColumnPtr null_map; + + column_ptrs.reserve(arguments.size()); + for (auto arg : collections::range(1, arguments.size())) + { + const auto & column = arguments[arg].column; + column_ptrs.push_back(column.get()); + if (auto full_column = column->convertToFullColumnIfConst()) + { + materialized_columns.push_back(full_column); + column_ptrs.back() = full_column.get(); + } + if (const auto * col_nullable = checkAndGetColumn(*column_ptrs.back())) + { + if (!null_map) + null_map = col_nullable->getNullMapColumnPtr(); + else + { + auto mut_null_map = IColumn::mutate(std::move(null_map)); + + NullMap & result_null_map = assert_cast(*mut_null_map).getData(); + const NullMap & src_null_map = col_nullable->getNullMapColumn().getData(); + + for (size_t i = 0, size = result_null_map.size(); i < size; ++i) + if (src_null_map[i]) + result_null_map[i] = 1; + + null_map = std::move(mut_null_map); + } + + column_ptrs.back() = &col_nullable->getNestedColumn(); + } + } + + auto res = model->evaluate(column_ptrs); + + if (null_map) + { + if (const auto * tuple = typeid_cast(res.get())) + { + auto nested = tuple->getColumns(); + for (auto & col : nested) + col = ColumnNullable::create(col, null_map); + + res = ColumnTuple::create(nested); + } + else + res = ColumnNullable::create(res, null_map); + } + + return res; + } + +private: + const ExternalModelsLoader & models_loader; +}; + + +void registerFunctionsExternalModels(FunctionFactory & factory) +{ + factory.registerFunction(); +} + +} diff --git a/src/Interpreters/CatBoostModel.h b/src/Interpreters/CatBoostModel.h index 820c26a1fb4..eb599b43ef2 100644 --- a/src/Interpreters/CatBoostModel.h +++ b/src/Interpreters/CatBoostModel.h @@ -1,4 +1,5 @@ #pragma once + #include #include #include