Merge pull request #1380 from yandex/catboost-models

Catboost models
This commit is contained in:
alexey-milovidov 2017-10-30 19:16:41 +03:00 committed by GitHub
commit 56ef2e9196
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
41 changed files with 2250 additions and 557 deletions

View File

@ -384,6 +384,8 @@ namespace ErrorCodes
extern const int UNKNOWN_STATUS_OF_DISTRIBUTED_DDL_TASK = 379;
extern const int CANNOT_KILL = 380;
extern const int HTTP_LENGTH_REQUIRED = 381;
extern const int CANNOT_LOAD_CATBOOST_MODEL = 382;
extern const int CANNOT_APPLY_CATBOOST_MODEL = 383;
extern const int KEEPER_EXCEPTION = 999;
extern const int POCO_EXCEPTION = 1000;

View File

@ -26,18 +26,19 @@ void DatabaseDictionary::loadTables(Context & context, ThreadPool * thread_pool,
Tables DatabaseDictionary::loadTables()
{
const std::lock_guard<std::mutex> lock_dictionaries {external_dictionaries.dictionaries_mutex};
auto objects_map = external_dictionaries.getObjectsMap();
const auto & dictionaries = objects_map.get();
Tables tables;
for (const auto & pair : external_dictionaries.dictionaries)
for (const auto & pair : dictionaries)
{
const std::string & name = pair.first;
if (deleted_tables.count(name))
continue;
auto dict_ptr = pair.second.dict;
auto dict_ptr = std::static_pointer_cast<IDictionaryBase>(pair.second.loadable);
if (dict_ptr)
{
const DictionaryStructure & dictionary_structure = dict_ptr->get()->getStructure();
const DictionaryStructure & dictionary_structure = dict_ptr->getStructure();
auto columns = StorageDictionary::getNamesAndTypes(dictionary_structure);
tables[name] = StorageDictionary::create(name, columns, {}, {}, {}, dictionary_structure, name);
}
@ -50,26 +51,28 @@ bool DatabaseDictionary::isTableExist(
const Context & context,
const String & table_name) const
{
const std::lock_guard<std::mutex> lock_dictionaries {external_dictionaries.dictionaries_mutex};
return external_dictionaries.dictionaries.count(table_name) && !deleted_tables.count(table_name);
auto objects_map = external_dictionaries.getObjectsMap();
const auto & dictionaries = objects_map.get();
return dictionaries.count(table_name) && !deleted_tables.count(table_name);
}
StoragePtr DatabaseDictionary::tryGetTable(
const Context & context,
const String & table_name)
{
const std::lock_guard<std::mutex> lock_dictionaries {external_dictionaries.dictionaries_mutex};
auto objects_map = external_dictionaries.getObjectsMap();
const auto & dictionaries = objects_map.get();
if (deleted_tables.count(table_name))
return {};
{
auto it = external_dictionaries.dictionaries.find(table_name);
if (it != external_dictionaries.dictionaries.end())
auto it = dictionaries.find(table_name);
if (it != dictionaries.end())
{
const auto & dict_ptr = it->second.dict;
const auto & dict_ptr = std::static_pointer_cast<IDictionaryBase>(it->second.loadable);
if (dict_ptr)
{
const DictionaryStructure & dictionary_structure = dict_ptr->get()->getStructure();
const DictionaryStructure & dictionary_structure = dict_ptr->getStructure();
auto columns = StorageDictionary::getNamesAndTypes(dictionary_structure);
return StorageDictionary::create(table_name, columns, {}, {}, {}, dictionary_structure, table_name);
}
@ -86,9 +89,10 @@ DatabaseIteratorPtr DatabaseDictionary::getIterator(const Context & context)
bool DatabaseDictionary::empty(const Context & context) const
{
const std::lock_guard<std::mutex> lock_dictionaries {external_dictionaries.dictionaries_mutex};
for (const auto & pair : external_dictionaries.dictionaries)
if (pair.second.dict && !deleted_tables.count(pair.first))
auto objects_map = external_dictionaries.getObjectsMap();
const auto & dictionaries = objects_map.get();
for (const auto & pair : dictionaries)
if (pair.second.loadable && !deleted_tables.count(pair.first))
return false;
return true;
}
@ -119,7 +123,7 @@ void DatabaseDictionary::removeTable(
if (!isTableExist(context, table_name))
throw Exception("Table " + name + "." + table_name + " doesn't exist.", ErrorCodes::UNKNOWN_TABLE);
const std::lock_guard<std::mutex> lock_dictionaries {external_dictionaries.dictionaries_mutex};
auto objects_map = external_dictionaries.getObjectsMap();
deleted_tables.insert(table_name);
}
@ -156,7 +160,6 @@ ASTPtr DatabaseDictionary::getCreateQuery(
const String & table_name) const
{
throw Exception("DatabaseDictionary: getCreateQuery() is not supported", ErrorCodes::NOT_IMPLEMENTED);
return nullptr;
}
void DatabaseDictionary::shutdown()

View File

@ -54,7 +54,7 @@ public:
bool isCached() const override { return true; }
DictionaryPtr clone() const override { return std::make_unique<CacheDictionary>(*this); }
std::unique_ptr<IExternalLoadable> clone() const override { return std::make_unique<CacheDictionary>(*this); }
const IDictionarySource * getSource() const override { return source_ptr.get(); }

View File

@ -0,0 +1,523 @@
#include <Dictionaries/CatBoostModel.h>
#include <Core/FieldVisitors.h>
#include <mutex>
#include <Columns/ColumnString.h>
#include <Columns/ColumnFixedString.h>
#include <Columns/ColumnVector.h>
#include <Common/typeid_cast.h>
#include <IO/WriteBufferFromString.h>
#include <IO/Operators.h>
#include <Common/PODArray.h>
#include <Common/SharedLibrary.h>
namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
extern const int BAD_ARGUMENTS;
extern const int CANNOT_LOAD_CATBOOST_MODEL;
extern const int CANNOT_APPLY_CATBOOST_MODEL;
}
/// CatBoost wrapper interface functions.
struct CatBoostWrapperAPI
{
typedef void ModelCalcerHandle;
ModelCalcerHandle * (* ModelCalcerCreate)();
void (* ModelCalcerDelete)(ModelCalcerHandle * calcer);
const char * (* GetErrorString)();
bool (* LoadFullModelFromFile)(ModelCalcerHandle * calcer, const char * filename);
bool (* CalcModelPredictionFlat)(ModelCalcerHandle * calcer, size_t docCount,
const float ** floatFeatures, size_t floatFeaturesSize,
double * result, size_t resultSize);
bool (* CalcModelPrediction)(ModelCalcerHandle * calcer, size_t docCount,
const float ** floatFeatures, size_t floatFeaturesSize,
const char *** catFeatures, size_t catFeaturesSize,
double * result, size_t resultSize);
bool (* CalcModelPredictionWithHashedCatFeatures)(ModelCalcerHandle * calcer, size_t docCount,
const float ** floatFeatures, size_t floatFeaturesSize,
const int ** catFeatures, size_t catFeaturesSize,
double * result, size_t resultSize);
int (* GetStringCatFeatureHash)(const char * data, size_t size);
int (* GetIntegerCatFeatureHash)(long long val);
size_t (* GetFloatFeaturesCount)(ModelCalcerHandle* calcer);
size_t (* GetCatFeaturesCount)(ModelCalcerHandle* calcer);
};
namespace
{
class CatBoostModelHolder
{
private:
CatBoostWrapperAPI::ModelCalcerHandle * handle;
const CatBoostWrapperAPI * api;
public:
explicit CatBoostModelHolder(const CatBoostWrapperAPI * api) : api(api) { handle = api->ModelCalcerCreate(); }
~CatBoostModelHolder() { api->ModelCalcerDelete(handle); }
CatBoostWrapperAPI::ModelCalcerHandle * get() { return handle; }
explicit operator CatBoostWrapperAPI::ModelCalcerHandle * () { return handle; }
};
class CatBoostModelImpl : public ICatBoostModel
{
public:
CatBoostModelImpl(const CatBoostWrapperAPI * api, const std::string & model_path) : api(api)
{
auto handle_ = std::make_unique<CatBoostModelHolder>(api);
if (!handle_)
{
std::string msg = "Cannot create CatBoost model: ";
throw Exception(msg + api->GetErrorString(), ErrorCodes::CANNOT_LOAD_CATBOOST_MODEL);
}
if (!api->LoadFullModelFromFile(handle_->get(), model_path.c_str()))
{
std::string msg = "Cannot load CatBoost model: ";
throw Exception(msg + api->GetErrorString(), ErrorCodes::CANNOT_LOAD_CATBOOST_MODEL);
}
handle = std::move(handle_);
}
ColumnPtr evaluate(const ConstColumnPlainPtrs & columns,
size_t float_features_count, size_t cat_features_count) const override
{
checkFeaturesCount(float_features_count, cat_features_count);
if (columns.empty())
throw Exception("Got empty columns list for CatBoost model.", ErrorCodes::BAD_ARGUMENTS);
if (columns.size() != float_features_count + cat_features_count)
{
std::string msg;
{
WriteBufferFromString buffer(msg);
buffer << "Number of columns is different with number of features: ";
buffer << columns.size() << " vs " << float_features_count << " + " << cat_features_count;
}
throw Exception(msg, ErrorCodes::BAD_ARGUMENTS);
}
for (size_t i = 0; i < float_features_count; ++i)
{
if (!columns[i]->isNumeric())
{
std::string msg;
{
WriteBufferFromString buffer(msg);
buffer << "Column " << i << "should be numeric to make float feature.";
}
throw Exception(msg, ErrorCodes::BAD_ARGUMENTS);
}
}
bool cat_features_are_strings = true;
for (size_t i = float_features_count; i < float_features_count + cat_features_count; ++i)
{
auto column = columns[i];
if (column->isNumeric())
cat_features_are_strings = false;
else if (!(typeid_cast<const ColumnString *>(column)
|| typeid_cast<const ColumnFixedString *>(column)))
{
std::string msg;
{
WriteBufferFromString buffer(msg);
buffer << "Column " << i << "should be numeric or string.";
}
throw Exception(msg, ErrorCodes::BAD_ARGUMENTS);
}
}
return evalImpl(columns, float_features_count, cat_features_count, cat_features_are_strings);
}
void checkFeaturesCount(size_t float_features_count, size_t cat_features_count) const
{
if (api->GetFloatFeaturesCount)
{
size_t float_features_in_model = api->GetFloatFeaturesCount(handle->get());
if (float_features_count != float_features_in_model)
throw Exception("CatBoost model expected " + std::to_string(float_features_in_model) + " float features"
+ ", but " + std::to_string(float_features_count) + " was provided.");
}
if (api->GetCatFeaturesCount)
{
size_t cat_features_in_model = api->GetCatFeaturesCount(handle->get());
if (cat_features_count != cat_features_in_model)
throw Exception("CatBoost model expected " + std::to_string(cat_features_in_model) + " cat features"
+ ", but " + std::to_string(cat_features_count) + " was provided.");
}
}
private:
std::unique_ptr<CatBoostModelHolder> handle;
const CatBoostWrapperAPI * api;
/// Buffer should be allocated with features_count * column->size() elements.
/// Place column elements in positions buffer[0], buffer[features_count], ... , buffer[size * features_count]
template <typename T>
void placeColumnAsNumber(const IColumn * column, T * buffer, size_t features_count) const
{
size_t size = column->size();
FieldVisitorConvertToNumber<T> visitor;
for (size_t i = 0; i < size; ++i)
{
/// TODO: Replace with column visitor.
Field field;
column->get(i, field);
*buffer = applyVisitor(visitor, field);
buffer += features_count;
}
}
/// Buffer should be allocated with features_count * column->size() elements.
/// Place string pointers in positions buffer[0], buffer[features_count], ... , buffer[size * features_count]
void placeStringColumn(const ColumnString & column, const char ** buffer, size_t features_count) const
{
size_t size = column.size();
for (size_t i = 0; i < size; ++i)
{
*buffer = const_cast<char *>(column.getDataAtWithTerminatingZero(i).data);
buffer += features_count;
}
}
/// Buffer should be allocated with features_count * column->size() elements.
/// Place string pointers in positions buffer[0], buffer[features_count], ... , buffer[size * features_count]
/// Returns PODArray which holds data (because ColumnFixedString doesn't store terminating zero).
PODArray<char> placeFixedStringColumn(
const ColumnFixedString & column, const char ** buffer, size_t features_count) const
{
size_t size = column.size();
size_t str_size = column.getN();
PODArray<char> data(size * (str_size + 1));
char * data_ptr = data.data();
for (size_t i = 0; i < size; ++i)
{
auto ref = column.getDataAt(i);
memcpy(data_ptr, ref.data, ref.size);
data_ptr[ref.size] = 0;
*buffer = data_ptr;
data_ptr += ref.size + 1;
buffer += features_count;
}
return data;
}
/// Place columns into buffer, returns column which holds placed data. Buffer should contains column->size() values.
template <typename T>
ColumnPtr placeNumericColumns(const ConstColumnPlainPtrs & columns,
size_t offset, size_t size, const T** buffer) const
{
if (size == 0)
return nullptr;
size_t column_size = columns[offset]->size();
auto data_column = std::make_shared<ColumnVector<T>>(size * column_size);
T* data = data_column->getData().data();
for (size_t i = 0; i < size; ++i)
{
auto column = columns[offset + i];
if (column->isNumeric())
placeColumnAsNumber(column, data + i, size);
}
for (size_t i = 0; i < column_size; ++i)
{
*buffer = data;
++buffer;
data += size;
}
return data_column;
}
/// Place columns into buffer, returns data which was used for fixed string columns.
/// Buffer should contains column->size() values, each value contains size strings.
std::vector<PODArray<char>> placeStringColumns(
const ConstColumnPlainPtrs & columns, size_t offset, size_t size, const char ** buffer) const
{
if (size == 0)
return {};
std::vector<PODArray<char>> data;
for (size_t i = 0; i < size; ++i)
{
auto column = columns[offset + i];
if (auto column_string = typeid_cast<const ColumnString *>(column))
placeStringColumn(*column_string, buffer + i, size);
else if (auto column_fixed_string = typeid_cast<const ColumnFixedString *>(column))
data.push_back(placeFixedStringColumn(*column_fixed_string, buffer + i, size));
else
throw Exception("Cannot place string column.", ErrorCodes::LOGICAL_ERROR);
}
return data;
}
/// Calc hash for string cat feature at ps positions.
template <typename Column>
void calcStringHashes(const Column * column, size_t ps, const int ** buffer) const
{
size_t column_size = column->size();
for (size_t j = 0; j < column_size; ++j)
{
auto ref = column->getDataAt(j);
const_cast<int *>(*buffer)[ps] = api->GetStringCatFeatureHash(ref.data, ref.size);
++buffer;
}
}
/// Calc hash for int cat feature at ps position. Buffer at positions ps should contains unhashed values.
void calcIntHashes(size_t column_size, size_t ps, const int ** buffer) const
{
for (size_t j = 0; j < column_size; ++j)
{
const_cast<int *>(*buffer)[ps] = api->GetIntegerCatFeatureHash((*buffer)[ps]);
++buffer;
}
}
/// buffer contains column->size() rows and size columns.
/// For int cat features calc hash inplace.
/// For string cat features calc hash from column rows.
void calcHashes(const ConstColumnPlainPtrs & columns, size_t offset, size_t size, const int ** buffer) const
{
if (size == 0)
return;
size_t column_size = columns[offset]->size();
std::vector<PODArray<char>> data;
for (size_t i = 0; i < size; ++i)
{
auto column = columns[offset + i];
if (auto column_string = typeid_cast<const ColumnString *>(column))
calcStringHashes(column_string, i, buffer);
else if (auto column_fixed_string = typeid_cast<const ColumnFixedString *>(column))
calcStringHashes(column_fixed_string, i, buffer);
else
calcIntHashes(column_size, i, buffer);
}
}
/// buffer[column_size * cat_features_count] -> char * => cat_features[column_size][cat_features_count] -> char *
void fillCatFeaturesBuffer(const char *** cat_features, const char ** buffer,
size_t column_size, size_t cat_features_count) const
{
for (size_t i = 0; i < column_size; ++i)
{
*cat_features = buffer;
++cat_features;
buffer += cat_features_count;
}
}
/// Convert values to row-oriented format and call evaluation function from CatBoost wrapper api.
/// * CalcModelPredictionFlat if no cat features
/// * CalcModelPrediction if all cat features are strings
/// * CalcModelPredictionWithHashedCatFeatures if has int cat features.
ColumnPtr evalImpl(const ConstColumnPlainPtrs & columns, size_t float_features_count, size_t cat_features_count,
bool cat_features_are_strings) const
{
std::string error_msg = "Error occurred while applying CatBoost model: ";
size_t column_size = columns.front()->size();
auto result= std::make_shared<ColumnFloat64>(column_size);
auto result_buf = result->getData().data();
/// Prepare float features.
PODArray<const float *> float_features(column_size);
auto float_features_buf = float_features.data();
/// Store all float data into single column. float_features is a list of pointers to it.
auto float_features_col = placeNumericColumns<float>(columns, 0, float_features_count, float_features_buf);
if (cat_features_count == 0)
{
if (!api->CalcModelPredictionFlat(handle->get(), column_size,
float_features_buf, float_features_count,
result_buf, column_size))
{
throw Exception(error_msg + api->GetErrorString(), ErrorCodes::CANNOT_APPLY_CATBOOST_MODEL);
}
return result;
}
/// Prepare cat features.
if (cat_features_are_strings)
{
/// cat_features_holder stores pointers to ColumnString data or fixed_strings_data.
PODArray<const char *> cat_features_holder(cat_features_count * column_size);
PODArray<const char **> cat_features(column_size);
auto cat_features_buf = cat_features.data();
fillCatFeaturesBuffer(cat_features_buf, cat_features_holder.data(), column_size, cat_features_count);
/// Fixed strings are stored without termination zero, so have to copy data into fixed_strings_data.
auto fixed_strings_data = placeStringColumns(columns, float_features_count,
cat_features_count, cat_features_holder.data());
if (!api->CalcModelPrediction(handle->get(), column_size,
float_features_buf, float_features_count,
cat_features_buf, cat_features_count,
result_buf, column_size))
{
throw Exception(error_msg + api->GetErrorString(), ErrorCodes::CANNOT_APPLY_CATBOOST_MODEL);
}
}
else
{
PODArray<const int *> cat_features(column_size);
auto cat_features_buf = cat_features.data();
auto cat_features_col = placeNumericColumns<int>(columns, float_features_count,
cat_features_count, cat_features_buf);
calcHashes(columns, float_features_count, cat_features_count, cat_features_buf);
if (!api->CalcModelPredictionWithHashedCatFeatures(
handle->get(), column_size,
float_features_buf, float_features_count,
cat_features_buf, cat_features_count,
result_buf, column_size))
{
throw Exception(error_msg + api->GetErrorString(), ErrorCodes::CANNOT_APPLY_CATBOOST_MODEL);
}
}
return result;
}
};
/// Holds CatBoost wrapper library and provides wrapper interface.
class CatBoostLibHolder: public CatBoostWrapperAPIProvider
{
public:
explicit CatBoostLibHolder(const std::string & lib_path) : lib_path(lib_path), lib(lib_path) { initAPI(); }
const CatBoostWrapperAPI & getAPI() const override { return api; }
const std::string & getCurrentPath() const { return lib_path; }
private:
CatBoostWrapperAPI api;
std::string lib_path;
SharedLibrary lib;
void initAPI();
template <typename T>
void load(T& func, const std::string & name) { func = lib.get<T>(name); }
template <typename T>
void tryLoad(T& func, const std::string & name) { func = lib.tryGet<T>(name); }
};
void CatBoostLibHolder::initAPI()
{
load(api.ModelCalcerCreate, "ModelCalcerCreate");
load(api.ModelCalcerDelete, "ModelCalcerDelete");
load(api.GetErrorString, "GetErrorString");
load(api.LoadFullModelFromFile, "LoadFullModelFromFile");
load(api.CalcModelPredictionFlat, "CalcModelPredictionFlat");
load(api.CalcModelPrediction, "CalcModelPrediction");
load(api.CalcModelPredictionWithHashedCatFeatures, "CalcModelPredictionWithHashedCatFeatures");
load(api.GetStringCatFeatureHash, "GetStringCatFeatureHash");
load(api.GetIntegerCatFeatureHash, "GetIntegerCatFeatureHash");
tryLoad(api.GetFloatFeaturesCount, "GetFloatFeaturesCount");
tryLoad(api.GetCatFeaturesCount, "GetCatFeaturesCount");
}
std::shared_ptr<CatBoostLibHolder> getCatBoostWrapperHolder(const std::string & lib_path)
{
static std::weak_ptr<CatBoostLibHolder> ptr;
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
auto result = ptr.lock();
if (!result || result->getCurrentPath() != lib_path)
{
result = std::make_shared<CatBoostLibHolder>(lib_path);
/// This assignment is not atomic, which prevents from creating lock only inside 'if'.
ptr = result;
}
return result;
}
}
CatBoostModel::CatBoostModel(const std::string & name, const std::string & model_path, const std::string & lib_path,
const ExternalLoadableLifetime & lifetime,
size_t float_features_count, size_t cat_features_count)
: name(name), model_path(model_path), lib_path(lib_path), lifetime(lifetime),
float_features_count(float_features_count), cat_features_count(cat_features_count)
{
try
{
init(lib_path);
}
catch (...)
{
creation_exception = std::current_exception();
}
}
void CatBoostModel::init(const std::string & lib_path)
{
api_provider = getCatBoostWrapperHolder(lib_path);
api = &api_provider->getAPI();
model = std::make_unique<CatBoostModelImpl>(api, model_path);
}
const ExternalLoadableLifetime & CatBoostModel::getLifetime() const
{
return lifetime;
}
bool CatBoostModel::isModified() const
{
return true;
}
std::unique_ptr<IExternalLoadable> CatBoostModel::clone() const
{
return std::make_unique<CatBoostModel>(name, model_path, lib_path, lifetime, float_features_count, cat_features_count);
}
size_t CatBoostModel::getFloatFeaturesCount() const
{
return float_features_count;
}
size_t CatBoostModel::getCatFeaturesCount() const
{
return cat_features_count;
}
ColumnPtr CatBoostModel::evaluate(const ConstColumnPlainPtrs & columns) const
{
if (!model)
throw Exception("CatBoost model was not loaded.", ErrorCodes::LOGICAL_ERROR);
return model->evaluate(columns, float_features_count, cat_features_count);
}
}

View File

@ -0,0 +1,79 @@
#pragma once
#include <Interpreters/IExternalLoadable.h>
#include <Columns/IColumn.h>
#include <Columns/ColumnsNumber.h>
namespace DB
{
/// CatBoost wrapper interface functions.
struct CatBoostWrapperAPI;
class CatBoostWrapperAPIProvider
{
public:
virtual ~CatBoostWrapperAPIProvider() = default;
virtual const CatBoostWrapperAPI & getAPI() const = 0;
};
/// CatBoost model interface.
class ICatBoostModel
{
public:
virtual ~ICatBoostModel() = default;
/// Evaluate model. Use first `float_features_count` columns as float features,
/// the others `cat_features_count` as categorical features.
virtual ColumnPtr evaluate(const ConstColumnPlainPtrs & columns, size_t float_features_count, size_t cat_features_count) const = 0;
};
/// General ML model evaluator interface.
class IModel : public IExternalLoadable
{
public:
virtual ColumnPtr evaluate(const ConstColumnPlainPtrs & columns) const = 0;
};
class CatBoostModel : public IModel
{
public:
CatBoostModel(const std::string & name, const std::string & model_path,
const std::string & lib_path, const ExternalLoadableLifetime & lifetime,
size_t float_features_count, size_t cat_features_count);
ColumnPtr evaluate(const ConstColumnPlainPtrs & columns) const override;
size_t getFloatFeaturesCount() const;
size_t getCatFeaturesCount() const;
/// IExternalLoadable interface.
const ExternalLoadableLifetime & getLifetime() const override;
std::string getName() const override { return name; }
bool supportUpdates() const override { return true; }
bool isModified() const override;
std::unique_ptr<IExternalLoadable> clone() const override;
std::exception_ptr getCreationException() const override { return creation_exception; }
private:
std::string name;
std::string model_path;
std::string lib_path;
ExternalLoadableLifetime lifetime;
std::exception_ptr creation_exception;
std::shared_ptr<CatBoostWrapperAPIProvider> api_provider;
const CatBoostWrapperAPI * api;
std::unique_ptr<ICatBoostModel> model;
size_t float_features_count;
size_t cat_features_count;
void init(const std::string & lib_path);
};
}

View File

@ -98,7 +98,7 @@ public:
return true;
}
DictionaryPtr clone() const override
std::unique_ptr<IExternalLoadable> clone() const override
{
return std::make_unique<ComplexKeyCacheDictionary>(*this);
}

View File

@ -46,7 +46,7 @@ public:
bool isCached() const override { return false; }
DictionaryPtr clone() const override { return std::make_unique<ComplexKeyHashedDictionary>(*this); }
std::unique_ptr<IExternalLoadable> clone() const override { return std::make_unique<ComplexKeyHashedDictionary>(*this); }
const IDictionarySource * getSource() const override { return source_ptr.get(); }

View File

@ -13,7 +13,7 @@ class Context;
class DictionaryFactory : public ext::singleton<DictionaryFactory>
{
public:
DictionaryPtr create(const std::string & name, Poco::Util::AbstractConfiguration & config,
DictionaryPtr create(const std::string & name, const Poco::Util::AbstractConfiguration & config,
const std::string & config_prefix, Context & context) const;
};

View File

@ -90,7 +90,7 @@ DictionarySourceFactory::DictionarySourceFactory()
DictionarySourcePtr DictionarySourceFactory::create(
const std::string & name, Poco::Util::AbstractConfiguration & config, const std::string & config_prefix,
const std::string & name, const Poco::Util::AbstractConfiguration & config, const std::string & config_prefix,
const DictionaryStructure & dict_struct, Context & context) const
{
Poco::Util::AbstractConfiguration::Keys keys;

View File

@ -25,7 +25,7 @@ public:
DictionarySourceFactory();
DictionarySourcePtr create(
const std::string & name, Poco::Util::AbstractConfiguration & config, const std::string & config_prefix,
const std::string & name, const Poco::Util::AbstractConfiguration & config, const std::string & config_prefix,
const DictionaryStructure & dict_struct, Context & context) const;
};

View File

@ -105,16 +105,6 @@ std::string toString(const AttributeUnderlyingType type)
}
DictionaryLifetime::DictionaryLifetime(const Poco::Util::AbstractConfiguration & config, const std::string & config_prefix)
{
const auto & lifetime_min_key = config_prefix + ".min";
const auto has_min = config.has(lifetime_min_key);
this->min_sec = has_min ? config.getInt(lifetime_min_key) : config.getInt(config_prefix);
this->max_sec = has_min ? config.getInt(config_prefix + ".max") : this->min_sec;
}
DictionarySpecialAttribute::DictionarySpecialAttribute(const Poco::Util::AbstractConfiguration & config, const std::string & config_prefix)
: name{config.getString(config_prefix + ".name", "")},
expression{config.getString(config_prefix + ".expression", "")}

View File

@ -4,6 +4,7 @@
#include <IO/ReadBufferFromString.h>
#include <IO/WriteBuffer.h>
#include <IO/WriteHelpers.h>
#include <Interpreters/IExternalLoadable.h>
#include <Poco/Util/AbstractConfiguration.h>
#include <ext/range.h>
#include <numeric>
@ -42,13 +43,7 @@ std::string toString(const AttributeUnderlyingType type);
/// Min and max lifetimes for a dictionary or it's entry
struct DictionaryLifetime final
{
UInt64 min_sec;
UInt64 max_sec;
DictionaryLifetime(const Poco::Util::AbstractConfiguration & config, const std::string & config_prefix);
};
using DictionaryLifetime = ExternalLoadableLifetime;
/** Holds the description of a single dictionary attribute:

View File

@ -41,7 +41,7 @@ public:
bool isCached() const override { return false; }
DictionaryPtr clone() const override { return std::make_unique<FlatDictionary>(*this); }
std::unique_ptr<IExternalLoadable> clone() const override { return std::make_unique<FlatDictionary>(*this); }
const IDictionarySource * getSource() const override { return source_ptr.get(); }

View File

@ -40,7 +40,7 @@ public:
bool isCached() const override { return false; }
DictionaryPtr clone() const override { return std::make_unique<HashedDictionary>(*this); }
std::unique_ptr<IExternalLoadable> clone() const override { return std::make_unique<HashedDictionary>(*this); }
const IDictionarySource * getSource() const override { return source_ptr.get(); }

View File

@ -1,22 +1,21 @@
#pragma once
#include <Core/Field.h>
#include <Interpreters/IExternalLoadable.h>
#include <common/StringRef.h>
#include <Core/Names.h>
#include <Poco/Util/XMLConfiguration.h>
#include <Common/PODArray.h>
#include <memory>
#include <chrono>
#include <Dictionaries/IDictionarySource.h>
namespace DB
{
class IDictionarySource;
struct IDictionaryBase;
using DictionaryPtr = std::unique_ptr<IDictionaryBase>;
struct DictionaryLifetime;
struct DictionaryStructure;
class ColumnString;
@ -24,14 +23,10 @@ class IBlockInputStream;
using BlockInputStreamPtr = std::shared_ptr<IBlockInputStream>;
struct IDictionaryBase : public std::enable_shared_from_this<IDictionaryBase>
struct IDictionaryBase : public IExternalLoadable
{
using Key = UInt64;
virtual std::exception_ptr getCreationException() const = 0;
virtual std::string getName() const = 0;
virtual std::string getTypeName() const = 0;
virtual size_t getBytesAllocated() const = 0;
@ -45,12 +40,9 @@ struct IDictionaryBase : public std::enable_shared_from_this<IDictionaryBase>
virtual double getLoadFactor() const = 0;
virtual bool isCached() const = 0;
virtual DictionaryPtr clone() const = 0;
virtual const IDictionarySource * getSource() const = 0;
virtual const DictionaryLifetime & getLifetime() const = 0;
virtual const DictionaryStructure & getStructure() const = 0;
virtual std::chrono::time_point<std::chrono::system_clock> getCreationTime() const = 0;
@ -59,7 +51,23 @@ struct IDictionaryBase : public std::enable_shared_from_this<IDictionaryBase>
virtual BlockInputStreamPtr getBlockInputStream(const Names & column_names, size_t max_block_size) const = 0;
virtual ~IDictionaryBase() = default;
bool supportUpdates() const override { return !isCached(); }
bool isModified() const override
{
auto source = getSource();
return source && source->isModified();
}
std::shared_ptr<IDictionaryBase> shared_from_this()
{
return std::static_pointer_cast<IDictionaryBase>(IExternalLoadable::shared_from_this());
}
std::shared_ptr<const IDictionaryBase> shared_from_this() const
{
return std::static_pointer_cast<const IDictionaryBase>(IExternalLoadable::shared_from_this());
}
};

View File

@ -41,7 +41,7 @@ public:
bool isCached() const override { return false; }
DictionaryPtr clone() const override { return std::make_unique<RangeHashedDictionary>(*this); }
std::unique_ptr<IExternalLoadable> clone() const override { return std::make_unique<RangeHashedDictionary>(*this); }
const IDictionarySource * getSource() const override { return source_ptr.get(); }

View File

@ -50,7 +50,7 @@ public:
bool isCached() const override { return false; }
DictionaryPtr clone() const override { return std::make_unique<TrieDictionary>(*this); }
std::unique_ptr<IExternalLoadable> clone() const override { return std::make_unique<TrieDictionary>(*this); }
const IDictionarySource * getSource() const override { return source_ptr.get(); }

View File

@ -0,0 +1,62 @@
#include <Functions/FunctionsExternalModels.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/FunctionFactory.h>
#include <Interpreters/Context.h>
#include <Interpreters/ExternalModels.h>
#include <DataTypes/DataTypeString.h>
#include <DataTypes/DataTypesNumber.h>
#include <Columns/ColumnString.h>
#include <ext/range.h>
namespace DB
{
FunctionPtr FunctionModelEvaluate::create(const Context & context)
{
return std::make_shared<FunctionModelEvaluate>(context.getExternalModels());
}
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int TOO_LESS_ARGUMENTS_FOR_FUNCTION;
extern const int ILLEGAL_COLUMN;
}
DataTypePtr FunctionModelEvaluate::getReturnTypeImpl(const DataTypes & arguments) const
{
if (arguments.size() < 2)
throw Exception("Function " + getName() + " expects at least 2 arguments",
ErrorCodes::TOO_LESS_ARGUMENTS_FOR_FUNCTION);
if (!checkDataType<DataTypeString>(arguments[0].get()))
throw Exception("Illegal type " + arguments[0]->getName() + " of first argument of function " + getName()
+ ", expected a string.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return std::make_shared<DataTypeFloat64>();
}
void FunctionModelEvaluate::executeImpl(Block & block, const ColumnNumbers & arguments, size_t result)
{
const auto name_col = checkAndGetColumnConst<ColumnString>(block.getByPosition(arguments[0]).column.get());
if (!name_col)
throw Exception("First argument of function " + getName() + " must be a constant string",
ErrorCodes::ILLEGAL_COLUMN);
auto model = models.getModel(name_col->getValue<String>());
ConstColumnPlainPtrs columns;
columns.reserve(arguments.size());
for (auto i : ext::range(1, arguments.size()))
columns.push_back(block.getByPosition(arguments[i]).column.get());
block.getByPosition(result).column = model->evaluate(columns);
}
void registerFunctionsExternalModels(FunctionFactory & factory)
{
factory.registerFunction<FunctionModelEvaluate>();
}
}

View File

@ -0,0 +1,36 @@
#pragma once
#include <Functions/IFunction.h>
namespace DB
{
class ExternalModels;
/// Evaluate external model.
/// First argument - model name, the others - model arguments.
/// * for CatBoost model - float features first, then categorical
/// Result - Float64.
class FunctionModelEvaluate final : public IFunction
{
public:
static constexpr auto name = "modelEvaluate";
static FunctionPtr create(const Context & context);
explicit FunctionModelEvaluate(const ExternalModels & models) : models(models) {}
String getName() const override { return name; }
bool isVariadic() const override { return true; }
size_t getNumberOfArguments() const override { return 0; }
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override;
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) override;
private:
const ExternalModels & models;
};
}

View File

@ -20,6 +20,7 @@ void registerFunctionsConversion(FunctionFactory &);
void registerFunctionsDateTime(FunctionFactory &);
void registerFunctionsEmbeddedDictionaries(FunctionFactory &);
void registerFunctionsExternalDictionaries(FunctionFactory &);
void registerFunctionsExternalModels(FunctionFactory &);
void registerFunctionsFormatting(FunctionFactory &);
void registerFunctionsHashing(FunctionFactory &);
void registerFunctionsHigherOrder(FunctionFactory &);
@ -54,6 +55,7 @@ void registerFunctions()
registerFunctionsDateTime(factory);
registerFunctionsEmbeddedDictionaries(factory);
registerFunctionsExternalDictionaries(factory);
registerFunctionsExternalModels(factory);
registerFunctionsFormatting(factory);
registerFunctionsHashing(factory);
registerFunctionsHigherOrder(factory);

View File

@ -29,6 +29,7 @@
#include <Interpreters/Quota.h>
#include <Interpreters/EmbeddedDictionaries.h>
#include <Interpreters/ExternalDictionaries.h>
#include <Interpreters/ExternalModels.h>
#include <Interpreters/ProcessList.h>
#include <Interpreters/Cluster.h>
#include <Interpreters/InterserverIOHandler.h>
@ -94,6 +95,7 @@ struct ContextShared
/// Separate mutex for access of dictionaries. Separate mutex to avoid locks when server doing request to itself.
mutable std::mutex embedded_dictionaries_mutex;
mutable std::mutex external_dictionaries_mutex;
mutable std::mutex external_models_mutex;
/// Separate mutex for re-initialization of zookeer session. This operation could take a long time and must not interfere with another operations.
mutable std::mutex zookeeper_mutex;
@ -111,6 +113,7 @@ struct ContextShared
FormatFactory format_factory; /// Formats.
mutable std::shared_ptr<EmbeddedDictionaries> embedded_dictionaries; /// Metrica's dictionaeis. Have lazy initialization.
mutable std::shared_ptr<ExternalDictionaries> external_dictionaries;
mutable std::shared_ptr<ExternalModels> external_models;
String default_profile_name; /// Default profile name used for default values.
Users users; /// Known users.
Quotas quotas; /// Known quotas for resource use.
@ -1062,6 +1065,17 @@ ExternalDictionaries & Context::getExternalDictionaries()
}
const ExternalModels & Context::getExternalModels() const
{
return getExternalModelsImpl(false);
}
ExternalModels & Context::getExternalModels()
{
return getExternalModelsImpl(false);
}
EmbeddedDictionaries & Context::getEmbeddedDictionariesImpl(const bool throw_on_error) const
{
std::lock_guard<std::mutex> lock(shared->embedded_dictionaries_mutex);
@ -1087,6 +1101,19 @@ ExternalDictionaries & Context::getExternalDictionariesImpl(const bool throw_on_
return *shared->external_dictionaries;
}
ExternalModels & Context::getExternalModelsImpl(bool throw_on_error) const
{
std::lock_guard<std::mutex> lock(shared->external_models_mutex);
if (!shared->external_models)
{
if (!this->global_context)
throw Exception("Logical error: there is no global context", ErrorCodes::LOGICAL_ERROR);
shared->external_models = std::make_shared<ExternalModels>(*this->global_context, throw_on_error);
}
return *shared->external_models;
}
void Context::tryCreateEmbeddedDictionaries() const
{
@ -1100,6 +1127,12 @@ void Context::tryCreateExternalDictionaries() const
}
void Context::tryCreateExternalModels() const
{
static_cast<void>(getExternalModelsImpl(true));
}
void Context::setProgressCallback(ProgressCallback callback)
{
/// Callback is set to a session or to a query. In the session, only one query is processed at a time. Therefore, the lock is not needed.

View File

@ -35,6 +35,7 @@ struct ContextShared;
class QuotaForIntervals;
class EmbeddedDictionaries;
class ExternalDictionaries;
class ExternalModels;
class InterserverIOHandler;
class BackgroundProcessingPool;
class ReshardingWorker;
@ -209,10 +210,13 @@ public:
const EmbeddedDictionaries & getEmbeddedDictionaries() const;
const ExternalDictionaries & getExternalDictionaries() const;
const ExternalModels & getExternalModels() const;
EmbeddedDictionaries & getEmbeddedDictionaries();
ExternalDictionaries & getExternalDictionaries();
ExternalModels & getExternalModels();
void tryCreateEmbeddedDictionaries() const;
void tryCreateExternalDictionaries() const;
void tryCreateExternalModels() const;
/// I/O formats.
BlockInputStreamPtr getInputFormat(const String & name, ReadBuffer & buf, const Block & sample, size_t max_block_size) const;
@ -362,6 +366,7 @@ private:
EmbeddedDictionaries & getEmbeddedDictionariesImpl(bool throw_on_error) const;
ExternalDictionaries & getExternalDictionariesImpl(bool throw_on_error) const;
ExternalModels & getExternalModelsImpl(bool throw_on_error) const;
StoragePtr getTableImpl(const String & database_name, const String & table_name, Exception * exception) const;

View File

@ -24,7 +24,7 @@ namespace ErrorCodes
}
DictionaryPtr DictionaryFactory::create(const std::string & name, Poco::Util::AbstractConfiguration & config,
DictionaryPtr DictionaryFactory::create(const std::string & name, const Poco::Util::AbstractConfiguration & config,
const std::string & config_prefix, Context & context) const
{
Poco::Util::AbstractConfiguration::Keys keys;

View File

@ -1,427 +1,46 @@
#include <Interpreters/ExternalDictionaries.h>
#include <Interpreters/Context.h>
#include <Dictionaries/DictionaryFactory.h>
#include <Dictionaries/DictionaryStructure.h>
#include <Dictionaries/IDictionarySource.h>
#include <Common/StringUtils.h>
#include <Common/MemoryTracker.h>
#include <Common/getMultipleKeysFromConfig.h>
#include <ext/scope_guard.h>
#include <Poco/Util/Application.h>
#include <Poco/Glob.h>
#include <Poco/File.h>
namespace
{
const auto check_period_sec = 5;
const auto backoff_initial_sec = 5;
/// 10 minutes
const auto backoff_max_sec = 10 * 60;
}
namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
extern const int BAD_ARGUMENTS;
}
void ExternalDictionaries::reloadPeriodically()
{
setThreadName("ExterDictReload");
while (true)
{
if (destroy.tryWait(check_period_sec * 1000))
return;
reloadAndUpdate();
}
}
ExternalDictionaries::ExternalDictionaries(Context & context, const bool throw_on_error)
: context(context), log(&Logger::get("ExternalDictionaries"))
{
{
/** During synchronous loading of external dictionaries at moment of query execution,
* we should not use per query memory limit.
*/
TemporarilyDisableMemoryTracker temporarily_disable_memory_tracker;
reloadAndUpdate(throw_on_error);
}
reloading_thread = std::thread{&ExternalDictionaries::reloadPeriodically, this};
}
ExternalDictionaries::~ExternalDictionaries()
{
destroy.set();
reloading_thread.join();
}
namespace
{
std::set<std::string> getDictionariesConfigPaths(const Poco::Util::AbstractConfiguration & config)
{
std::set<std::string> files;
auto patterns = getMultipleValuesFromConfig(config, "", "dictionaries_config");
for (auto & pattern : patterns)
const ExternalLoaderUpdateSettings externalDictionariesUpdateSettings;
const ExternalLoaderConfigSettings & getExternalDictionariesConfigSettings()
{
if (pattern.empty())
continue;
static ExternalLoaderConfigSettings settings;
static std::once_flag flag;
if (pattern[0] != '/')
{
const auto app_config_path = config.getString("config-file", "config.xml");
const auto config_dir = Poco::Path{app_config_path}.parent().toString();
const auto absolute_path = config_dir + pattern;
Poco::Glob::glob(absolute_path, files, 0);
if (!files.empty())
continue;
}
std::call_once(flag, [] {
settings.external_config = "dictionary";
settings.external_name = "name";
Poco::Glob::glob(pattern, files, 0);
}
settings.path_setting_name = "dictionaries_config";
});
return files;
}
}
void ExternalDictionaries::reloadAndUpdate(bool throw_on_error)
{
reloadFromConfigFiles(throw_on_error);
/// list of recreated dictionaries to perform delayed removal from unordered_map
std::list<std::string> recreated_failed_dictionaries;
std::unique_lock<std::mutex> all_lock(all_mutex);
/// retry loading failed dictionaries
for (auto & failed_dictionary : failed_dictionaries)
{
if (std::chrono::system_clock::now() < failed_dictionary.second.next_attempt_time)
continue;
const auto & name = failed_dictionary.first;
try
{
auto dict_ptr = failed_dictionary.second.dict->clone();
if (const auto exception_ptr = dict_ptr->getCreationException())
{
/// recalculate next attempt time
std::uniform_int_distribution<UInt64> distribution(
0, static_cast<UInt64>(std::exp2(failed_dictionary.second.error_count)));
failed_dictionary.second.next_attempt_time = std::chrono::system_clock::now() +
std::chrono::seconds{
std::min<UInt64>(backoff_max_sec, backoff_initial_sec + distribution(rnd_engine))};
++failed_dictionary.second.error_count;
std::rethrow_exception(exception_ptr);
}
else
{
const std::lock_guard<std::mutex> lock{dictionaries_mutex};
const auto & lifetime = dict_ptr->getLifetime();
std::uniform_int_distribution<UInt64> distribution{lifetime.min_sec, lifetime.max_sec};
update_times[name] = std::chrono::system_clock::now() + std::chrono::seconds{distribution(rnd_engine)};
const auto dict_it = dictionaries.find(name);
if (dict_it->second.dict)
dict_it->second.dict->set(dict_ptr.release());
else
dict_it->second.dict = std::make_shared<MultiVersion<IDictionaryBase>>(dict_ptr.release());
/// clear stored exception on success
dict_it->second.exception = std::exception_ptr{};
recreated_failed_dictionaries.push_back(name);
}
}
catch (...)
{
tryLogCurrentException(log, "Failed reloading '" + name + "' dictionary");
if (throw_on_error)
throw;
}
}
/// do not undertake further attempts to recreate these dictionaries
for (const auto & name : recreated_failed_dictionaries)
failed_dictionaries.erase(name);
/// periodic update
for (auto & dictionary : dictionaries)
{
const auto & name = dictionary.first;
try
{
/// If the dictionary failed to load or even failed to initialize from the config.
if (!dictionary.second.dict)
continue;
auto current = dictionary.second.dict->get();
const auto & lifetime = current->getLifetime();
/// do not update dictionaries with zero as lifetime
if (lifetime.min_sec == 0 || lifetime.max_sec == 0)
continue;
/// update only non-cached dictionaries
if (!current->isCached())
{
auto & update_time = update_times[current->getName()];
/// check that timeout has passed
if (std::chrono::system_clock::now() < update_time)
continue;
SCOPE_EXIT({
/// calculate next update time
std::uniform_int_distribution<UInt64> distribution{lifetime.min_sec, lifetime.max_sec};
update_time = std::chrono::system_clock::now() + std::chrono::seconds{distribution(rnd_engine)};
});
/// check source modified
if (current->getSource()->isModified())
{
/// create new version of dictionary
auto new_version = current->clone();
if (const auto exception_ptr = new_version->getCreationException())
std::rethrow_exception(exception_ptr);
dictionary.second.dict->set(new_version.release());
}
}
/// erase stored exception on success
dictionary.second.exception = std::exception_ptr{};
}
catch (...)
{
dictionary.second.exception = std::current_exception();
tryLogCurrentException(log, "Cannot update external dictionary '" + name + "', leaving old version");
if (throw_on_error)
throw;
}
return settings;
}
}
void ExternalDictionaries::reloadFromConfigFiles(const bool throw_on_error, const bool force_reload, const std::string & only_dictionary)
ExternalDictionaries::ExternalDictionaries(Context & context, bool throw_on_error)
: ExternalLoader(context.getConfigRef(),
externalDictionariesUpdateSettings,
getExternalDictionariesConfigSettings(),
&Logger::get("ExternalDictionaries"),
"external dictionary"),
context(context)
{
const auto config_paths = getDictionariesConfigPaths(context.getConfigRef());
for (const auto & config_path : config_paths)
{
try
{
reloadFromConfigFile(config_path, throw_on_error, force_reload, only_dictionary);
}
catch (...)
{
tryLogCurrentException(log, "reloadFromConfigFile has thrown while reading from " + config_path);
if (throw_on_error)
throw;
}
}
init(throw_on_error);
}
void ExternalDictionaries::reloadFromConfigFile(const std::string & config_path, const bool throw_on_error, const bool force_reload,
const std::string & only_dictionary)
std::unique_ptr<IExternalLoadable> ExternalDictionaries::create(
const std::string & name, const Configuration & config, const std::string & config_prefix)
{
const Poco::File config_file{config_path};
if (config_path.empty() || !config_file.exists())
{
LOG_WARNING(log, "config file '" + config_path + "' does not exist");
}
else
{
std::unique_lock<std::mutex> all_lock(all_mutex);
auto modification_time_it = last_modification_times.find(config_path);
if (modification_time_it == std::end(last_modification_times))
modification_time_it = last_modification_times.emplace(config_path, Poco::Timestamp{0}).first;
auto & config_last_modified = modification_time_it->second;
const auto last_modified = config_file.getLastModified();
if (force_reload || last_modified > config_last_modified)
{
Poco::AutoPtr<Poco::Util::XMLConfiguration> config = new Poco::Util::XMLConfiguration(config_path);
/// Definitions of dictionaries may have changed, recreate all of them
/// If we need update only one dictionary, don't update modification time: might be other dictionaries in the config file
if (only_dictionary.empty())
config_last_modified = last_modified;
/// get all dictionaries' definitions
Poco::Util::AbstractConfiguration::Keys keys;
config->keys(keys);
/// for each dictionary defined in xml config
for (const auto & key : keys)
{
std::string name;
if (!startsWith(key, "dictionary"))
{
if (!startsWith(key.data(), "comment"))
LOG_WARNING(log,
config_path << ": unknown node in dictionaries file: '" << key + "', 'dictionary'");
continue;
}
try
{
name = config->getString(key + ".name");
if (name.empty())
{
LOG_WARNING(log, config_path << ": dictionary name cannot be empty");
continue;
}
if (!only_dictionary.empty() && name != only_dictionary)
continue;
decltype(dictionaries.begin()) dict_it;
{
std::lock_guard<std::mutex> lock{dictionaries_mutex};
dict_it = dictionaries.find(name);
}
if (dict_it != std::end(dictionaries) && dict_it->second.origin != config_path)
throw std::runtime_error{"Overriding dictionary from file " + dict_it->second.origin};
auto dict_ptr = DictionaryFactory::instance().create(name, *config, key, context);
/// If the dictionary could not be loaded.
if (const auto exception_ptr = dict_ptr->getCreationException())
{
const auto failed_dict_it = failed_dictionaries.find(name);
if (failed_dict_it != std::end(failed_dictionaries))
{
failed_dict_it->second = FailedDictionaryInfo{
std::move(dict_ptr),
std::chrono::system_clock::now() + std::chrono::seconds{backoff_initial_sec}};
}
else
failed_dictionaries.emplace(name, FailedDictionaryInfo{
std::move(dict_ptr),
std::chrono::system_clock::now() + std::chrono::seconds{backoff_initial_sec}});
std::rethrow_exception(exception_ptr);
}
else if (!dict_ptr->isCached())
{
const auto & lifetime = dict_ptr->getLifetime();
if (lifetime.min_sec != 0 && lifetime.max_sec != 0)
{
std::uniform_int_distribution<UInt64> distribution{
lifetime.min_sec,
lifetime.max_sec
};
update_times[name] = std::chrono::system_clock::now() +
std::chrono::seconds{distribution(rnd_engine)};
}
}
const std::lock_guard<std::mutex> lock{dictionaries_mutex};
/// add new dictionary or update an existing version
if (dict_it == std::end(dictionaries))
dictionaries.emplace(name, DictionaryInfo{
std::make_shared<MultiVersion<IDictionaryBase>>(dict_ptr.release()),
config_path
});
else
{
if (dict_it->second.dict)
dict_it->second.dict->set(dict_ptr.release());
else
dict_it->second.dict = std::make_shared<MultiVersion<IDictionaryBase>>(dict_ptr.release());
/// erase stored exception on success
dict_it->second.exception = std::exception_ptr{};
failed_dictionaries.erase(name);
}
}
catch (...)
{
if (!name.empty())
{
/// If the dictionary could not load data or even failed to initialize from the config.
/// - all the same we insert information into the `dictionaries`, with the zero pointer `dict`.
const std::lock_guard<std::mutex> lock{dictionaries_mutex};
const auto exception_ptr = std::current_exception();
const auto dict_it = dictionaries.find(name);
if (dict_it == std::end(dictionaries))
dictionaries.emplace(name, DictionaryInfo{nullptr, config_path, exception_ptr});
else
dict_it->second.exception = exception_ptr;
}
tryLogCurrentException(log, "Cannot create external dictionary '" + name + "' from config path " + config_path);
/// propagate exception
if (throw_on_error)
throw;
}
}
}
}
}
void ExternalDictionaries::reload()
{
reloadFromConfigFiles(true, true);
}
void ExternalDictionaries::reloadDictionary(const std::string & name)
{
reloadFromConfigFiles(true, true, name);
/// Check that specified dict was loaded
const std::lock_guard<std::mutex> lock{dictionaries_mutex};
if (!dictionaries.count(name))
throw Exception("Dictionary " + name + " wasn't loaded during the reload process", ErrorCodes::BAD_ARGUMENTS);
}
MultiVersion<IDictionaryBase>::Version ExternalDictionaries::getDictionary(const std::string & name) const
{
const std::lock_guard<std::mutex> lock{dictionaries_mutex};
const auto it = dictionaries.find(name);
if (it == std::end(dictionaries))
throw Exception{
"No such dictionary: " + name,
ErrorCodes::BAD_ARGUMENTS
};
if (!it->second.dict)
it->second.exception ? std::rethrow_exception(it->second.exception) :
throw Exception{"No dictionary", ErrorCodes::LOGICAL_ERROR};
return it->second.dict->get();
return DictionaryFactory::instance().create(name, config, config_prefix, context);
}
}

View File

@ -1,19 +1,9 @@
#pragma once
#include <Dictionaries/IDictionary.h>
#include <Common/Exception.h>
#include <Common/setThreadName.h>
#include <Common/randomSeed.h>
#include <common/MultiVersion.h>
#include <Interpreters/ExternalLoader.h>
#include <common/logger_useful.h>
#include <Poco/Event.h>
#include <unistd.h>
#include <time.h>
#include <mutex>
#include <thread>
#include <unordered_map>
#include <chrono>
#include <pcg_random.hpp>
#include <memory>
namespace DB
@ -21,93 +11,36 @@ namespace DB
class Context;
/** Manages user-defined dictionaries.
* Monitors configuration file and automatically reloads dictionaries in a separate thread.
* The monitoring thread wakes up every @check_period_sec seconds and checks
* modification time of dictionaries' configuration file. If said time is greater than
* @config_last_modified, the dictionaries are created from scratch using configuration file,
* possibly overriding currently existing dictionaries with the same name (previous versions of
* overridden dictionaries will live as long as there are any users retaining them).
*
* Apart from checking configuration file for modifications, each non-cached dictionary
* has a lifetime of its own and may be updated if it's source reports that it has been
* modified. The time of next update is calculated by choosing uniformly a random number
* distributed between lifetime.min_sec and lifetime.max_sec.
* If either of lifetime.min_sec and lifetime.max_sec is zero, such dictionary is never updated.
*/
class ExternalDictionaries
/// Manages user-defined dictionaries.
class ExternalDictionaries : public ExternalLoader
{
private:
public:
using DictPtr = std::shared_ptr<IDictionaryBase>;
/// Dictionaries will be loaded immediately and then will be updated in separate thread, each 'reload_period' seconds.
ExternalDictionaries(Context & context, bool throw_on_error);
/// Forcibly reloads specified dictionary.
void reloadDictionary(const std::string & name) { reload(name); }
DictPtr getDictionary(const std::string & name) const
{
return std::static_pointer_cast<IDictionaryBase>(getLoadable(name));
}
protected:
std::unique_ptr<IExternalLoadable> create(const std::string & name, const Configuration & config,
const std::string & config_prefix) override;
using ExternalLoader::getObjectsMap;
friend class StorageSystemDictionaries;
friend class DatabaseDictionary;
/// Protects only dictionaries map.
mutable std::mutex dictionaries_mutex;
/// Protects all data, currently used to avoid races between updating thread and SYSTEM queries
mutable std::mutex all_mutex;
using DictionaryPtr = std::shared_ptr<MultiVersion<IDictionaryBase>>;
struct DictionaryInfo final
{
DictionaryPtr dict;
std::string origin;
std::exception_ptr exception;
};
struct FailedDictionaryInfo final
{
std::unique_ptr<IDictionaryBase> dict;
std::chrono::system_clock::time_point next_attempt_time;
UInt64 error_count;
};
/** name -> dictionary.
*/
std::unordered_map<std::string, DictionaryInfo> dictionaries;
/** Here are dictionaries, that has been never loaded successfully.
* They are also in 'dictionaries', but with nullptr as 'dict'.
*/
std::unordered_map<std::string, FailedDictionaryInfo> failed_dictionaries;
/** Both for dictionaries and failed_dictionaries.
*/
std::unordered_map<std::string, std::chrono::system_clock::time_point> update_times;
pcg64 rnd_engine{randomSeed()};
private:
Context & context;
std::thread reloading_thread;
Poco::Event destroy;
Logger * log;
std::unordered_map<std::string, Poco::Timestamp> last_modification_times;
/// Check dictionaries definitions in config files and reload or/and add new ones if the definition is changed
void reloadFromConfigFiles(const bool throw_on_error, const bool force_reload = false, const std::string & only_dictionary = "");
void reloadFromConfigFile(const std::string & config_path, const bool throw_on_error, const bool force_reload,
const std::string & only_dictionary);
/// Check config files and update expired dictionaries
void reloadAndUpdate(bool throw_on_error = false);
void reloadPeriodically();
public:
/// Dictionaries will be loaded immediately and then will be updated in separate thread, each 'reload_period' seconds.
ExternalDictionaries(Context & context, const bool throw_on_error);
~ExternalDictionaries();
/// Forcibly reloads all dictionaries.
void reload();
/// Forcibly reloads specified dictionary.
void reloadDictionary(const std::string & name);
MultiVersion<IDictionaryBase>::Version getDictionary(const std::string & name) const;
};
}

View File

@ -0,0 +1,433 @@
#include <Interpreters/ExternalLoader.h>
#include <Common/StringUtils.h>
#include <Common/MemoryTracker.h>
#include <Common/Exception.h>
#include <Common/getMultipleKeysFromConfig.h>
#include <Common/setThreadName.h>
#include <ext/scope_guard.h>
#include <Poco/Util/Application.h>
#include <Poco/Glob.h>
#include <Poco/File.h>
#include <cmath>
#include <Poco/Util/XMLConfiguration.h>
namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
extern const int BAD_ARGUMENTS;
}
ExternalLoadableLifetime::ExternalLoadableLifetime(const Poco::Util::AbstractConfiguration & config,
const std::string & config_prefix)
{
const auto & lifetime_min_key = config_prefix + ".min";
const auto has_min = config.has(lifetime_min_key);
min_sec = has_min ? config.getUInt64(lifetime_min_key) : config.getUInt64(config_prefix);
max_sec = has_min ? config.getUInt64(config_prefix + ".max") : min_sec;
}
void ExternalLoader::reloadPeriodically()
{
setThreadName("ExterLdrReload");
while (true)
{
if (destroy.tryWait(update_settings.check_period_sec * 1000))
return;
reloadAndUpdate();
}
}
ExternalLoader::ExternalLoader(const Poco::Util::AbstractConfiguration & config,
const ExternalLoaderUpdateSettings & update_settings,
const ExternalLoaderConfigSettings & config_settings,
Logger * log, const std::string & loadable_object_name)
: config(config), update_settings(update_settings), config_settings(config_settings),
log(log), object_name(loadable_object_name)
{
}
void ExternalLoader::init(bool throw_on_error)
{
if (is_initialized)
return;
is_initialized = true;
{
/// During synchronous loading of external dictionaries at moment of query execution,
/// we should not use per query memory limit.
TemporarilyDisableMemoryTracker temporarily_disable_memory_tracker;
reloadAndUpdate(throw_on_error);
}
reloading_thread = std::thread{&ExternalLoader::reloadPeriodically, this};
}
ExternalLoader::~ExternalLoader()
{
destroy.set();
reloading_thread.join();
}
namespace
{
std::set<std::string> getConfigPaths(const Poco::Util::AbstractConfiguration & config,
const std::string & external_config_paths_setting)
{
std::set<std::string> files;
auto patterns = getMultipleValuesFromConfig(config, "", external_config_paths_setting);
for (auto & pattern : patterns)
{
if (pattern.empty())
continue;
if (pattern[0] != '/')
{
const auto app_config_path = config.getString("config-file", "config.xml");
const auto config_dir = Poco::Path{app_config_path}.parent().toString();
const auto absolute_path = config_dir + pattern;
Poco::Glob::glob(absolute_path, files, 0);
if (!files.empty())
continue;
}
Poco::Glob::glob(pattern, files, 0);
}
return files;
}
}
void ExternalLoader::reloadAndUpdate(bool throw_on_error)
{
reloadFromConfigFiles(throw_on_error);
/// list of recreated loadable objects to perform delayed removal from unordered_map
std::list<std::string> recreated_failed_loadable_objects;
std::unique_lock<std::mutex> all_lock(all_mutex);
/// retry loading failed loadable objects
for (auto & failed_loadable_object : failed_loadable_objects)
{
if (std::chrono::system_clock::now() < failed_loadable_object.second.next_attempt_time)
continue;
const auto & name = failed_loadable_object.first;
try
{
auto loadable_ptr = failed_loadable_object.second.loadable->clone();
if (const auto exception_ptr = loadable_ptr->getCreationException())
{
/// recalculate next attempt time
std::uniform_int_distribution<UInt64> distribution(
0, static_cast<UInt64>(std::exp2(failed_loadable_object.second.error_count)));
std::chrono::seconds delay(std::min<UInt64>(
update_settings.backoff_max_sec,
update_settings.backoff_initial_sec + distribution(rnd_engine)));
failed_loadable_object.second.next_attempt_time = std::chrono::system_clock::now() + delay;
++failed_loadable_object.second.error_count;
std::rethrow_exception(exception_ptr);
}
else
{
const std::lock_guard<std::mutex> lock{map_mutex};
const auto & lifetime = loadable_ptr->getLifetime();
std::uniform_int_distribution<UInt64> distribution{lifetime.min_sec, lifetime.max_sec};
update_times[name] = std::chrono::system_clock::now() + std::chrono::seconds{distribution(rnd_engine)};
const auto dict_it = loadable_objects.find(name);
dict_it->second.loadable.reset();
dict_it->second.loadable = std::move(loadable_ptr);
/// clear stored exception on success
dict_it->second.exception = std::exception_ptr{};
recreated_failed_loadable_objects.push_back(name);
}
}
catch (...)
{
tryLogCurrentException(log, "Failed reloading '" + name + "' " + object_name);
if (throw_on_error)
throw;
}
}
/// do not undertake further attempts to recreate these loadable objects
for (const auto & name : recreated_failed_loadable_objects)
failed_loadable_objects.erase(name);
/// periodic update
for (auto & loadable_object : loadable_objects)
{
const auto & name = loadable_object.first;
try
{
/// If the loadable objects failed to load or even failed to initialize from the config.
if (!loadable_object.second.loadable)
continue;
auto current = loadable_object.second.loadable;
const auto & lifetime = current->getLifetime();
/// do not update loadable objects with zero as lifetime
if (lifetime.min_sec == 0 || lifetime.max_sec == 0)
continue;
if (current->supportUpdates())
{
auto & update_time = update_times[current->getName()];
/// check that timeout has passed
if (std::chrono::system_clock::now() < update_time)
continue;
SCOPE_EXIT({
/// calculate next update time
std::uniform_int_distribution<UInt64> distribution{lifetime.min_sec, lifetime.max_sec};
update_time = std::chrono::system_clock::now() + std::chrono::seconds{distribution(rnd_engine)};
});
/// check source modified
if (current->isModified())
{
/// create new version of loadable object
auto new_version = current->clone();
if (const auto exception_ptr = new_version->getCreationException())
std::rethrow_exception(exception_ptr);
loadable_object.second.loadable.reset();
loadable_object.second.loadable = std::move(new_version);
}
}
/// erase stored exception on success
loadable_object.second.exception = std::exception_ptr{};
}
catch (...)
{
loadable_object.second.exception = std::current_exception();
tryLogCurrentException(log, "Cannot update " + object_name + " '" + name + "', leaving old version");
if (throw_on_error)
throw;
}
}
}
void ExternalLoader::reloadFromConfigFiles(const bool throw_on_error, const bool force_reload, const std::string & only_dictionary)
{
const auto config_paths = getConfigPaths(config, config_settings.path_setting_name);
for (const auto & config_path : config_paths)
{
try
{
reloadFromConfigFile(config_path, throw_on_error, force_reload, only_dictionary);
}
catch (...)
{
tryLogCurrentException(log, "reloadFromConfigFile has thrown while reading from " + config_path);
if (throw_on_error)
throw;
}
}
}
void ExternalLoader::reloadFromConfigFile(const std::string & config_path, const bool throw_on_error,
const bool force_reload, const std::string & loadable_name)
{
const Poco::File config_file{config_path};
if (config_path.empty() || !config_file.exists())
{
LOG_WARNING(log, "config file '" + config_path + "' does not exist");
}
else
{
std::unique_lock<std::mutex> all_lock(all_mutex);
auto modification_time_it = last_modification_times.find(config_path);
if (modification_time_it == std::end(last_modification_times))
modification_time_it = last_modification_times.emplace(config_path, Poco::Timestamp{0}).first;
auto & config_last_modified = modification_time_it->second;
const auto last_modified = config_file.getLastModified();
if (force_reload || last_modified > config_last_modified)
{
Poco::AutoPtr<Poco::Util::XMLConfiguration> config = new Poco::Util::XMLConfiguration(config_path);
/// Definitions of loadable objects may have changed, recreate all of them
/// If we need update only one object, don't update modification time: might be other objects in the config file
if (loadable_name.empty())
config_last_modified = last_modified;
/// get all objects' definitions
Poco::Util::AbstractConfiguration::Keys keys;
config->keys(keys);
/// for each loadable object defined in xml config
for (const auto & key : keys)
{
std::string name;
if (!startsWith(key, config_settings.external_config))
{
if (!startsWith(key, "comment"))
LOG_WARNING(log, config_path << ": unknown node in file: '" << key
<< "', expected '" << config_settings.external_config << "'");
continue;
}
try
{
name = config->getString(key + "." + config_settings.external_name);
if (name.empty())
{
LOG_WARNING(log, config_path << ": " + config_settings.external_name + " name cannot be empty");
continue;
}
if (!loadable_name.empty() && name != loadable_name)
continue;
decltype(loadable_objects.begin()) object_it;
{
std::lock_guard<std::mutex> lock{map_mutex};
object_it = loadable_objects.find(name);
}
if (object_it != std::end(loadable_objects) && object_it->second.origin != config_path)
throw std::runtime_error{"Overriding " + object_name + " from file " + object_it->second.origin};
auto object_ptr = create(name, *config, key);
/// If the object could not be loaded.
if (const auto exception_ptr = object_ptr->getCreationException())
{
std::chrono::seconds delay(update_settings.backoff_initial_sec);
const auto failed_dict_it = failed_loadable_objects.find(name);
FailedLoadableInfo info{std::move(object_ptr), std::chrono::system_clock::now() + delay, 0};
if (failed_dict_it != std::end(failed_loadable_objects))
(*failed_dict_it).second = std::move(info);
else
failed_loadable_objects.emplace(name, std::move(info));
std::rethrow_exception(exception_ptr);
}
else if (object_ptr->supportUpdates())
{
const auto & lifetime = object_ptr->getLifetime();
if (lifetime.min_sec != 0 && lifetime.max_sec != 0)
{
std::uniform_int_distribution<UInt64> distribution(lifetime.min_sec, lifetime.max_sec);
update_times[name] = std::chrono::system_clock::now() +
std::chrono::seconds{distribution(rnd_engine)};
}
}
const std::lock_guard<std::mutex> lock{map_mutex};
/// add new loadable object or update an existing version
if (object_it == std::end(loadable_objects))
loadable_objects.emplace(name, LoadableInfo{std::move(object_ptr), config_path});
else
{
if (object_it->second.loadable)
object_it->second.loadable.reset();
object_it->second.loadable = std::move(object_ptr);
/// erase stored exception on success
object_it->second.exception = std::exception_ptr{};
failed_loadable_objects.erase(name);
}
}
catch (...)
{
if (!name.empty())
{
/// If the loadable object could not load data or even failed to initialize from the config.
/// - all the same we insert information into the `loadable_objects`, with the zero pointer `loadable`.
const std::lock_guard<std::mutex> lock{map_mutex};
const auto exception_ptr = std::current_exception();
const auto loadable_it = loadable_objects.find(name);
if (loadable_it == std::end(loadable_objects))
loadable_objects.emplace(name, LoadableInfo{nullptr, config_path, exception_ptr});
else
loadable_it->second.exception = exception_ptr;
}
tryLogCurrentException(log, "Cannot create " + object_name + " '"
+ name + "' from config path " + config_path);
/// propagate exception
if (throw_on_error)
throw;
}
}
}
}
}
void ExternalLoader::reload()
{
reloadFromConfigFiles(true, true);
}
void ExternalLoader::reload(const std::string & name)
{
reloadFromConfigFiles(true, true, name);
/// Check that specified object was loaded
const std::lock_guard<std::mutex> lock{map_mutex};
if (!loadable_objects.count(name))
throw Exception("Failed to load " + object_name + " '" + name + "' during the reload process", ErrorCodes::BAD_ARGUMENTS);
}
ExternalLoader::LoadablePtr ExternalLoader::getLoadable(const std::string & name) const
{
const std::lock_guard<std::mutex> lock{map_mutex};
const auto it = loadable_objects.find(name);
if (it == std::end(loadable_objects))
throw Exception("No such " + object_name + ": " + name, ErrorCodes::BAD_ARGUMENTS);
if (!it->second.loadable)
it->second.exception ? std::rethrow_exception(it->second.exception) :
throw Exception{object_name + " '" + name + "' is not loaded", ErrorCodes::LOGICAL_ERROR};
return it->second.loadable;
}
ExternalLoader::LockedObjectsMap ExternalLoader::getObjectsMap() const
{
return LockedObjectsMap(map_mutex, loadable_objects);
}
}

View File

@ -0,0 +1,173 @@
#pragma once
#include <common/logger_useful.h>
#include <Poco/Event.h>
#include <mutex>
#include <thread>
#include <unordered_map>
#include <chrono>
#include <tuple>
#include <Interpreters/IExternalLoadable.h>
#include <Core/Types.h>
#include <pcg_random.hpp>
#include <Common/randomSeed.h>
namespace DB
{
class Context;
struct ExternalLoaderUpdateSettings
{
UInt64 check_period_sec = 5;
UInt64 backoff_initial_sec = 5;
/// 10 minutes
UInt64 backoff_max_sec = 10 * 60;
ExternalLoaderUpdateSettings() = default;
ExternalLoaderUpdateSettings(UInt64 check_period_sec, UInt64 backoff_initial_sec, UInt64 backoff_max_sec)
: check_period_sec(check_period_sec),
backoff_initial_sec(backoff_initial_sec),
backoff_max_sec(backoff_max_sec) {}
};
/* External configuration structure.
*
* <external_group>
* <external_config>
* <external_name>name</external_name>
* ....
* </external_config>
* </external_group>
*/
struct ExternalLoaderConfigSettings
{
std::string external_config;
std::string external_name;
std::string path_setting_name;
};
/** Manages user-defined objects.
* Monitors configuration file and automatically reloads objects in a separate thread.
* The monitoring thread wakes up every @check_period_sec seconds and checks
* modification time of objects' configuration file. If said time is greater than
* @config_last_modified, the objects are created from scratch using configuration file,
* possibly overriding currently existing objects with the same name (previous versions of
* overridden objects will live as long as there are any users retaining them).
*
* Apart from checking configuration file for modifications, each object
* has a lifetime of its own and may be updated if it supportUpdates.
* The time of next update is calculated by choosing uniformly a random number
* distributed between lifetime.min_sec and lifetime.max_sec.
* If either of lifetime.min_sec and lifetime.max_sec is zero, such object is never updated.
*/
class ExternalLoader
{
public:
using LoadablePtr = std::shared_ptr<IExternalLoadable>;
private:
struct LoadableInfo final
{
LoadablePtr loadable;
std::string origin;
std::exception_ptr exception;
};
struct FailedLoadableInfo final
{
std::unique_ptr<IExternalLoadable> loadable;
std::chrono::system_clock::time_point next_attempt_time;
UInt64 error_count;
};
public:
using Configuration = Poco::Util::AbstractConfiguration;
using ObjectsMap = std::unordered_map<std::string, LoadableInfo>;
/// Objects will be loaded immediately and then will be updated in separate thread, each 'reload_period' seconds.
ExternalLoader(const Configuration & config,
const ExternalLoaderUpdateSettings & update_settings,
const ExternalLoaderConfigSettings & config_settings,
Logger * log, const std::string & loadable_object_name);
virtual ~ExternalLoader();
/// Forcibly reloads all loadable objects.
void reload();
/// Forcibly reloads specified loadable object.
void reload(const std::string & name);
LoadablePtr getLoadable(const std::string & name) const;
protected:
virtual std::unique_ptr<IExternalLoadable> create(const std::string & name, const Configuration & config,
const std::string & config_prefix) = 0;
class LockedObjectsMap
{
public:
LockedObjectsMap(std::mutex & mutex, const ObjectsMap & objectsMap) : lock(mutex), objectsMap(objectsMap) {}
const ObjectsMap & get() { return objectsMap; }
private:
std::unique_lock<std::mutex> lock;
const ObjectsMap & objectsMap;
};
/// Direct access to objects.
LockedObjectsMap getObjectsMap() const;
/// Should be called in derived constructor (to avoid pure virtual call).
void init(bool throw_on_error);
private:
bool is_initialized = false;
/// Protects only objects map.
mutable std::mutex map_mutex;
/// Protects all data, currently used to avoid races between updating thread and SYSTEM queries
mutable std::mutex all_mutex;
/// name -> loadable.
ObjectsMap loadable_objects;
/// Here are loadable objects, that has been never loaded successfully.
/// They are also in 'loadable_objects', but with nullptr as 'loadable'.
std::unordered_map<std::string, FailedLoadableInfo> failed_loadable_objects;
/// Both for loadable_objects and failed_loadable_objects.
std::unordered_map<std::string, std::chrono::system_clock::time_point> update_times;
pcg64 rnd_engine{randomSeed()};
const Configuration & config;
const ExternalLoaderUpdateSettings & update_settings;
const ExternalLoaderConfigSettings & config_settings;
std::thread reloading_thread;
Poco::Event destroy;
Logger * log;
/// Loadable object name to use in log messages.
std::string object_name;
std::unordered_map<std::string, Poco::Timestamp> last_modification_times;
/// Check objects definitions in config files and reload or/and add new ones if the definition is changed
/// If loadable_name is not empty, load only loadable object with name loadable_name
void reloadFromConfigFiles(bool throw_on_error, bool force_reload = false, const std::string & loadable_name = "");
void reloadFromConfigFile(const std::string & config_path, bool throw_on_error, bool force_reload,
const std::string & loadable_name);
/// Check config files and update expired loadable objects
void reloadAndUpdate(bool throw_on_error = false);
void reloadPeriodically();
};
}

View File

@ -0,0 +1,67 @@
#include <Interpreters/ExternalModels.h>
#include <Interpreters/Context.h>
namespace DB
{
namespace ErrorCodes
{
extern const int INVALID_CONFIG_PARAMETER;
}
namespace
{
const ExternalLoaderUpdateSettings externalModelsUpdateSettings;
const ExternalLoaderConfigSettings & getExternalModelsConfigSettings()
{
static ExternalLoaderConfigSettings settings;
static std::once_flag flag;
std::call_once(flag, [] {
settings.external_config = "model";
settings.external_name = "name";
settings.path_setting_name = "models_config";
});
return settings;
}
}
ExternalModels::ExternalModels(Context & context, bool throw_on_error)
: ExternalLoader(context.getConfigRef(),
externalModelsUpdateSettings,
getExternalModelsConfigSettings(),
&Logger::get("ExternalModels"),
"external model"),
context(context)
{
init(throw_on_error);
}
std::unique_ptr<IExternalLoadable> ExternalModels::create(
const std::string & name, const Configuration & config, const std::string & config_prefix)
{
String type = config.getString(config_prefix + ".type");
ExternalLoadableLifetime lifetime(config, config_prefix + ".lifetime");
/// TODO: add models factory.
if (type == "catboost")
{
return std::make_unique<CatBoostModel>(
name, config.getString(config_prefix + ".path"),
context.getConfigRef().getString("catboost_dynamic_library_path"),
lifetime, config.getUInt(config_prefix + ".float_features_count"),
config.getUInt(config_prefix + ".cat_features_count")
);
}
else
{
throw Exception("Unknown model type: " + type, ErrorCodes::INVALID_CONFIG_PARAMETER);
}
}
}

View File

@ -0,0 +1,41 @@
#pragma once
#include <Dictionaries/CatBoostModel.h>
#include <Interpreters/ExternalLoader.h>
#include <common/logger_useful.h>
#include <memory>
namespace DB
{
class Context;
/// Manages user-defined models.
class ExternalModels : public ExternalLoader
{
public:
using ModelPtr = std::shared_ptr<IModel>;
/// Models will be loaded immediately and then will be updated in separate thread, each 'reload_period' seconds.
ExternalModels(Context & context, bool throw_on_error);
/// Forcibly reloads specified model.
void reloadModel(const std::string & name) { reload(name); }
ModelPtr getModel(const std::string & name) const
{
return std::static_pointer_cast<IModel>(getLoadable(name));
}
protected:
std::unique_ptr<IExternalLoadable> create(const std::string & name, const Configuration & config,
const std::string & config_prefix) override;
private:
Context & context;
};
}

View File

@ -0,0 +1,45 @@
#pragma once
#include <string>
#include <memory>
#include <Core/Types.h>
namespace Poco::Util
{
class AbstractConfiguration;
}
namespace DB
{
/// Min and max lifetimes for a loadable object or it's entry
struct ExternalLoadableLifetime final
{
UInt64 min_sec;
UInt64 max_sec;
ExternalLoadableLifetime(const Poco::Util::AbstractConfiguration & config, const std::string & config_prefix);
};
/// Basic interface for external loadable objects. Is used in ExternalLoader.
class IExternalLoadable : public std::enable_shared_from_this<IExternalLoadable>
{
public:
virtual ~IExternalLoadable() = default;
virtual const ExternalLoadableLifetime & getLifetime() const = 0;
virtual std::string getName() const = 0;
/// True if object can be updated when lifetime exceeded.
virtual bool supportUpdates() const = 0;
/// If lifetime exceeded and isModified() ExternalLoader replace current object with the result of clone().
virtual bool isModified() const = 0;
/// Returns new object with the same configuration. Is used to update modified object when lifetime exceeded.
virtual std::unique_ptr<IExternalLoadable> clone() const = 0;
virtual std::exception_ptr getCreationException() const = 0;
};
}

View File

@ -76,16 +76,17 @@ BlockInputStreams StorageSystemDictionaries::read(
ColumnWithTypeAndName col_source{std::make_shared<ColumnString>(), std::make_shared<DataTypeString>(), "source"};
const auto & external_dictionaries = context.getExternalDictionaries();
const std::lock_guard<std::mutex> lock{external_dictionaries.dictionaries_mutex};
auto objects_map = external_dictionaries.getObjectsMap();
const auto & dictionaries = objects_map.get();
for (const auto & dict_info : external_dictionaries.dictionaries)
for (const auto & dict_info : dictionaries)
{
col_name.column->insert(dict_info.first);
col_origin.column->insert(dict_info.second.origin);
if (dict_info.second.dict)
if (dict_info.second.loadable)
{
const auto dict_ptr = dict_info.second.dict->get();
const auto dict_ptr = std::static_pointer_cast<IDictionaryBase>(dict_info.second.loadable);
col_type.column->insert(dict_ptr->getTypeName());

View File

@ -0,0 +1,17 @@
#!/usr/bin/env bash
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
cd $DIR
git clone https://github.com/catboost/catboost.git
cd "${DIR}/catboost/catboost/libs/model_interface"
../../../ya make -r -o "${DIR}/build/lib" -j4
cd $DIR
ln -sf "${DIR}/build/lib/catboost/libs/model_interface/libcatboostmodel.so" libcatboostmodel.so
cd "${DIR}/catboost/catboost/python-package/catboost"
../../../ya make -r -DUSE_ARCADIA_PYTHON=no -DPYTHON_CONFIG=python2-config -j4
cd $DIR
ln -sf "${DIR}/catboost/catboost/python-package" python-package

View File

@ -0,0 +1,42 @@
import subprocess
import threading
import os
class ClickHouseClient:
def __init__(self, binary_path, port):
self.binary_path = binary_path
self.port = port
def query(self, query, timeout=10, pipe=None):
result = []
process = []
def run(path, port, text, result, in_pipe, process):
if in_pipe is None:
in_pipe = subprocess.PIPE
pipe = subprocess.Popen([path, 'client', '--port', str(port), '-q', text],
stdin=in_pipe, stdout=subprocess.PIPE, stderr=subprocess.PIPE, close_fds=True)
process.append(pipe)
stdout_data, stderr_data = pipe.communicate()
if stderr_data:
raise Exception('Error while executing query: {}\nstdout:\n{}\nstderr:\n{}'
.format(text, stdout_data, stderr_data))
result.append(stdout_data)
thread = threading.Thread(target=run, args=(self.binary_path, self.port, query, result, pipe, process))
thread.start()
thread.join(timeout)
if thread.isAlive():
if len(process):
process[0].kill()
thread.join()
raise Exception('timeout exceed for query: ' + query)
if len(result):
return result[0]

View File

@ -0,0 +1,15 @@
import numpy as np
def generate_uniform_int_column(size, low, high, seed=0):
np.random.seed(seed)
return np.random.randint(low, high, size)
def generate_uniform_float_column(size, low, high, seed=0):
np.random.seed(seed)
return np.random.random(size) * (high - low) + low
def generate_uniform_string_column(size, samples, seed):
return np.array(samples)[generate_uniform_int_column(size, 0, len(samples), seed)]

View File

@ -0,0 +1,67 @@
import subprocess
import threading
import socket
import time
class ClickHouseServer:
def __init__(self, binary_path, config_path, stdout_file=None, stderr_file=None, shutdown_timeout=10):
self.binary_path = binary_path
self.config_path = config_path
self.pipe = None
self.stdout_file = stdout_file
self.stderr_file = stderr_file
self.shutdown_timeout = shutdown_timeout
def start(self):
cmd = [self.binary_path, 'server', '--config', self.config_path]
out_pipe = None
err_pipe = None
if self.stdout_file is not None:
out_pipe = open(self.stdout_file, 'w')
if self.stderr_file is not None:
err_pipe = open(self.stderr_file, 'w')
self.pipe = subprocess.Popen(cmd, stdout=out_pipe, stderr=err_pipe)
def wait_for_request(self, port, timeout=1):
try:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# is not working
# s.settimeout(timeout)
step = 0.01
for iter in range(int(timeout / step)):
if s.connect_ex(('127.0.0.1', port)) == 0:
return
time.sleep(step)
s.connect(('127.0.0.1', port))
except socket.error as socketerror:
print "Error: ", socketerror
raise
def shutdown(self, timeout=10):
def wait(pipe):
pipe.wait()
if self.pipe is not None:
self.pipe.terminate()
thread = threading.Thread(target=wait, args=(self.pipe,))
thread.start()
thread.join(timeout)
if thread.isAlive():
self.pipe.kill()
thread.join()
if self.pipe.stdout is not None:
self.pipe.stdout.close()
if self.pipe.stderr is not None:
self.pipe.stderr.close()
def __enter__(self):
self.start()
return self
def __exit__(self, type, value, traceback):
self.shutdown(self.shutdown_timeout)

View File

@ -0,0 +1,166 @@
from server import ClickHouseServer
from client import ClickHouseClient
from table import ClickHouseTable
import os
import errno
from shutil import rmtree
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
CATBOOST_ROOT = os.path.dirname(SCRIPT_DIR)
CLICKHOUSE_CONFIG = \
'''
<yandex>
<timezone>Europe/Moscow</timezone>
<listen_host>::</listen_host>
<path>{path}</path>
<tmp_path>{tmp_path}</tmp_path>
<models_config>{models_config}</models_config>
<mark_cache_size>5368709120</mark_cache_size>
<users_config>users.xml</users_config>
<tcp_port>{tcp_port}</tcp_port>
<catboost_dynamic_library_path>{catboost_dynamic_library_path}</catboost_dynamic_library_path>
</yandex>
'''
CLICKHOUSE_USERS = \
'''
<yandex>
<profiles>
<default>
</default>
<readonly>
<readonly>1</readonly>
</readonly>
</profiles>
<users>
<readonly>
<password></password>
<profile>readonly</profile>
<quota>default</quota>
</readonly>
<default>
<password></password>
<profile>default</profile>
<quota>default</quota>
<networks incl="networks" replace="replace">
<ip>::1</ip>
<ip>127.0.0.1</ip>
</networks>
</default>
</users>
<quotas>
<default>
</default>
</quotas>
</yandex>
'''
CATBOOST_MODEL_CONFIG = \
'''
<models>
<model>
<type>catboost</type>
<name>{name}</name>
<path>{path}</path>
<float_features_count>{float_features_count}</float_features_count>
<cat_features_count>{cat_features_count}</cat_features_count>
<lifetime>0</lifetime>
</model>
</models>
'''
class ClickHouseServerWithCatboostModels:
def __init__(self, name, binary_path, port, shutdown_timeout=10, clean_folder=False):
self.models = {}
self.name = name
self.binary_path = binary_path
self.port = port
self.shutdown_timeout = shutdown_timeout
self.clean_folder = clean_folder
self.root = os.path.join(CATBOOST_ROOT, 'data', 'servers')
self.config_path = os.path.join(self.root, 'config.xml')
self.users_path = os.path.join(self.root, 'users.xml')
self.models_dir = os.path.join(self.root, 'models')
self.server = None
def _get_server(self):
stdout_file = os.path.join(self.root, 'server_stdout.txt')
stderr_file = os.path.join(self.root, 'server_stderr.txt')
return ClickHouseServer(self.binary_path, self.config_path, stdout_file, stderr_file, self.shutdown_timeout)
def add_model(self, model_name, model, float_features_count, cat_features_count):
self.models[model_name] = (float_features_count, cat_features_count, model)
def apply_model(self, name, df, cat_feature_names):
names = list(df)
float_feature_names = tuple(name for name in names if name not in cat_feature_names)
with ClickHouseTable(self.server, self.port, name, df) as table:
return table.apply_model(name, cat_feature_names, float_feature_names)
def _create_root(self):
try:
os.makedirs(self.root)
except OSError as exc: # Python >2.5
if exc.errno == errno.EEXIST and os.path.isdir(self.root):
pass
else:
raise
def _clean_root(self):
rmtree(self.root)
def _save_config(self):
params = {
'tcp_port': self.port,
'path': os.path.join(self.root, 'clickhouse'),
'tmp_path': os.path.join(self.root, 'clickhouse', 'tmp'),
'models_config': os.path.join(self.models_dir, '*_model.xml'),
'catboost_dynamic_library_path': os.path.join(CATBOOST_ROOT, 'data', 'libcatboostmodel.so')
}
config = CLICKHOUSE_CONFIG.format(**params)
with open(self.config_path, 'w') as f:
f.write(config)
with open(self.users_path, 'w') as f:
f.write(CLICKHOUSE_USERS)
def _save_models(self):
if not os.path.exists(self.models_dir):
os.makedirs(self.models_dir)
for name, params in self.models.items():
float_features_count, cat_features_count, model = params
model_path = os.path.join(self.models_dir, name + '.cbm')
config_path = os.path.join(self.models_dir, name + '_model.xml')
params = {
'name': name,
'path': model_path,
'float_features_count': float_features_count,
'cat_features_count': cat_features_count
}
config = CATBOOST_MODEL_CONFIG.format(**params)
with open(config_path, 'w') as f:
f.write(config)
model.save_model(model_path)
def __enter__(self):
self._create_root()
self._save_config()
self._save_models()
self.server = self._get_server().__enter__()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
res = self.server.__exit__(exc_type, exc_val, exc_tb)
if self.clean_folder:
self._clean_root()
return res

View File

@ -0,0 +1,69 @@
from server import ClickHouseServer
from client import ClickHouseClient
from pandas import DataFrame
import os
import threading
import tempfile
class ClickHouseTable:
def __init__(self, server, port, table_name, df):
self.server = server
self.port = port
self.table_name = table_name
self.df = df
if not isinstance(self.server, ClickHouseServer):
raise Exception('Expected ClickHouseServer, got ' + repr(self.server))
if not isinstance(self.df, DataFrame):
raise Exception('Expected DataFrame, got ' + repr(self.df))
self.server.wait_for_request(port)
self.client = ClickHouseClient(server.binary_path, port)
def _convert(self, name):
types_map = {
'float64': 'Float64',
'int64': 'Int64',
'float32': 'Float32',
'int32': 'Int32'
}
if name in types_map:
return types_map[name]
return 'String'
def _create_table_from_df(self):
self.client.query('create database if not exists test')
self.client.query('drop table if exists test.{}'.format(self.table_name))
column_types = list(self.df.dtypes)
column_names = list(self.df)
schema = ', '.join((name + ' ' + self._convert(str(t)) for name, t in zip(column_names, column_types)))
print 'schema:', schema
create_query = 'create table test.{} (date Date DEFAULT today(), {}) engine = MergeTree(date, (date), 8192)'
self.client.query(create_query.format(self.table_name, schema))
insert_query = 'insert into test.{} ({}) format CSV'
with tempfile.TemporaryFile() as tmp_file:
self.df.to_csv(tmp_file, header=False, index=False)
tmp_file.seek(0)
self.client.query(insert_query.format(self.table_name, ', '.join(column_names)), pipe=tmp_file)
def apply_model(self, model_name, float_columns, cat_columns):
columns = ', '.join(list(float_columns) + list(cat_columns))
query = "select modelEvaluate('{}', {}) from test.{} format TSV"
result = self.client.query(query.format(model_name, columns, self.table_name))
return tuple(map(float, filter(len, map(str.strip, result.split()))))
def _drop_table(self):
self.client.query('drop table test.{}'.format(self.table_name))
def __enter__(self):
self._create_table_from_df()
return self
def __exit__(self, type, value, traceback):
self._drop_table()

View File

@ -0,0 +1,28 @@
import os
import sys
from pandas import DataFrame
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
CATBOOST_ROOT = os.path.dirname(SCRIPT_DIR)
CATBOOST_PYTHON_DIR = os.path.join(CATBOOST_ROOT, 'data', 'python-package')
if CATBOOST_PYTHON_DIR not in sys.path:
sys.path.append(CATBOOST_PYTHON_DIR)
import catboost
from catboost import CatBoostClassifier
def train_catboost_model(df, target, cat_features, params, verbose=True):
if not isinstance(df, DataFrame):
raise Exception('DataFrame object expected, but got ' + repr(df))
print 'features:', df.columns.tolist()
cat_features_index = list(df.columns.get_loc(feature) for feature in cat_features)
print 'cat features:', cat_features_index
model = CatBoostClassifier(**params)
model.fit(df, target, cat_features=cat_features_index, verbose=verbose)
return model

View File

@ -0,0 +1,3 @@
[pytest]
python_files = test.py
norecursedirs=data

View File

@ -0,0 +1,236 @@
from helpers.server_with_models import ClickHouseServerWithCatboostModels
from helpers.generate import generate_uniform_string_column, generate_uniform_float_column, generate_uniform_int_column
from helpers.train import train_catboost_model
import os
import numpy as np
from pandas import DataFrame
PORT = int(os.environ.get('CLICKHOUSE_TESTS_PORT', '9000'))
CLICKHOUSE_TESTS_SERVER_BIN_PATH = os.environ.get('CLICKHOUSE_TESTS_SERVER_BIN_PATH', '/usr/bin/clickhouse')
def add_noise_to_target(target, seed, threshold=0.05):
col = generate_uniform_float_column(len(target), 0., 1., seed + 1) < threshold
return target * (1 - col) + (1 - target) * col
def check_predictions(test_name, target, pred_python, pred_ch, acc_threshold):
ch_class = pred_ch.astype(int)
python_class = pred_python.astype(int)
if not np.array_equal(ch_class, python_class):
raise Exception('Got different results:\npython:\n' + str(python_class) + '\nClickHouse:\n' + str(ch_class))
acc = 1 - np.sum(np.abs(ch_class - np.array(target))) / (len(target) + .0)
assert acc >= acc_threshold
print test_name, 'accuracy: {:.10f}'.format(acc)
def test_apply_float_features_only():
name = 'test_apply_float_features_only'
train_size = 10000
test_size = 10000
def gen_data(size, seed):
data = {
'a': generate_uniform_float_column(size, 0., 1., seed + 1),
'b': generate_uniform_float_column(size, 0., 1., seed + 2),
'c': generate_uniform_float_column(size, 0., 1., seed + 3)
}
return DataFrame.from_dict(data)
def get_target(df):
def target_filter(row):
return 1 if (row['a'] > .3 and row['b'] > .3) or (row['c'] < .4 and row['a'] * row['b'] > 0.1) else 0
return df.apply(target_filter, axis=1).as_matrix()
train_df = gen_data(train_size, 42)
test_df = gen_data(test_size, 43)
train_target = get_target(train_df)
test_target = get_target(test_df)
print
print 'train target', train_target
print 'test target', test_target
params = {
'iterations': 4,
'depth': 2,
'learning_rate': 1,
'loss_function': 'Logloss'
}
model = train_catboost_model(train_df, train_target, [], params)
pred_python = model.predict(test_df)
server = ClickHouseServerWithCatboostModels(name, CLICKHOUSE_TESTS_SERVER_BIN_PATH, PORT)
server.add_model(name, model, 3, 0)
with server:
pred_ch = (np.array(server.apply_model(name, test_df, [])) > 0).astype(int)
print 'python predictions', pred_python
print 'clickhouse predictions', pred_ch
check_predictions(name, test_target, pred_python, pred_ch, 0.9)
def test_apply_float_features_with_string_cat_features():
name = 'test_apply_float_features_with_string_cat_features'
train_size = 10000
test_size = 10000
def gen_data(size, seed):
data = {
'a': generate_uniform_float_column(size, 0., 1., seed + 1),
'b': generate_uniform_float_column(size, 0., 1., seed + 2),
'c': generate_uniform_string_column(size, ['a', 'b', 'c'], seed + 3),
'd': generate_uniform_string_column(size, ['e', 'f', 'g'], seed + 4)
}
return DataFrame.from_dict(data)
def get_target(df):
def target_filter(row):
return 1 if (row['a'] > .3 and row['b'] > .3 and row['c'] != 'a') \
or (row['a'] * row['b'] > 0.1 and row['c'] != 'b' and row['d'] != 'e') else 0
return df.apply(target_filter, axis=1).as_matrix()
train_df = gen_data(train_size, 42)
test_df = gen_data(test_size, 43)
train_target = get_target(train_df)
test_target = get_target(test_df)
print
print 'train target', train_target
print 'test target', test_target
params = {
'iterations': 6,
'depth': 2,
'learning_rate': 1,
'loss_function': 'Logloss'
}
model = train_catboost_model(train_df, train_target, ['c', 'd'], params)
pred_python = model.predict(test_df)
server = ClickHouseServerWithCatboostModels(name, CLICKHOUSE_TESTS_SERVER_BIN_PATH, PORT)
server.add_model(name, model, 2, 2)
with server:
pred_ch = (np.array(server.apply_model(name, test_df, [])) > 0).astype(int)
print 'python predictions', pred_python
print 'clickhouse predictions', pred_ch
check_predictions(name, test_target, pred_python, pred_ch, 0.9)
def test_apply_float_features_with_int_cat_features():
name = 'test_apply_float_features_with_int_cat_features'
train_size = 10000
test_size = 10000
def gen_data(size, seed):
data = {
'a': generate_uniform_float_column(size, 0., 1., seed + 1),
'b': generate_uniform_float_column(size, 0., 1., seed + 2),
'c': generate_uniform_int_column(size, 1, 4, seed + 3),
'd': generate_uniform_int_column(size, 1, 4, seed + 4)
}
return DataFrame.from_dict(data)
def get_target(df):
def target_filter(row):
return 1 if (row['a'] > .3 and row['b'] > .3 and row['c'] != 1) \
or (row['a'] * row['b'] > 0.1 and row['c'] != 2 and row['d'] != 3) else 0
return df.apply(target_filter, axis=1).as_matrix()
train_df = gen_data(train_size, 42)
test_df = gen_data(test_size, 43)
train_target = get_target(train_df)
test_target = get_target(test_df)
print
print 'train target', train_target
print 'test target', test_target
params = {
'iterations': 6,
'depth': 4,
'learning_rate': 1,
'loss_function': 'Logloss'
}
model = train_catboost_model(train_df, train_target, ['c', 'd'], params)
pred_python = model.predict(test_df)
server = ClickHouseServerWithCatboostModels(name, CLICKHOUSE_TESTS_SERVER_BIN_PATH, PORT)
server.add_model(name, model, 2, 2)
with server:
pred_ch = (np.array(server.apply_model(name, test_df, [])) > 0).astype(int)
print 'python predictions', pred_python
print 'clickhouse predictions', pred_ch
check_predictions(name, test_target, pred_python, pred_ch, 0.9)
def test_apply_float_features_with_mixed_cat_features():
name = 'test_apply_float_features_with_mixed_cat_features'
train_size = 10000
test_size = 10000
def gen_data(size, seed):
data = {
'a': generate_uniform_float_column(size, 0., 1., seed + 1),
'b': generate_uniform_float_column(size, 0., 1., seed + 2),
'c': generate_uniform_string_column(size, ['a', 'b', 'c'], seed + 3),
'd': generate_uniform_int_column(size, 1, 4, seed + 4)
}
return DataFrame.from_dict(data)
def get_target(df):
def target_filter(row):
return 1 if (row['a'] > .3 and row['b'] > .3 and row['c'] != 'a') \
or (row['a'] * row['b'] > 0.1 and row['c'] != 'b' and row['d'] != 2) else 0
return df.apply(target_filter, axis=1).as_matrix()
train_df = gen_data(train_size, 42)
test_df = gen_data(test_size, 43)
train_target = get_target(train_df)
test_target = get_target(test_df)
print
print 'train target', train_target
print 'test target', test_target
params = {
'iterations': 6,
'depth': 4,
'learning_rate': 1,
'loss_function': 'Logloss'
}
model = train_catboost_model(train_df, train_target, ['c', 'd'], params)
pred_python = model.predict(test_df)
server = ClickHouseServerWithCatboostModels(name, CLICKHOUSE_TESTS_SERVER_BIN_PATH, PORT)
server.add_model(name, model, 2, 2)
with server:
pred_ch = (np.array(server.apply_model(name, test_df, [])) > 0).astype(int)
print 'python predictions', pred_python
print 'clickhouse predictions', pred_ch
check_predictions(name, test_target, pred_python, pred_ch, 0.9)