2018-11-28 11:37:12 +00:00
|
|
|
#include "CatBoostModel.h"
|
|
|
|
|
2017-11-24 13:55:31 +00:00
|
|
|
#include <Common/FieldVisitors.h>
|
2017-10-06 18:05:30 +00:00
|
|
|
#include <mutex>
|
2017-10-09 20:13:44 +00:00
|
|
|
#include <Columns/ColumnString.h>
|
|
|
|
#include <Columns/ColumnFixedString.h>
|
|
|
|
#include <Columns/ColumnVector.h>
|
2018-12-26 16:44:57 +00:00
|
|
|
#include <Columns/ColumnTuple.h>
|
2017-10-09 20:13:44 +00:00
|
|
|
#include <Common/typeid_cast.h>
|
|
|
|
#include <IO/WriteBufferFromString.h>
|
|
|
|
#include <IO/Operators.h>
|
|
|
|
#include <Common/PODArray.h>
|
2017-10-26 19:00:27 +00:00
|
|
|
#include <Common/SharedLibrary.h>
|
2018-12-26 16:44:57 +00:00
|
|
|
#include <DataTypes/DataTypesNumber.h>
|
|
|
|
#include <DataTypes/DataTypeTuple.h>
|
2017-10-06 14:48:33 +00:00
|
|
|
|
|
|
|
namespace DB
|
|
|
|
{
|
|
|
|
|
2017-10-09 20:13:44 +00:00
|
|
|
namespace ErrorCodes
|
|
|
|
{
|
|
|
|
extern const int LOGICAL_ERROR;
|
|
|
|
extern const int BAD_ARGUMENTS;
|
|
|
|
extern const int CANNOT_LOAD_CATBOOST_MODEL;
|
|
|
|
extern const int CANNOT_APPLY_CATBOOST_MODEL;
|
|
|
|
}
|
|
|
|
|
2017-10-06 18:05:30 +00:00
|
|
|
|
2017-10-20 15:12:34 +00:00
|
|
|
/// CatBoost wrapper interface functions.
|
2017-10-26 12:18:37 +00:00
|
|
|
struct CatBoostWrapperAPI
|
2017-10-06 18:05:30 +00:00
|
|
|
{
|
|
|
|
typedef void ModelCalcerHandle;
|
|
|
|
|
|
|
|
ModelCalcerHandle * (* ModelCalcerCreate)();
|
|
|
|
|
|
|
|
void (* ModelCalcerDelete)(ModelCalcerHandle * calcer);
|
|
|
|
|
|
|
|
const char * (* GetErrorString)();
|
|
|
|
|
|
|
|
bool (* LoadFullModelFromFile)(ModelCalcerHandle * calcer, const char * filename);
|
|
|
|
|
|
|
|
bool (* CalcModelPredictionFlat)(ModelCalcerHandle * calcer, size_t docCount,
|
|
|
|
const float ** floatFeatures, size_t floatFeaturesSize,
|
|
|
|
double * result, size_t resultSize);
|
|
|
|
|
|
|
|
bool (* CalcModelPrediction)(ModelCalcerHandle * calcer, size_t docCount,
|
|
|
|
const float ** floatFeatures, size_t floatFeaturesSize,
|
|
|
|
const char *** catFeatures, size_t catFeaturesSize,
|
|
|
|
double * result, size_t resultSize);
|
|
|
|
|
|
|
|
bool (* CalcModelPredictionWithHashedCatFeatures)(ModelCalcerHandle * calcer, size_t docCount,
|
|
|
|
const float ** floatFeatures, size_t floatFeaturesSize,
|
|
|
|
const int ** catFeatures, size_t catFeaturesSize,
|
|
|
|
double * result, size_t resultSize);
|
|
|
|
|
|
|
|
int (* GetStringCatFeatureHash)(const char * data, size_t size);
|
|
|
|
int (* GetIntegerCatFeatureHash)(long long val);
|
2017-10-26 20:12:40 +00:00
|
|
|
|
|
|
|
size_t (* GetFloatFeaturesCount)(ModelCalcerHandle* calcer);
|
|
|
|
size_t (* GetCatFeaturesCount)(ModelCalcerHandle* calcer);
|
2018-12-26 16:44:57 +00:00
|
|
|
size_t (* GetTreeCount)(ModelCalcerHandle* modelHandle);
|
|
|
|
size_t (* GetDimensionsCount)(ModelCalcerHandle* modelHandle);
|
|
|
|
|
|
|
|
bool (* CheckModelMetadataHasKey)(ModelCalcerHandle* modelHandle, const char* keyPtr, size_t keySize);
|
|
|
|
size_t (*GetModelInfoValueSize)(ModelCalcerHandle* modelHandle, const char* keyPtr, size_t keySize);
|
|
|
|
const char* (*GetModelInfoValue)(ModelCalcerHandle* modelHandle, const char* keyPtr, size_t keySize);
|
2017-10-06 18:05:30 +00:00
|
|
|
};
|
|
|
|
|
2017-10-09 20:13:44 +00:00
|
|
|
|
2017-10-17 10:44:46 +00:00
|
|
|
namespace
|
|
|
|
{
|
|
|
|
|
2017-10-09 20:13:44 +00:00
|
|
|
class CatBoostModelHolder
|
2017-10-06 18:05:30 +00:00
|
|
|
{
|
2017-10-09 20:13:44 +00:00
|
|
|
private:
|
2017-10-26 12:18:37 +00:00
|
|
|
CatBoostWrapperAPI::ModelCalcerHandle * handle;
|
|
|
|
const CatBoostWrapperAPI * api;
|
2017-10-06 18:05:30 +00:00
|
|
|
public:
|
2017-10-26 12:18:37 +00:00
|
|
|
explicit CatBoostModelHolder(const CatBoostWrapperAPI * api) : api(api) { handle = api->ModelCalcerCreate(); }
|
2017-10-09 20:13:44 +00:00
|
|
|
~CatBoostModelHolder() { api->ModelCalcerDelete(handle); }
|
|
|
|
|
2017-10-26 12:18:37 +00:00
|
|
|
CatBoostWrapperAPI::ModelCalcerHandle * get() { return handle; }
|
2017-10-09 20:13:44 +00:00
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
class CatBoostModelImpl : public ICatBoostModel
|
|
|
|
{
|
|
|
|
public:
|
2017-10-26 12:18:37 +00:00
|
|
|
CatBoostModelImpl(const CatBoostWrapperAPI * api, const std::string & model_path) : api(api)
|
2017-10-09 20:13:44 +00:00
|
|
|
{
|
|
|
|
auto handle_ = std::make_unique<CatBoostModelHolder>(api);
|
|
|
|
if (!handle_)
|
|
|
|
{
|
|
|
|
std::string msg = "Cannot create CatBoost model: ";
|
|
|
|
throw Exception(msg + api->GetErrorString(), ErrorCodes::CANNOT_LOAD_CATBOOST_MODEL);
|
|
|
|
}
|
2017-10-20 09:29:00 +00:00
|
|
|
if (!api->LoadFullModelFromFile(handle_->get(), model_path.c_str()))
|
2017-10-09 20:13:44 +00:00
|
|
|
{
|
|
|
|
std::string msg = "Cannot load CatBoost model: ";
|
|
|
|
throw Exception(msg + api->GetErrorString(), ErrorCodes::CANNOT_LOAD_CATBOOST_MODEL);
|
|
|
|
}
|
2017-10-31 11:18:09 +00:00
|
|
|
|
|
|
|
float_features_count = api->GetFloatFeaturesCount(handle_->get());
|
|
|
|
cat_features_count = api->GetCatFeaturesCount(handle_->get());
|
2018-12-26 16:44:57 +00:00
|
|
|
tree_count = 1;
|
|
|
|
if (api->GetDimensionsCount)
|
|
|
|
tree_count = api->GetDimensionsCount(handle_->get());
|
2017-10-31 11:18:09 +00:00
|
|
|
|
2017-10-09 20:13:44 +00:00
|
|
|
handle = std::move(handle_);
|
|
|
|
}
|
|
|
|
|
2017-12-13 01:27:53 +00:00
|
|
|
ColumnPtr evaluate(const ColumnRawPtrs & columns) const override
|
2017-10-09 20:13:44 +00:00
|
|
|
{
|
|
|
|
if (columns.empty())
|
|
|
|
throw Exception("Got empty columns list for CatBoost model.", ErrorCodes::BAD_ARGUMENTS);
|
|
|
|
|
|
|
|
if (columns.size() != float_features_count + cat_features_count)
|
|
|
|
{
|
|
|
|
std::string msg;
|
|
|
|
{
|
|
|
|
WriteBufferFromString buffer(msg);
|
|
|
|
buffer << "Number of columns is different with number of features: ";
|
|
|
|
buffer << columns.size() << " vs " << float_features_count << " + " << cat_features_count;
|
|
|
|
}
|
|
|
|
throw Exception(msg, ErrorCodes::BAD_ARGUMENTS);
|
|
|
|
}
|
|
|
|
|
|
|
|
for (size_t i = 0; i < float_features_count; ++i)
|
|
|
|
{
|
|
|
|
if (!columns[i]->isNumeric())
|
|
|
|
{
|
|
|
|
std::string msg;
|
|
|
|
{
|
|
|
|
WriteBufferFromString buffer(msg);
|
2019-02-12 23:49:32 +00:00
|
|
|
buffer << "Column " << i << " should be numeric to make float feature.";
|
2017-10-09 20:13:44 +00:00
|
|
|
}
|
|
|
|
throw Exception(msg, ErrorCodes::BAD_ARGUMENTS);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
bool cat_features_are_strings = true;
|
|
|
|
for (size_t i = float_features_count; i < float_features_count + cat_features_count; ++i)
|
|
|
|
{
|
2017-10-26 14:08:05 +00:00
|
|
|
auto column = columns[i];
|
2017-10-09 20:13:44 +00:00
|
|
|
if (column->isNumeric())
|
|
|
|
cat_features_are_strings = false;
|
2017-10-26 14:08:05 +00:00
|
|
|
else if (!(typeid_cast<const ColumnString *>(column)
|
|
|
|
|| typeid_cast<const ColumnFixedString *>(column)))
|
2017-10-09 20:13:44 +00:00
|
|
|
{
|
|
|
|
std::string msg;
|
|
|
|
{
|
|
|
|
WriteBufferFromString buffer(msg);
|
2019-02-12 23:49:32 +00:00
|
|
|
buffer << "Column " << i << " should be numeric or string.";
|
2017-10-09 20:13:44 +00:00
|
|
|
}
|
|
|
|
throw Exception(msg, ErrorCodes::BAD_ARGUMENTS);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2019-01-04 12:10:00 +00:00
|
|
|
auto result = evalImpl(columns, cat_features_are_strings);
|
2018-12-26 16:44:57 +00:00
|
|
|
|
|
|
|
if (tree_count == 1)
|
|
|
|
return result;
|
|
|
|
|
|
|
|
size_t column_size = columns.front()->size();
|
|
|
|
auto result_buf = result->getData().data();
|
|
|
|
|
|
|
|
/// Multiple trees case. Copy data to several columns.
|
|
|
|
MutableColumns mutable_columns(tree_count);
|
|
|
|
std::vector<Float64 *> column_ptrs(tree_count);
|
|
|
|
for (size_t i = 0; i < tree_count; ++i)
|
|
|
|
{
|
|
|
|
auto col = ColumnFloat64::create(column_size);
|
|
|
|
column_ptrs[i] = col->getData().data();
|
|
|
|
mutable_columns[i] = std::move(col);
|
|
|
|
}
|
|
|
|
|
|
|
|
Float64 * data = result_buf;
|
|
|
|
for (size_t row = 0; row < column_size; ++row)
|
|
|
|
{
|
|
|
|
for (size_t i = 0; i < tree_count; ++i)
|
|
|
|
{
|
|
|
|
*column_ptrs[i] = *data;
|
|
|
|
++column_ptrs[i];
|
|
|
|
++data;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return ColumnTuple::create(std::move(mutable_columns));
|
2017-10-09 20:13:44 +00:00
|
|
|
}
|
|
|
|
|
2017-10-31 11:18:09 +00:00
|
|
|
size_t getFloatFeaturesCount() const override { return float_features_count; }
|
|
|
|
size_t getCatFeaturesCount() const override { return cat_features_count; }
|
2018-12-26 16:44:57 +00:00
|
|
|
size_t getTreeCount() const override { return tree_count; }
|
2017-10-26 20:12:40 +00:00
|
|
|
|
2017-10-09 20:13:44 +00:00
|
|
|
private:
|
|
|
|
std::unique_ptr<CatBoostModelHolder> handle;
|
2017-10-26 12:18:37 +00:00
|
|
|
const CatBoostWrapperAPI * api;
|
2017-10-31 11:18:09 +00:00
|
|
|
size_t float_features_count;
|
|
|
|
size_t cat_features_count;
|
2018-12-26 16:44:57 +00:00
|
|
|
size_t tree_count;
|
2017-10-09 20:13:44 +00:00
|
|
|
|
|
|
|
/// Buffer should be allocated with features_count * column->size() elements.
|
|
|
|
/// Place column elements in positions buffer[0], buffer[features_count], ... , buffer[size * features_count]
|
|
|
|
template <typename T>
|
2017-10-26 14:08:05 +00:00
|
|
|
void placeColumnAsNumber(const IColumn * column, T * buffer, size_t features_count) const
|
2017-10-09 20:13:44 +00:00
|
|
|
{
|
|
|
|
size_t size = column->size();
|
|
|
|
FieldVisitorConvertToNumber<T> visitor;
|
|
|
|
for (size_t i = 0; i < size; ++i)
|
|
|
|
{
|
|
|
|
/// TODO: Replace with column visitor.
|
|
|
|
Field field;
|
|
|
|
column->get(i, field);
|
|
|
|
*buffer = applyVisitor(visitor, field);
|
|
|
|
buffer += features_count;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Buffer should be allocated with features_count * column->size() elements.
|
|
|
|
/// Place string pointers in positions buffer[0], buffer[features_count], ... , buffer[size * features_count]
|
2017-10-17 10:44:46 +00:00
|
|
|
void placeStringColumn(const ColumnString & column, const char ** buffer, size_t features_count) const
|
2017-10-09 20:13:44 +00:00
|
|
|
{
|
|
|
|
size_t size = column.size();
|
|
|
|
for (size_t i = 0; i < size; ++i)
|
|
|
|
{
|
|
|
|
*buffer = const_cast<char *>(column.getDataAtWithTerminatingZero(i).data);
|
|
|
|
buffer += features_count;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Buffer should be allocated with features_count * column->size() elements.
|
|
|
|
/// Place string pointers in positions buffer[0], buffer[features_count], ... , buffer[size * features_count]
|
|
|
|
/// Returns PODArray which holds data (because ColumnFixedString doesn't store terminating zero).
|
2017-10-17 10:44:46 +00:00
|
|
|
PODArray<char> placeFixedStringColumn(
|
|
|
|
const ColumnFixedString & column, const char ** buffer, size_t features_count) const
|
2017-10-09 20:13:44 +00:00
|
|
|
{
|
|
|
|
size_t size = column.size();
|
|
|
|
size_t str_size = column.getN();
|
|
|
|
PODArray<char> data(size * (str_size + 1));
|
|
|
|
char * data_ptr = data.data();
|
|
|
|
|
|
|
|
for (size_t i = 0; i < size; ++i)
|
|
|
|
{
|
|
|
|
auto ref = column.getDataAt(i);
|
|
|
|
memcpy(data_ptr, ref.data, ref.size);
|
|
|
|
data_ptr[ref.size] = 0;
|
|
|
|
*buffer = data_ptr;
|
|
|
|
data_ptr += ref.size + 1;
|
|
|
|
buffer += features_count;
|
|
|
|
}
|
|
|
|
|
|
|
|
return data;
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Place columns into buffer, returns column which holds placed data. Buffer should contains column->size() values.
|
|
|
|
template <typename T>
|
2017-12-13 01:27:53 +00:00
|
|
|
ColumnPtr placeNumericColumns(const ColumnRawPtrs & columns,
|
2017-10-26 14:08:05 +00:00
|
|
|
size_t offset, size_t size, const T** buffer) const
|
2017-10-09 20:13:44 +00:00
|
|
|
{
|
|
|
|
if (size == 0)
|
|
|
|
return nullptr;
|
|
|
|
size_t column_size = columns[offset]->size();
|
2017-12-14 01:43:19 +00:00
|
|
|
auto data_column = ColumnVector<T>::create(size * column_size);
|
2017-12-15 02:52:38 +00:00
|
|
|
T * data = data_column->getData().data();
|
2017-10-20 10:05:58 +00:00
|
|
|
for (size_t i = 0; i < size; ++i)
|
2017-10-09 20:13:44 +00:00
|
|
|
{
|
2017-10-26 14:08:05 +00:00
|
|
|
auto column = columns[offset + i];
|
2017-10-09 20:13:44 +00:00
|
|
|
if (column->isNumeric())
|
|
|
|
placeColumnAsNumber(column, data + i, size);
|
|
|
|
}
|
|
|
|
|
|
|
|
for (size_t i = 0; i < column_size; ++i)
|
|
|
|
{
|
|
|
|
*buffer = data;
|
|
|
|
++buffer;
|
|
|
|
data += size;
|
|
|
|
}
|
|
|
|
|
Get rid of useless std::move to get NRVO
http://eel.is/c++draft/class.copy.elision#:constructor,copy,elision
Some quote:
> Speaking of RVO, return std::move(w); prohibits it. It means "use move constructor or fail to compile", whereas return w; means "use RVO, and if you can't, use move constructor, and if you can't, use copy constructor, and if you can't, fail to compile."
There is one exception to this rule:
```cpp
Block FilterBlockInputStream::removeFilterIfNeed(Block && block)
{
if (block && remove_filter)
block.erase(static_cast<size_t>(filter_column));
return std::move(block);
}
```
because references are not eligible for NRVO, which is another rule "always move rvalue references and forward universal references" that takes precedence.
2018-08-27 14:04:22 +00:00
|
|
|
return data_column;
|
2017-10-09 20:13:44 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
/// Place columns into buffer, returns data which was used for fixed string columns.
|
|
|
|
/// Buffer should contains column->size() values, each value contains size strings.
|
|
|
|
std::vector<PODArray<char>> placeStringColumns(
|
2017-12-13 01:27:53 +00:00
|
|
|
const ColumnRawPtrs & columns, size_t offset, size_t size, const char ** buffer) const
|
2017-10-09 20:13:44 +00:00
|
|
|
{
|
|
|
|
if (size == 0)
|
|
|
|
return {};
|
|
|
|
|
|
|
|
std::vector<PODArray<char>> data;
|
2017-10-20 10:05:58 +00:00
|
|
|
for (size_t i = 0; i < size; ++i)
|
2017-10-09 20:13:44 +00:00
|
|
|
{
|
2017-10-26 14:08:05 +00:00
|
|
|
auto column = columns[offset + i];
|
|
|
|
if (auto column_string = typeid_cast<const ColumnString *>(column))
|
2017-10-20 10:05:58 +00:00
|
|
|
placeStringColumn(*column_string, buffer + i, size);
|
2017-10-26 14:08:05 +00:00
|
|
|
else if (auto column_fixed_string = typeid_cast<const ColumnFixedString *>(column))
|
2017-10-20 10:05:58 +00:00
|
|
|
data.push_back(placeFixedStringColumn(*column_fixed_string, buffer + i, size));
|
2017-10-09 20:13:44 +00:00
|
|
|
else
|
|
|
|
throw Exception("Cannot place string column.", ErrorCodes::LOGICAL_ERROR);
|
|
|
|
}
|
|
|
|
|
|
|
|
return data;
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Calc hash for string cat feature at ps positions.
|
|
|
|
template <typename Column>
|
2017-10-20 10:05:58 +00:00
|
|
|
void calcStringHashes(const Column * column, size_t ps, const int ** buffer) const
|
2017-10-09 20:13:44 +00:00
|
|
|
{
|
|
|
|
size_t column_size = column->size();
|
|
|
|
for (size_t j = 0; j < column_size; ++j)
|
|
|
|
{
|
|
|
|
auto ref = column->getDataAt(j);
|
|
|
|
const_cast<int *>(*buffer)[ps] = api->GetStringCatFeatureHash(ref.data, ref.size);
|
2017-10-20 10:05:58 +00:00
|
|
|
++buffer;
|
2017-10-09 20:13:44 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Calc hash for int cat feature at ps position. Buffer at positions ps should contains unhashed values.
|
2017-10-20 10:05:58 +00:00
|
|
|
void calcIntHashes(size_t column_size, size_t ps, const int ** buffer) const
|
2017-10-09 20:13:44 +00:00
|
|
|
{
|
|
|
|
for (size_t j = 0; j < column_size; ++j)
|
|
|
|
{
|
|
|
|
const_cast<int *>(*buffer)[ps] = api->GetIntegerCatFeatureHash((*buffer)[ps]);
|
2017-10-20 10:05:58 +00:00
|
|
|
++buffer;
|
2017-10-09 20:13:44 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2017-10-20 10:05:58 +00:00
|
|
|
/// buffer contains column->size() rows and size columns.
|
|
|
|
/// For int cat features calc hash inplace.
|
|
|
|
/// For string cat features calc hash from column rows.
|
2017-12-13 01:27:53 +00:00
|
|
|
void calcHashes(const ColumnRawPtrs & columns, size_t offset, size_t size, const int ** buffer) const
|
2017-10-09 20:13:44 +00:00
|
|
|
{
|
|
|
|
if (size == 0)
|
|
|
|
return;
|
|
|
|
size_t column_size = columns[offset]->size();
|
|
|
|
|
|
|
|
std::vector<PODArray<char>> data;
|
2017-10-20 10:05:58 +00:00
|
|
|
for (size_t i = 0; i < size; ++i)
|
2017-10-09 20:13:44 +00:00
|
|
|
{
|
2017-10-26 14:08:05 +00:00
|
|
|
auto column = columns[offset + i];
|
|
|
|
if (auto column_string = typeid_cast<const ColumnString *>(column))
|
2017-10-20 10:05:58 +00:00
|
|
|
calcStringHashes(column_string, i, buffer);
|
2017-10-26 14:08:05 +00:00
|
|
|
else if (auto column_fixed_string = typeid_cast<const ColumnFixedString *>(column))
|
2017-10-20 10:05:58 +00:00
|
|
|
calcStringHashes(column_fixed_string, i, buffer);
|
2017-10-09 20:13:44 +00:00
|
|
|
else
|
2017-10-20 10:05:58 +00:00
|
|
|
calcIntHashes(column_size, i, buffer);
|
2017-10-09 20:13:44 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2017-10-20 10:05:58 +00:00
|
|
|
/// buffer[column_size * cat_features_count] -> char * => cat_features[column_size][cat_features_count] -> char *
|
2017-10-09 20:13:44 +00:00
|
|
|
void fillCatFeaturesBuffer(const char *** cat_features, const char ** buffer,
|
2019-01-04 12:10:00 +00:00
|
|
|
size_t column_size) const
|
2017-10-09 20:13:44 +00:00
|
|
|
{
|
|
|
|
for (size_t i = 0; i < column_size; ++i)
|
|
|
|
{
|
|
|
|
*cat_features = buffer;
|
|
|
|
++cat_features;
|
2019-01-04 12:10:00 +00:00
|
|
|
buffer += cat_features_count;
|
2017-10-09 20:13:44 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2017-10-20 15:12:34 +00:00
|
|
|
/// Convert values to row-oriented format and call evaluation function from CatBoost wrapper api.
|
|
|
|
/// * CalcModelPredictionFlat if no cat features
|
|
|
|
/// * CalcModelPrediction if all cat features are strings
|
|
|
|
/// * CalcModelPredictionWithHashedCatFeatures if has int cat features.
|
2018-12-26 16:44:57 +00:00
|
|
|
ColumnFloat64::MutablePtr evalImpl(
|
|
|
|
const ColumnRawPtrs & columns,
|
|
|
|
bool cat_features_are_strings) const
|
2017-10-09 20:13:44 +00:00
|
|
|
{
|
2017-10-20 15:12:34 +00:00
|
|
|
std::string error_msg = "Error occurred while applying CatBoost model: ";
|
2017-10-09 20:13:44 +00:00
|
|
|
size_t column_size = columns.front()->size();
|
|
|
|
|
2018-12-26 16:44:57 +00:00
|
|
|
auto result = ColumnFloat64::create(column_size * tree_count);
|
2017-10-09 20:13:44 +00:00
|
|
|
auto result_buf = result->getData().data();
|
|
|
|
|
2018-12-24 12:35:46 +00:00
|
|
|
if (!column_size)
|
|
|
|
return result;
|
|
|
|
|
2017-10-20 15:12:34 +00:00
|
|
|
/// Prepare float features.
|
|
|
|
PODArray<const float *> float_features(column_size);
|
|
|
|
auto float_features_buf = float_features.data();
|
|
|
|
/// Store all float data into single column. float_features is a list of pointers to it.
|
2019-01-04 12:10:00 +00:00
|
|
|
auto float_features_col = placeNumericColumns<float>(columns, 0, float_features_count, float_features_buf);
|
2017-10-09 20:13:44 +00:00
|
|
|
|
2019-01-04 12:10:00 +00:00
|
|
|
if (cat_features_count == 0)
|
2017-10-09 20:13:44 +00:00
|
|
|
{
|
2017-10-20 09:29:00 +00:00
|
|
|
if (!api->CalcModelPredictionFlat(handle->get(), column_size,
|
2019-01-04 12:10:00 +00:00
|
|
|
float_features_buf, float_features_count,
|
2018-12-26 16:44:57 +00:00
|
|
|
result_buf, column_size * tree_count))
|
2017-10-09 20:13:44 +00:00
|
|
|
{
|
|
|
|
|
|
|
|
throw Exception(error_msg + api->GetErrorString(), ErrorCodes::CANNOT_APPLY_CATBOOST_MODEL);
|
|
|
|
}
|
Get rid of useless std::move to get NRVO
http://eel.is/c++draft/class.copy.elision#:constructor,copy,elision
Some quote:
> Speaking of RVO, return std::move(w); prohibits it. It means "use move constructor or fail to compile", whereas return w; means "use RVO, and if you can't, use move constructor, and if you can't, use copy constructor, and if you can't, fail to compile."
There is one exception to this rule:
```cpp
Block FilterBlockInputStream::removeFilterIfNeed(Block && block)
{
if (block && remove_filter)
block.erase(static_cast<size_t>(filter_column));
return std::move(block);
}
```
because references are not eligible for NRVO, which is another rule "always move rvalue references and forward universal references" that takes precedence.
2018-08-27 14:04:22 +00:00
|
|
|
return result;
|
2017-10-09 20:13:44 +00:00
|
|
|
}
|
|
|
|
|
2017-10-20 15:12:34 +00:00
|
|
|
/// Prepare cat features.
|
2017-10-09 20:13:44 +00:00
|
|
|
if (cat_features_are_strings)
|
|
|
|
{
|
2017-10-20 15:12:34 +00:00
|
|
|
/// cat_features_holder stores pointers to ColumnString data or fixed_strings_data.
|
2019-01-04 12:10:00 +00:00
|
|
|
PODArray<const char *> cat_features_holder(cat_features_count * column_size);
|
2017-10-09 20:13:44 +00:00
|
|
|
PODArray<const char **> cat_features(column_size);
|
|
|
|
auto cat_features_buf = cat_features.data();
|
|
|
|
|
2019-01-04 12:10:00 +00:00
|
|
|
fillCatFeaturesBuffer(cat_features_buf, cat_features_holder.data(), column_size);
|
2017-10-26 14:08:05 +00:00
|
|
|
/// Fixed strings are stored without termination zero, so have to copy data into fixed_strings_data.
|
2019-01-04 12:10:00 +00:00
|
|
|
auto fixed_strings_data = placeStringColumns(columns, float_features_count,
|
|
|
|
cat_features_count, cat_features_holder.data());
|
2017-10-09 20:13:44 +00:00
|
|
|
|
2017-10-20 09:29:00 +00:00
|
|
|
if (!api->CalcModelPrediction(handle->get(), column_size,
|
2019-01-04 12:10:00 +00:00
|
|
|
float_features_buf, float_features_count,
|
|
|
|
cat_features_buf, cat_features_count,
|
2018-12-26 16:44:57 +00:00
|
|
|
result_buf, column_size * tree_count))
|
2017-10-09 20:13:44 +00:00
|
|
|
{
|
|
|
|
throw Exception(error_msg + api->GetErrorString(), ErrorCodes::CANNOT_APPLY_CATBOOST_MODEL);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
else
|
|
|
|
{
|
|
|
|
PODArray<const int *> cat_features(column_size);
|
|
|
|
auto cat_features_buf = cat_features.data();
|
2019-01-04 12:10:00 +00:00
|
|
|
auto cat_features_col = placeNumericColumns<int>(columns, float_features_count,
|
|
|
|
cat_features_count, cat_features_buf);
|
|
|
|
calcHashes(columns, float_features_count, cat_features_count, cat_features_buf);
|
2017-10-09 20:13:44 +00:00
|
|
|
if (!api->CalcModelPredictionWithHashedCatFeatures(
|
2017-10-20 09:29:00 +00:00
|
|
|
handle->get(), column_size,
|
2019-01-04 12:10:00 +00:00
|
|
|
float_features_buf, float_features_count,
|
|
|
|
cat_features_buf, cat_features_count,
|
2018-12-26 16:44:57 +00:00
|
|
|
result_buf, column_size * tree_count))
|
2017-10-09 20:13:44 +00:00
|
|
|
{
|
|
|
|
throw Exception(error_msg + api->GetErrorString(), ErrorCodes::CANNOT_APPLY_CATBOOST_MODEL);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
Get rid of useless std::move to get NRVO
http://eel.is/c++draft/class.copy.elision#:constructor,copy,elision
Some quote:
> Speaking of RVO, return std::move(w); prohibits it. It means "use move constructor or fail to compile", whereas return w; means "use RVO, and if you can't, use move constructor, and if you can't, use copy constructor, and if you can't, fail to compile."
There is one exception to this rule:
```cpp
Block FilterBlockInputStream::removeFilterIfNeed(Block && block)
{
if (block && remove_filter)
block.erase(static_cast<size_t>(filter_column));
return std::move(block);
}
```
because references are not eligible for NRVO, which is another rule "always move rvalue references and forward universal references" that takes precedence.
2018-08-27 14:04:22 +00:00
|
|
|
return result;
|
2017-10-09 20:13:44 +00:00
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
|
2017-10-20 15:12:34 +00:00
|
|
|
/// Holds CatBoost wrapper library and provides wrapper interface.
|
2017-10-26 12:18:37 +00:00
|
|
|
class CatBoostLibHolder: public CatBoostWrapperAPIProvider
|
2017-10-09 20:13:44 +00:00
|
|
|
{
|
|
|
|
public:
|
2017-10-31 11:18:09 +00:00
|
|
|
explicit CatBoostLibHolder(std::string lib_path_) : lib_path(std::move(lib_path_)), lib(lib_path) { initAPI(); }
|
2017-10-06 18:05:30 +00:00
|
|
|
|
2017-10-26 12:18:37 +00:00
|
|
|
const CatBoostWrapperAPI & getAPI() const override { return api; }
|
2017-10-06 18:05:30 +00:00
|
|
|
const std::string & getCurrentPath() const { return lib_path; }
|
|
|
|
|
|
|
|
private:
|
2017-10-26 12:18:37 +00:00
|
|
|
CatBoostWrapperAPI api;
|
2017-10-06 18:05:30 +00:00
|
|
|
std::string lib_path;
|
2017-10-26 19:00:27 +00:00
|
|
|
SharedLibrary lib;
|
2017-10-06 18:05:30 +00:00
|
|
|
|
2017-10-26 12:18:37 +00:00
|
|
|
void initAPI();
|
2017-10-06 18:05:30 +00:00
|
|
|
|
|
|
|
template <typename T>
|
2017-10-26 19:00:27 +00:00
|
|
|
void load(T& func, const std::string & name) { func = lib.get<T>(name); }
|
2018-12-26 16:44:57 +00:00
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
void tryLoad(T& func, const std::string & name) { func = lib.tryGet<T>(name); }
|
2017-10-06 18:05:30 +00:00
|
|
|
};
|
|
|
|
|
2017-10-26 12:18:37 +00:00
|
|
|
void CatBoostLibHolder::initAPI()
|
2017-10-06 18:05:30 +00:00
|
|
|
{
|
|
|
|
load(api.ModelCalcerCreate, "ModelCalcerCreate");
|
|
|
|
load(api.ModelCalcerDelete, "ModelCalcerDelete");
|
|
|
|
load(api.GetErrorString, "GetErrorString");
|
|
|
|
load(api.LoadFullModelFromFile, "LoadFullModelFromFile");
|
|
|
|
load(api.CalcModelPredictionFlat, "CalcModelPredictionFlat");
|
|
|
|
load(api.CalcModelPrediction, "CalcModelPrediction");
|
|
|
|
load(api.CalcModelPredictionWithHashedCatFeatures, "CalcModelPredictionWithHashedCatFeatures");
|
|
|
|
load(api.GetStringCatFeatureHash, "GetStringCatFeatureHash");
|
|
|
|
load(api.GetIntegerCatFeatureHash, "GetIntegerCatFeatureHash");
|
2017-10-31 11:18:09 +00:00
|
|
|
load(api.GetFloatFeaturesCount, "GetFloatFeaturesCount");
|
|
|
|
load(api.GetCatFeaturesCount, "GetCatFeaturesCount");
|
2018-12-26 16:44:57 +00:00
|
|
|
tryLoad(api.CheckModelMetadataHasKey, "CheckModelMetadataHasKey");
|
|
|
|
tryLoad(api.GetModelInfoValueSize, "GetModelInfoValueSize");
|
|
|
|
tryLoad(api.GetModelInfoValue, "GetModelInfoValue");
|
|
|
|
tryLoad(api.GetTreeCount, "GetTreeCount");
|
|
|
|
tryLoad(api.GetDimensionsCount, "GetDimensionsCount");
|
2017-10-06 18:05:30 +00:00
|
|
|
}
|
|
|
|
|
2017-10-09 20:13:44 +00:00
|
|
|
std::shared_ptr<CatBoostLibHolder> getCatBoostWrapperHolder(const std::string & lib_path)
|
2017-10-06 18:05:30 +00:00
|
|
|
{
|
2017-10-09 20:13:44 +00:00
|
|
|
static std::weak_ptr<CatBoostLibHolder> ptr;
|
2017-10-06 18:05:30 +00:00
|
|
|
static std::mutex mutex;
|
|
|
|
|
2019-01-02 06:44:36 +00:00
|
|
|
std::lock_guard lock(mutex);
|
2017-10-06 18:05:30 +00:00
|
|
|
auto result = ptr.lock();
|
|
|
|
|
|
|
|
if (!result || result->getCurrentPath() != lib_path)
|
|
|
|
{
|
2017-10-09 20:13:44 +00:00
|
|
|
result = std::make_shared<CatBoostLibHolder>(lib_path);
|
2017-10-06 18:05:30 +00:00
|
|
|
/// This assignment is not atomic, which prevents from creating lock only inside 'if'.
|
|
|
|
ptr = result;
|
|
|
|
}
|
|
|
|
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
2017-10-06 14:48:33 +00:00
|
|
|
|
2017-10-31 11:18:09 +00:00
|
|
|
CatBoostModel::CatBoostModel(std::string name_, std::string model_path_, std::string lib_path_,
|
|
|
|
const ExternalLoadableLifetime & lifetime)
|
|
|
|
: name(std::move(name_)), model_path(std::move(model_path_)), lib_path(std::move(lib_path_)), lifetime(lifetime)
|
2017-10-09 20:13:44 +00:00
|
|
|
{
|
|
|
|
api_provider = getCatBoostWrapperHolder(lib_path);
|
2017-10-26 12:18:37 +00:00
|
|
|
api = &api_provider->getAPI();
|
2017-10-09 20:13:44 +00:00
|
|
|
model = std::make_unique<CatBoostModelImpl>(api, model_path);
|
2017-10-31 11:18:09 +00:00
|
|
|
float_features_count = model->getFloatFeaturesCount();
|
|
|
|
cat_features_count = model->getCatFeaturesCount();
|
2018-12-26 16:44:57 +00:00
|
|
|
tree_count = model->getTreeCount();
|
2017-10-09 20:13:44 +00:00
|
|
|
}
|
|
|
|
|
2017-10-06 14:48:33 +00:00
|
|
|
const ExternalLoadableLifetime & CatBoostModel::getLifetime() const
|
|
|
|
{
|
|
|
|
return lifetime;
|
|
|
|
}
|
|
|
|
|
|
|
|
bool CatBoostModel::isModified() const
|
|
|
|
{
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
2019-06-02 12:11:01 +00:00
|
|
|
std::shared_ptr<const IExternalLoadable> CatBoostModel::clone() const
|
2017-10-06 14:48:33 +00:00
|
|
|
{
|
2019-06-02 12:11:01 +00:00
|
|
|
return std::make_shared<CatBoostModel>(name, model_path, lib_path, lifetime);
|
2017-10-06 14:48:33 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
size_t CatBoostModel::getFloatFeaturesCount() const
|
|
|
|
{
|
2017-10-09 20:13:44 +00:00
|
|
|
return float_features_count;
|
2017-10-06 14:48:33 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
size_t CatBoostModel::getCatFeaturesCount() const
|
|
|
|
{
|
2017-10-09 20:13:44 +00:00
|
|
|
return cat_features_count;
|
2017-10-06 14:48:33 +00:00
|
|
|
}
|
|
|
|
|
2018-12-26 16:44:57 +00:00
|
|
|
size_t CatBoostModel::getTreeCount() const
|
|
|
|
{
|
|
|
|
return tree_count;
|
|
|
|
}
|
|
|
|
|
|
|
|
DataTypePtr CatBoostModel::getReturnType() const
|
|
|
|
{
|
|
|
|
auto type = std::make_shared<DataTypeFloat64>();
|
|
|
|
if (tree_count == 1)
|
|
|
|
return type;
|
|
|
|
|
|
|
|
DataTypes types(tree_count, type);
|
|
|
|
|
|
|
|
return std::make_shared<DataTypeTuple>(types);
|
|
|
|
}
|
|
|
|
|
2017-12-13 01:27:53 +00:00
|
|
|
ColumnPtr CatBoostModel::evaluate(const ColumnRawPtrs & columns) const
|
2017-10-06 14:48:33 +00:00
|
|
|
{
|
2017-10-09 20:13:44 +00:00
|
|
|
if (!model)
|
|
|
|
throw Exception("CatBoost model was not loaded.", ErrorCodes::LOGICAL_ERROR);
|
2017-10-31 11:18:09 +00:00
|
|
|
return model->evaluate(columns);
|
2017-10-06 14:48:33 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
}
|