mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-22 15:42:02 +00:00
commit
56ef2e9196
@ -384,6 +384,8 @@ namespace ErrorCodes
|
|||||||
extern const int UNKNOWN_STATUS_OF_DISTRIBUTED_DDL_TASK = 379;
|
extern const int UNKNOWN_STATUS_OF_DISTRIBUTED_DDL_TASK = 379;
|
||||||
extern const int CANNOT_KILL = 380;
|
extern const int CANNOT_KILL = 380;
|
||||||
extern const int HTTP_LENGTH_REQUIRED = 381;
|
extern const int HTTP_LENGTH_REQUIRED = 381;
|
||||||
|
extern const int CANNOT_LOAD_CATBOOST_MODEL = 382;
|
||||||
|
extern const int CANNOT_APPLY_CATBOOST_MODEL = 383;
|
||||||
|
|
||||||
extern const int KEEPER_EXCEPTION = 999;
|
extern const int KEEPER_EXCEPTION = 999;
|
||||||
extern const int POCO_EXCEPTION = 1000;
|
extern const int POCO_EXCEPTION = 1000;
|
||||||
|
@ -26,18 +26,19 @@ void DatabaseDictionary::loadTables(Context & context, ThreadPool * thread_pool,
|
|||||||
|
|
||||||
Tables DatabaseDictionary::loadTables()
|
Tables DatabaseDictionary::loadTables()
|
||||||
{
|
{
|
||||||
const std::lock_guard<std::mutex> lock_dictionaries {external_dictionaries.dictionaries_mutex};
|
auto objects_map = external_dictionaries.getObjectsMap();
|
||||||
|
const auto & dictionaries = objects_map.get();
|
||||||
|
|
||||||
Tables tables;
|
Tables tables;
|
||||||
for (const auto & pair : external_dictionaries.dictionaries)
|
for (const auto & pair : dictionaries)
|
||||||
{
|
{
|
||||||
const std::string & name = pair.first;
|
const std::string & name = pair.first;
|
||||||
if (deleted_tables.count(name))
|
if (deleted_tables.count(name))
|
||||||
continue;
|
continue;
|
||||||
auto dict_ptr = pair.second.dict;
|
auto dict_ptr = std::static_pointer_cast<IDictionaryBase>(pair.second.loadable);
|
||||||
if (dict_ptr)
|
if (dict_ptr)
|
||||||
{
|
{
|
||||||
const DictionaryStructure & dictionary_structure = dict_ptr->get()->getStructure();
|
const DictionaryStructure & dictionary_structure = dict_ptr->getStructure();
|
||||||
auto columns = StorageDictionary::getNamesAndTypes(dictionary_structure);
|
auto columns = StorageDictionary::getNamesAndTypes(dictionary_structure);
|
||||||
tables[name] = StorageDictionary::create(name, columns, {}, {}, {}, dictionary_structure, name);
|
tables[name] = StorageDictionary::create(name, columns, {}, {}, {}, dictionary_structure, name);
|
||||||
}
|
}
|
||||||
@ -50,26 +51,28 @@ bool DatabaseDictionary::isTableExist(
|
|||||||
const Context & context,
|
const Context & context,
|
||||||
const String & table_name) const
|
const String & table_name) const
|
||||||
{
|
{
|
||||||
const std::lock_guard<std::mutex> lock_dictionaries {external_dictionaries.dictionaries_mutex};
|
auto objects_map = external_dictionaries.getObjectsMap();
|
||||||
return external_dictionaries.dictionaries.count(table_name) && !deleted_tables.count(table_name);
|
const auto & dictionaries = objects_map.get();
|
||||||
|
return dictionaries.count(table_name) && !deleted_tables.count(table_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
StoragePtr DatabaseDictionary::tryGetTable(
|
StoragePtr DatabaseDictionary::tryGetTable(
|
||||||
const Context & context,
|
const Context & context,
|
||||||
const String & table_name)
|
const String & table_name)
|
||||||
{
|
{
|
||||||
const std::lock_guard<std::mutex> lock_dictionaries {external_dictionaries.dictionaries_mutex};
|
auto objects_map = external_dictionaries.getObjectsMap();
|
||||||
|
const auto & dictionaries = objects_map.get();
|
||||||
|
|
||||||
if (deleted_tables.count(table_name))
|
if (deleted_tables.count(table_name))
|
||||||
return {};
|
return {};
|
||||||
{
|
{
|
||||||
auto it = external_dictionaries.dictionaries.find(table_name);
|
auto it = dictionaries.find(table_name);
|
||||||
if (it != external_dictionaries.dictionaries.end())
|
if (it != dictionaries.end())
|
||||||
{
|
{
|
||||||
const auto & dict_ptr = it->second.dict;
|
const auto & dict_ptr = std::static_pointer_cast<IDictionaryBase>(it->second.loadable);
|
||||||
if (dict_ptr)
|
if (dict_ptr)
|
||||||
{
|
{
|
||||||
const DictionaryStructure & dictionary_structure = dict_ptr->get()->getStructure();
|
const DictionaryStructure & dictionary_structure = dict_ptr->getStructure();
|
||||||
auto columns = StorageDictionary::getNamesAndTypes(dictionary_structure);
|
auto columns = StorageDictionary::getNamesAndTypes(dictionary_structure);
|
||||||
return StorageDictionary::create(table_name, columns, {}, {}, {}, dictionary_structure, table_name);
|
return StorageDictionary::create(table_name, columns, {}, {}, {}, dictionary_structure, table_name);
|
||||||
}
|
}
|
||||||
@ -86,9 +89,10 @@ DatabaseIteratorPtr DatabaseDictionary::getIterator(const Context & context)
|
|||||||
|
|
||||||
bool DatabaseDictionary::empty(const Context & context) const
|
bool DatabaseDictionary::empty(const Context & context) const
|
||||||
{
|
{
|
||||||
const std::lock_guard<std::mutex> lock_dictionaries {external_dictionaries.dictionaries_mutex};
|
auto objects_map = external_dictionaries.getObjectsMap();
|
||||||
for (const auto & pair : external_dictionaries.dictionaries)
|
const auto & dictionaries = objects_map.get();
|
||||||
if (pair.second.dict && !deleted_tables.count(pair.first))
|
for (const auto & pair : dictionaries)
|
||||||
|
if (pair.second.loadable && !deleted_tables.count(pair.first))
|
||||||
return false;
|
return false;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -119,7 +123,7 @@ void DatabaseDictionary::removeTable(
|
|||||||
if (!isTableExist(context, table_name))
|
if (!isTableExist(context, table_name))
|
||||||
throw Exception("Table " + name + "." + table_name + " doesn't exist.", ErrorCodes::UNKNOWN_TABLE);
|
throw Exception("Table " + name + "." + table_name + " doesn't exist.", ErrorCodes::UNKNOWN_TABLE);
|
||||||
|
|
||||||
const std::lock_guard<std::mutex> lock_dictionaries {external_dictionaries.dictionaries_mutex};
|
auto objects_map = external_dictionaries.getObjectsMap();
|
||||||
deleted_tables.insert(table_name);
|
deleted_tables.insert(table_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -156,7 +160,6 @@ ASTPtr DatabaseDictionary::getCreateQuery(
|
|||||||
const String & table_name) const
|
const String & table_name) const
|
||||||
{
|
{
|
||||||
throw Exception("DatabaseDictionary: getCreateQuery() is not supported", ErrorCodes::NOT_IMPLEMENTED);
|
throw Exception("DatabaseDictionary: getCreateQuery() is not supported", ErrorCodes::NOT_IMPLEMENTED);
|
||||||
return nullptr;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void DatabaseDictionary::shutdown()
|
void DatabaseDictionary::shutdown()
|
||||||
|
@ -54,7 +54,7 @@ public:
|
|||||||
|
|
||||||
bool isCached() const override { return true; }
|
bool isCached() const override { return true; }
|
||||||
|
|
||||||
DictionaryPtr clone() const override { return std::make_unique<CacheDictionary>(*this); }
|
std::unique_ptr<IExternalLoadable> clone() const override { return std::make_unique<CacheDictionary>(*this); }
|
||||||
|
|
||||||
const IDictionarySource * getSource() const override { return source_ptr.get(); }
|
const IDictionarySource * getSource() const override { return source_ptr.get(); }
|
||||||
|
|
||||||
|
523
dbms/src/Dictionaries/CatBoostModel.cpp
Normal file
523
dbms/src/Dictionaries/CatBoostModel.cpp
Normal file
@ -0,0 +1,523 @@
|
|||||||
|
#include <Dictionaries/CatBoostModel.h>
|
||||||
|
#include <Core/FieldVisitors.h>
|
||||||
|
#include <mutex>
|
||||||
|
#include <Columns/ColumnString.h>
|
||||||
|
#include <Columns/ColumnFixedString.h>
|
||||||
|
#include <Columns/ColumnVector.h>
|
||||||
|
#include <Common/typeid_cast.h>
|
||||||
|
#include <IO/WriteBufferFromString.h>
|
||||||
|
#include <IO/Operators.h>
|
||||||
|
#include <Common/PODArray.h>
|
||||||
|
#include <Common/SharedLibrary.h>
|
||||||
|
|
||||||
|
namespace DB
|
||||||
|
{
|
||||||
|
|
||||||
|
namespace ErrorCodes
|
||||||
|
{
|
||||||
|
extern const int LOGICAL_ERROR;
|
||||||
|
extern const int BAD_ARGUMENTS;
|
||||||
|
extern const int CANNOT_LOAD_CATBOOST_MODEL;
|
||||||
|
extern const int CANNOT_APPLY_CATBOOST_MODEL;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/// CatBoost wrapper interface functions.
|
||||||
|
struct CatBoostWrapperAPI
|
||||||
|
{
|
||||||
|
typedef void ModelCalcerHandle;
|
||||||
|
|
||||||
|
ModelCalcerHandle * (* ModelCalcerCreate)();
|
||||||
|
|
||||||
|
void (* ModelCalcerDelete)(ModelCalcerHandle * calcer);
|
||||||
|
|
||||||
|
const char * (* GetErrorString)();
|
||||||
|
|
||||||
|
bool (* LoadFullModelFromFile)(ModelCalcerHandle * calcer, const char * filename);
|
||||||
|
|
||||||
|
bool (* CalcModelPredictionFlat)(ModelCalcerHandle * calcer, size_t docCount,
|
||||||
|
const float ** floatFeatures, size_t floatFeaturesSize,
|
||||||
|
double * result, size_t resultSize);
|
||||||
|
|
||||||
|
bool (* CalcModelPrediction)(ModelCalcerHandle * calcer, size_t docCount,
|
||||||
|
const float ** floatFeatures, size_t floatFeaturesSize,
|
||||||
|
const char *** catFeatures, size_t catFeaturesSize,
|
||||||
|
double * result, size_t resultSize);
|
||||||
|
|
||||||
|
bool (* CalcModelPredictionWithHashedCatFeatures)(ModelCalcerHandle * calcer, size_t docCount,
|
||||||
|
const float ** floatFeatures, size_t floatFeaturesSize,
|
||||||
|
const int ** catFeatures, size_t catFeaturesSize,
|
||||||
|
double * result, size_t resultSize);
|
||||||
|
|
||||||
|
int (* GetStringCatFeatureHash)(const char * data, size_t size);
|
||||||
|
|
||||||
|
int (* GetIntegerCatFeatureHash)(long long val);
|
||||||
|
|
||||||
|
size_t (* GetFloatFeaturesCount)(ModelCalcerHandle* calcer);
|
||||||
|
|
||||||
|
size_t (* GetCatFeaturesCount)(ModelCalcerHandle* calcer);
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
namespace
|
||||||
|
{
|
||||||
|
|
||||||
|
class CatBoostModelHolder
|
||||||
|
{
|
||||||
|
private:
|
||||||
|
CatBoostWrapperAPI::ModelCalcerHandle * handle;
|
||||||
|
const CatBoostWrapperAPI * api;
|
||||||
|
public:
|
||||||
|
explicit CatBoostModelHolder(const CatBoostWrapperAPI * api) : api(api) { handle = api->ModelCalcerCreate(); }
|
||||||
|
~CatBoostModelHolder() { api->ModelCalcerDelete(handle); }
|
||||||
|
|
||||||
|
CatBoostWrapperAPI::ModelCalcerHandle * get() { return handle; }
|
||||||
|
explicit operator CatBoostWrapperAPI::ModelCalcerHandle * () { return handle; }
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
class CatBoostModelImpl : public ICatBoostModel
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
CatBoostModelImpl(const CatBoostWrapperAPI * api, const std::string & model_path) : api(api)
|
||||||
|
{
|
||||||
|
auto handle_ = std::make_unique<CatBoostModelHolder>(api);
|
||||||
|
if (!handle_)
|
||||||
|
{
|
||||||
|
std::string msg = "Cannot create CatBoost model: ";
|
||||||
|
throw Exception(msg + api->GetErrorString(), ErrorCodes::CANNOT_LOAD_CATBOOST_MODEL);
|
||||||
|
}
|
||||||
|
if (!api->LoadFullModelFromFile(handle_->get(), model_path.c_str()))
|
||||||
|
{
|
||||||
|
std::string msg = "Cannot load CatBoost model: ";
|
||||||
|
throw Exception(msg + api->GetErrorString(), ErrorCodes::CANNOT_LOAD_CATBOOST_MODEL);
|
||||||
|
}
|
||||||
|
handle = std::move(handle_);
|
||||||
|
}
|
||||||
|
|
||||||
|
ColumnPtr evaluate(const ConstColumnPlainPtrs & columns,
|
||||||
|
size_t float_features_count, size_t cat_features_count) const override
|
||||||
|
{
|
||||||
|
checkFeaturesCount(float_features_count, cat_features_count);
|
||||||
|
|
||||||
|
if (columns.empty())
|
||||||
|
throw Exception("Got empty columns list for CatBoost model.", ErrorCodes::BAD_ARGUMENTS);
|
||||||
|
|
||||||
|
if (columns.size() != float_features_count + cat_features_count)
|
||||||
|
{
|
||||||
|
std::string msg;
|
||||||
|
{
|
||||||
|
WriteBufferFromString buffer(msg);
|
||||||
|
buffer << "Number of columns is different with number of features: ";
|
||||||
|
buffer << columns.size() << " vs " << float_features_count << " + " << cat_features_count;
|
||||||
|
}
|
||||||
|
throw Exception(msg, ErrorCodes::BAD_ARGUMENTS);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < float_features_count; ++i)
|
||||||
|
{
|
||||||
|
if (!columns[i]->isNumeric())
|
||||||
|
{
|
||||||
|
std::string msg;
|
||||||
|
{
|
||||||
|
WriteBufferFromString buffer(msg);
|
||||||
|
buffer << "Column " << i << "should be numeric to make float feature.";
|
||||||
|
}
|
||||||
|
throw Exception(msg, ErrorCodes::BAD_ARGUMENTS);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool cat_features_are_strings = true;
|
||||||
|
for (size_t i = float_features_count; i < float_features_count + cat_features_count; ++i)
|
||||||
|
{
|
||||||
|
auto column = columns[i];
|
||||||
|
if (column->isNumeric())
|
||||||
|
cat_features_are_strings = false;
|
||||||
|
else if (!(typeid_cast<const ColumnString *>(column)
|
||||||
|
|| typeid_cast<const ColumnFixedString *>(column)))
|
||||||
|
{
|
||||||
|
std::string msg;
|
||||||
|
{
|
||||||
|
WriteBufferFromString buffer(msg);
|
||||||
|
buffer << "Column " << i << "should be numeric or string.";
|
||||||
|
}
|
||||||
|
throw Exception(msg, ErrorCodes::BAD_ARGUMENTS);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return evalImpl(columns, float_features_count, cat_features_count, cat_features_are_strings);
|
||||||
|
}
|
||||||
|
|
||||||
|
void checkFeaturesCount(size_t float_features_count, size_t cat_features_count) const
|
||||||
|
{
|
||||||
|
if (api->GetFloatFeaturesCount)
|
||||||
|
{
|
||||||
|
size_t float_features_in_model = api->GetFloatFeaturesCount(handle->get());
|
||||||
|
if (float_features_count != float_features_in_model)
|
||||||
|
throw Exception("CatBoost model expected " + std::to_string(float_features_in_model) + " float features"
|
||||||
|
+ ", but " + std::to_string(float_features_count) + " was provided.");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (api->GetCatFeaturesCount)
|
||||||
|
{
|
||||||
|
size_t cat_features_in_model = api->GetCatFeaturesCount(handle->get());
|
||||||
|
if (cat_features_count != cat_features_in_model)
|
||||||
|
throw Exception("CatBoost model expected " + std::to_string(cat_features_in_model) + " cat features"
|
||||||
|
+ ", but " + std::to_string(cat_features_count) + " was provided.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unique_ptr<CatBoostModelHolder> handle;
|
||||||
|
const CatBoostWrapperAPI * api;
|
||||||
|
|
||||||
|
/// Buffer should be allocated with features_count * column->size() elements.
|
||||||
|
/// Place column elements in positions buffer[0], buffer[features_count], ... , buffer[size * features_count]
|
||||||
|
template <typename T>
|
||||||
|
void placeColumnAsNumber(const IColumn * column, T * buffer, size_t features_count) const
|
||||||
|
{
|
||||||
|
size_t size = column->size();
|
||||||
|
FieldVisitorConvertToNumber<T> visitor;
|
||||||
|
for (size_t i = 0; i < size; ++i)
|
||||||
|
{
|
||||||
|
/// TODO: Replace with column visitor.
|
||||||
|
Field field;
|
||||||
|
column->get(i, field);
|
||||||
|
*buffer = applyVisitor(visitor, field);
|
||||||
|
buffer += features_count;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Buffer should be allocated with features_count * column->size() elements.
|
||||||
|
/// Place string pointers in positions buffer[0], buffer[features_count], ... , buffer[size * features_count]
|
||||||
|
void placeStringColumn(const ColumnString & column, const char ** buffer, size_t features_count) const
|
||||||
|
{
|
||||||
|
size_t size = column.size();
|
||||||
|
for (size_t i = 0; i < size; ++i)
|
||||||
|
{
|
||||||
|
*buffer = const_cast<char *>(column.getDataAtWithTerminatingZero(i).data);
|
||||||
|
buffer += features_count;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Buffer should be allocated with features_count * column->size() elements.
|
||||||
|
/// Place string pointers in positions buffer[0], buffer[features_count], ... , buffer[size * features_count]
|
||||||
|
/// Returns PODArray which holds data (because ColumnFixedString doesn't store terminating zero).
|
||||||
|
PODArray<char> placeFixedStringColumn(
|
||||||
|
const ColumnFixedString & column, const char ** buffer, size_t features_count) const
|
||||||
|
{
|
||||||
|
size_t size = column.size();
|
||||||
|
size_t str_size = column.getN();
|
||||||
|
PODArray<char> data(size * (str_size + 1));
|
||||||
|
char * data_ptr = data.data();
|
||||||
|
|
||||||
|
for (size_t i = 0; i < size; ++i)
|
||||||
|
{
|
||||||
|
auto ref = column.getDataAt(i);
|
||||||
|
memcpy(data_ptr, ref.data, ref.size);
|
||||||
|
data_ptr[ref.size] = 0;
|
||||||
|
*buffer = data_ptr;
|
||||||
|
data_ptr += ref.size + 1;
|
||||||
|
buffer += features_count;
|
||||||
|
}
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Place columns into buffer, returns column which holds placed data. Buffer should contains column->size() values.
|
||||||
|
template <typename T>
|
||||||
|
ColumnPtr placeNumericColumns(const ConstColumnPlainPtrs & columns,
|
||||||
|
size_t offset, size_t size, const T** buffer) const
|
||||||
|
{
|
||||||
|
if (size == 0)
|
||||||
|
return nullptr;
|
||||||
|
size_t column_size = columns[offset]->size();
|
||||||
|
auto data_column = std::make_shared<ColumnVector<T>>(size * column_size);
|
||||||
|
T* data = data_column->getData().data();
|
||||||
|
for (size_t i = 0; i < size; ++i)
|
||||||
|
{
|
||||||
|
auto column = columns[offset + i];
|
||||||
|
if (column->isNumeric())
|
||||||
|
placeColumnAsNumber(column, data + i, size);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < column_size; ++i)
|
||||||
|
{
|
||||||
|
*buffer = data;
|
||||||
|
++buffer;
|
||||||
|
data += size;
|
||||||
|
}
|
||||||
|
|
||||||
|
return data_column;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Place columns into buffer, returns data which was used for fixed string columns.
|
||||||
|
/// Buffer should contains column->size() values, each value contains size strings.
|
||||||
|
std::vector<PODArray<char>> placeStringColumns(
|
||||||
|
const ConstColumnPlainPtrs & columns, size_t offset, size_t size, const char ** buffer) const
|
||||||
|
{
|
||||||
|
if (size == 0)
|
||||||
|
return {};
|
||||||
|
|
||||||
|
std::vector<PODArray<char>> data;
|
||||||
|
for (size_t i = 0; i < size; ++i)
|
||||||
|
{
|
||||||
|
auto column = columns[offset + i];
|
||||||
|
if (auto column_string = typeid_cast<const ColumnString *>(column))
|
||||||
|
placeStringColumn(*column_string, buffer + i, size);
|
||||||
|
else if (auto column_fixed_string = typeid_cast<const ColumnFixedString *>(column))
|
||||||
|
data.push_back(placeFixedStringColumn(*column_fixed_string, buffer + i, size));
|
||||||
|
else
|
||||||
|
throw Exception("Cannot place string column.", ErrorCodes::LOGICAL_ERROR);
|
||||||
|
}
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Calc hash for string cat feature at ps positions.
|
||||||
|
template <typename Column>
|
||||||
|
void calcStringHashes(const Column * column, size_t ps, const int ** buffer) const
|
||||||
|
{
|
||||||
|
size_t column_size = column->size();
|
||||||
|
for (size_t j = 0; j < column_size; ++j)
|
||||||
|
{
|
||||||
|
auto ref = column->getDataAt(j);
|
||||||
|
const_cast<int *>(*buffer)[ps] = api->GetStringCatFeatureHash(ref.data, ref.size);
|
||||||
|
++buffer;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Calc hash for int cat feature at ps position. Buffer at positions ps should contains unhashed values.
|
||||||
|
void calcIntHashes(size_t column_size, size_t ps, const int ** buffer) const
|
||||||
|
{
|
||||||
|
for (size_t j = 0; j < column_size; ++j)
|
||||||
|
{
|
||||||
|
const_cast<int *>(*buffer)[ps] = api->GetIntegerCatFeatureHash((*buffer)[ps]);
|
||||||
|
++buffer;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// buffer contains column->size() rows and size columns.
|
||||||
|
/// For int cat features calc hash inplace.
|
||||||
|
/// For string cat features calc hash from column rows.
|
||||||
|
void calcHashes(const ConstColumnPlainPtrs & columns, size_t offset, size_t size, const int ** buffer) const
|
||||||
|
{
|
||||||
|
if (size == 0)
|
||||||
|
return;
|
||||||
|
size_t column_size = columns[offset]->size();
|
||||||
|
|
||||||
|
std::vector<PODArray<char>> data;
|
||||||
|
for (size_t i = 0; i < size; ++i)
|
||||||
|
{
|
||||||
|
auto column = columns[offset + i];
|
||||||
|
if (auto column_string = typeid_cast<const ColumnString *>(column))
|
||||||
|
calcStringHashes(column_string, i, buffer);
|
||||||
|
else if (auto column_fixed_string = typeid_cast<const ColumnFixedString *>(column))
|
||||||
|
calcStringHashes(column_fixed_string, i, buffer);
|
||||||
|
else
|
||||||
|
calcIntHashes(column_size, i, buffer);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// buffer[column_size * cat_features_count] -> char * => cat_features[column_size][cat_features_count] -> char *
|
||||||
|
void fillCatFeaturesBuffer(const char *** cat_features, const char ** buffer,
|
||||||
|
size_t column_size, size_t cat_features_count) const
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < column_size; ++i)
|
||||||
|
{
|
||||||
|
*cat_features = buffer;
|
||||||
|
++cat_features;
|
||||||
|
buffer += cat_features_count;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert values to row-oriented format and call evaluation function from CatBoost wrapper api.
|
||||||
|
/// * CalcModelPredictionFlat if no cat features
|
||||||
|
/// * CalcModelPrediction if all cat features are strings
|
||||||
|
/// * CalcModelPredictionWithHashedCatFeatures if has int cat features.
|
||||||
|
ColumnPtr evalImpl(const ConstColumnPlainPtrs & columns, size_t float_features_count, size_t cat_features_count,
|
||||||
|
bool cat_features_are_strings) const
|
||||||
|
{
|
||||||
|
std::string error_msg = "Error occurred while applying CatBoost model: ";
|
||||||
|
size_t column_size = columns.front()->size();
|
||||||
|
|
||||||
|
auto result= std::make_shared<ColumnFloat64>(column_size);
|
||||||
|
auto result_buf = result->getData().data();
|
||||||
|
|
||||||
|
/// Prepare float features.
|
||||||
|
PODArray<const float *> float_features(column_size);
|
||||||
|
auto float_features_buf = float_features.data();
|
||||||
|
/// Store all float data into single column. float_features is a list of pointers to it.
|
||||||
|
auto float_features_col = placeNumericColumns<float>(columns, 0, float_features_count, float_features_buf);
|
||||||
|
|
||||||
|
if (cat_features_count == 0)
|
||||||
|
{
|
||||||
|
if (!api->CalcModelPredictionFlat(handle->get(), column_size,
|
||||||
|
float_features_buf, float_features_count,
|
||||||
|
result_buf, column_size))
|
||||||
|
{
|
||||||
|
|
||||||
|
throw Exception(error_msg + api->GetErrorString(), ErrorCodes::CANNOT_APPLY_CATBOOST_MODEL);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Prepare cat features.
|
||||||
|
if (cat_features_are_strings)
|
||||||
|
{
|
||||||
|
/// cat_features_holder stores pointers to ColumnString data or fixed_strings_data.
|
||||||
|
PODArray<const char *> cat_features_holder(cat_features_count * column_size);
|
||||||
|
PODArray<const char **> cat_features(column_size);
|
||||||
|
auto cat_features_buf = cat_features.data();
|
||||||
|
|
||||||
|
fillCatFeaturesBuffer(cat_features_buf, cat_features_holder.data(), column_size, cat_features_count);
|
||||||
|
/// Fixed strings are stored without termination zero, so have to copy data into fixed_strings_data.
|
||||||
|
auto fixed_strings_data = placeStringColumns(columns, float_features_count,
|
||||||
|
cat_features_count, cat_features_holder.data());
|
||||||
|
|
||||||
|
if (!api->CalcModelPrediction(handle->get(), column_size,
|
||||||
|
float_features_buf, float_features_count,
|
||||||
|
cat_features_buf, cat_features_count,
|
||||||
|
result_buf, column_size))
|
||||||
|
{
|
||||||
|
throw Exception(error_msg + api->GetErrorString(), ErrorCodes::CANNOT_APPLY_CATBOOST_MODEL);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
PODArray<const int *> cat_features(column_size);
|
||||||
|
auto cat_features_buf = cat_features.data();
|
||||||
|
auto cat_features_col = placeNumericColumns<int>(columns, float_features_count,
|
||||||
|
cat_features_count, cat_features_buf);
|
||||||
|
calcHashes(columns, float_features_count, cat_features_count, cat_features_buf);
|
||||||
|
if (!api->CalcModelPredictionWithHashedCatFeatures(
|
||||||
|
handle->get(), column_size,
|
||||||
|
float_features_buf, float_features_count,
|
||||||
|
cat_features_buf, cat_features_count,
|
||||||
|
result_buf, column_size))
|
||||||
|
{
|
||||||
|
throw Exception(error_msg + api->GetErrorString(), ErrorCodes::CANNOT_APPLY_CATBOOST_MODEL);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
/// Holds CatBoost wrapper library and provides wrapper interface.
|
||||||
|
class CatBoostLibHolder: public CatBoostWrapperAPIProvider
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
explicit CatBoostLibHolder(const std::string & lib_path) : lib_path(lib_path), lib(lib_path) { initAPI(); }
|
||||||
|
|
||||||
|
const CatBoostWrapperAPI & getAPI() const override { return api; }
|
||||||
|
const std::string & getCurrentPath() const { return lib_path; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
CatBoostWrapperAPI api;
|
||||||
|
std::string lib_path;
|
||||||
|
SharedLibrary lib;
|
||||||
|
|
||||||
|
void initAPI();
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void load(T& func, const std::string & name) { func = lib.get<T>(name); }
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void tryLoad(T& func, const std::string & name) { func = lib.tryGet<T>(name); }
|
||||||
|
};
|
||||||
|
|
||||||
|
void CatBoostLibHolder::initAPI()
|
||||||
|
{
|
||||||
|
load(api.ModelCalcerCreate, "ModelCalcerCreate");
|
||||||
|
load(api.ModelCalcerDelete, "ModelCalcerDelete");
|
||||||
|
load(api.GetErrorString, "GetErrorString");
|
||||||
|
load(api.LoadFullModelFromFile, "LoadFullModelFromFile");
|
||||||
|
load(api.CalcModelPredictionFlat, "CalcModelPredictionFlat");
|
||||||
|
load(api.CalcModelPrediction, "CalcModelPrediction");
|
||||||
|
load(api.CalcModelPredictionWithHashedCatFeatures, "CalcModelPredictionWithHashedCatFeatures");
|
||||||
|
load(api.GetStringCatFeatureHash, "GetStringCatFeatureHash");
|
||||||
|
load(api.GetIntegerCatFeatureHash, "GetIntegerCatFeatureHash");
|
||||||
|
|
||||||
|
tryLoad(api.GetFloatFeaturesCount, "GetFloatFeaturesCount");
|
||||||
|
tryLoad(api.GetCatFeaturesCount, "GetCatFeaturesCount");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<CatBoostLibHolder> getCatBoostWrapperHolder(const std::string & lib_path)
|
||||||
|
{
|
||||||
|
static std::weak_ptr<CatBoostLibHolder> ptr;
|
||||||
|
static std::mutex mutex;
|
||||||
|
|
||||||
|
std::lock_guard<std::mutex> lock(mutex);
|
||||||
|
auto result = ptr.lock();
|
||||||
|
|
||||||
|
if (!result || result->getCurrentPath() != lib_path)
|
||||||
|
{
|
||||||
|
result = std::make_shared<CatBoostLibHolder>(lib_path);
|
||||||
|
/// This assignment is not atomic, which prevents from creating lock only inside 'if'.
|
||||||
|
ptr = result;
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
CatBoostModel::CatBoostModel(const std::string & name, const std::string & model_path, const std::string & lib_path,
|
||||||
|
const ExternalLoadableLifetime & lifetime,
|
||||||
|
size_t float_features_count, size_t cat_features_count)
|
||||||
|
: name(name), model_path(model_path), lib_path(lib_path), lifetime(lifetime),
|
||||||
|
float_features_count(float_features_count), cat_features_count(cat_features_count)
|
||||||
|
{
|
||||||
|
try
|
||||||
|
{
|
||||||
|
init(lib_path);
|
||||||
|
}
|
||||||
|
catch (...)
|
||||||
|
{
|
||||||
|
creation_exception = std::current_exception();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void CatBoostModel::init(const std::string & lib_path)
|
||||||
|
{
|
||||||
|
api_provider = getCatBoostWrapperHolder(lib_path);
|
||||||
|
api = &api_provider->getAPI();
|
||||||
|
model = std::make_unique<CatBoostModelImpl>(api, model_path);
|
||||||
|
}
|
||||||
|
|
||||||
|
const ExternalLoadableLifetime & CatBoostModel::getLifetime() const
|
||||||
|
{
|
||||||
|
return lifetime;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool CatBoostModel::isModified() const
|
||||||
|
{
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<IExternalLoadable> CatBoostModel::clone() const
|
||||||
|
{
|
||||||
|
return std::make_unique<CatBoostModel>(name, model_path, lib_path, lifetime, float_features_count, cat_features_count);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t CatBoostModel::getFloatFeaturesCount() const
|
||||||
|
{
|
||||||
|
return float_features_count;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t CatBoostModel::getCatFeaturesCount() const
|
||||||
|
{
|
||||||
|
return cat_features_count;
|
||||||
|
}
|
||||||
|
|
||||||
|
ColumnPtr CatBoostModel::evaluate(const ConstColumnPlainPtrs & columns) const
|
||||||
|
{
|
||||||
|
if (!model)
|
||||||
|
throw Exception("CatBoost model was not loaded.", ErrorCodes::LOGICAL_ERROR);
|
||||||
|
return model->evaluate(columns, float_features_count, cat_features_count);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
79
dbms/src/Dictionaries/CatBoostModel.h
Normal file
79
dbms/src/Dictionaries/CatBoostModel.h
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
#pragma once
|
||||||
|
#include <Interpreters/IExternalLoadable.h>
|
||||||
|
#include <Columns/IColumn.h>
|
||||||
|
#include <Columns/ColumnsNumber.h>
|
||||||
|
|
||||||
|
|
||||||
|
namespace DB
|
||||||
|
{
|
||||||
|
|
||||||
|
/// CatBoost wrapper interface functions.
|
||||||
|
struct CatBoostWrapperAPI;
|
||||||
|
class CatBoostWrapperAPIProvider
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
virtual ~CatBoostWrapperAPIProvider() = default;
|
||||||
|
virtual const CatBoostWrapperAPI & getAPI() const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// CatBoost model interface.
|
||||||
|
class ICatBoostModel
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
virtual ~ICatBoostModel() = default;
|
||||||
|
/// Evaluate model. Use first `float_features_count` columns as float features,
|
||||||
|
/// the others `cat_features_count` as categorical features.
|
||||||
|
virtual ColumnPtr evaluate(const ConstColumnPlainPtrs & columns, size_t float_features_count, size_t cat_features_count) const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// General ML model evaluator interface.
|
||||||
|
class IModel : public IExternalLoadable
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
virtual ColumnPtr evaluate(const ConstColumnPlainPtrs & columns) const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
class CatBoostModel : public IModel
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
CatBoostModel(const std::string & name, const std::string & model_path,
|
||||||
|
const std::string & lib_path, const ExternalLoadableLifetime & lifetime,
|
||||||
|
size_t float_features_count, size_t cat_features_count);
|
||||||
|
|
||||||
|
ColumnPtr evaluate(const ConstColumnPlainPtrs & columns) const override;
|
||||||
|
|
||||||
|
size_t getFloatFeaturesCount() const;
|
||||||
|
size_t getCatFeaturesCount() const;
|
||||||
|
|
||||||
|
/// IExternalLoadable interface.
|
||||||
|
|
||||||
|
const ExternalLoadableLifetime & getLifetime() const override;
|
||||||
|
|
||||||
|
std::string getName() const override { return name; }
|
||||||
|
|
||||||
|
bool supportUpdates() const override { return true; }
|
||||||
|
|
||||||
|
bool isModified() const override;
|
||||||
|
|
||||||
|
std::unique_ptr<IExternalLoadable> clone() const override;
|
||||||
|
|
||||||
|
std::exception_ptr getCreationException() const override { return creation_exception; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::string name;
|
||||||
|
std::string model_path;
|
||||||
|
std::string lib_path;
|
||||||
|
ExternalLoadableLifetime lifetime;
|
||||||
|
std::exception_ptr creation_exception;
|
||||||
|
std::shared_ptr<CatBoostWrapperAPIProvider> api_provider;
|
||||||
|
const CatBoostWrapperAPI * api;
|
||||||
|
|
||||||
|
std::unique_ptr<ICatBoostModel> model;
|
||||||
|
|
||||||
|
size_t float_features_count;
|
||||||
|
size_t cat_features_count;
|
||||||
|
|
||||||
|
void init(const std::string & lib_path);
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
@ -98,7 +98,7 @@ public:
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
DictionaryPtr clone() const override
|
std::unique_ptr<IExternalLoadable> clone() const override
|
||||||
{
|
{
|
||||||
return std::make_unique<ComplexKeyCacheDictionary>(*this);
|
return std::make_unique<ComplexKeyCacheDictionary>(*this);
|
||||||
}
|
}
|
||||||
|
@ -46,7 +46,7 @@ public:
|
|||||||
|
|
||||||
bool isCached() const override { return false; }
|
bool isCached() const override { return false; }
|
||||||
|
|
||||||
DictionaryPtr clone() const override { return std::make_unique<ComplexKeyHashedDictionary>(*this); }
|
std::unique_ptr<IExternalLoadable> clone() const override { return std::make_unique<ComplexKeyHashedDictionary>(*this); }
|
||||||
|
|
||||||
const IDictionarySource * getSource() const override { return source_ptr.get(); }
|
const IDictionarySource * getSource() const override { return source_ptr.get(); }
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ class Context;
|
|||||||
class DictionaryFactory : public ext::singleton<DictionaryFactory>
|
class DictionaryFactory : public ext::singleton<DictionaryFactory>
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
DictionaryPtr create(const std::string & name, Poco::Util::AbstractConfiguration & config,
|
DictionaryPtr create(const std::string & name, const Poco::Util::AbstractConfiguration & config,
|
||||||
const std::string & config_prefix, Context & context) const;
|
const std::string & config_prefix, Context & context) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -90,7 +90,7 @@ DictionarySourceFactory::DictionarySourceFactory()
|
|||||||
|
|
||||||
|
|
||||||
DictionarySourcePtr DictionarySourceFactory::create(
|
DictionarySourcePtr DictionarySourceFactory::create(
|
||||||
const std::string & name, Poco::Util::AbstractConfiguration & config, const std::string & config_prefix,
|
const std::string & name, const Poco::Util::AbstractConfiguration & config, const std::string & config_prefix,
|
||||||
const DictionaryStructure & dict_struct, Context & context) const
|
const DictionaryStructure & dict_struct, Context & context) const
|
||||||
{
|
{
|
||||||
Poco::Util::AbstractConfiguration::Keys keys;
|
Poco::Util::AbstractConfiguration::Keys keys;
|
||||||
|
@ -25,7 +25,7 @@ public:
|
|||||||
DictionarySourceFactory();
|
DictionarySourceFactory();
|
||||||
|
|
||||||
DictionarySourcePtr create(
|
DictionarySourcePtr create(
|
||||||
const std::string & name, Poco::Util::AbstractConfiguration & config, const std::string & config_prefix,
|
const std::string & name, const Poco::Util::AbstractConfiguration & config, const std::string & config_prefix,
|
||||||
const DictionaryStructure & dict_struct, Context & context) const;
|
const DictionaryStructure & dict_struct, Context & context) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -105,16 +105,6 @@ std::string toString(const AttributeUnderlyingType type)
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
DictionaryLifetime::DictionaryLifetime(const Poco::Util::AbstractConfiguration & config, const std::string & config_prefix)
|
|
||||||
{
|
|
||||||
const auto & lifetime_min_key = config_prefix + ".min";
|
|
||||||
const auto has_min = config.has(lifetime_min_key);
|
|
||||||
|
|
||||||
this->min_sec = has_min ? config.getInt(lifetime_min_key) : config.getInt(config_prefix);
|
|
||||||
this->max_sec = has_min ? config.getInt(config_prefix + ".max") : this->min_sec;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
DictionarySpecialAttribute::DictionarySpecialAttribute(const Poco::Util::AbstractConfiguration & config, const std::string & config_prefix)
|
DictionarySpecialAttribute::DictionarySpecialAttribute(const Poco::Util::AbstractConfiguration & config, const std::string & config_prefix)
|
||||||
: name{config.getString(config_prefix + ".name", "")},
|
: name{config.getString(config_prefix + ".name", "")},
|
||||||
expression{config.getString(config_prefix + ".expression", "")}
|
expression{config.getString(config_prefix + ".expression", "")}
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
#include <IO/ReadBufferFromString.h>
|
#include <IO/ReadBufferFromString.h>
|
||||||
#include <IO/WriteBuffer.h>
|
#include <IO/WriteBuffer.h>
|
||||||
#include <IO/WriteHelpers.h>
|
#include <IO/WriteHelpers.h>
|
||||||
|
#include <Interpreters/IExternalLoadable.h>
|
||||||
#include <Poco/Util/AbstractConfiguration.h>
|
#include <Poco/Util/AbstractConfiguration.h>
|
||||||
#include <ext/range.h>
|
#include <ext/range.h>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
@ -42,13 +43,7 @@ std::string toString(const AttributeUnderlyingType type);
|
|||||||
|
|
||||||
|
|
||||||
/// Min and max lifetimes for a dictionary or it's entry
|
/// Min and max lifetimes for a dictionary or it's entry
|
||||||
struct DictionaryLifetime final
|
using DictionaryLifetime = ExternalLoadableLifetime;
|
||||||
{
|
|
||||||
UInt64 min_sec;
|
|
||||||
UInt64 max_sec;
|
|
||||||
|
|
||||||
DictionaryLifetime(const Poco::Util::AbstractConfiguration & config, const std::string & config_prefix);
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
/** Holds the description of a single dictionary attribute:
|
/** Holds the description of a single dictionary attribute:
|
||||||
|
@ -41,7 +41,7 @@ public:
|
|||||||
|
|
||||||
bool isCached() const override { return false; }
|
bool isCached() const override { return false; }
|
||||||
|
|
||||||
DictionaryPtr clone() const override { return std::make_unique<FlatDictionary>(*this); }
|
std::unique_ptr<IExternalLoadable> clone() const override { return std::make_unique<FlatDictionary>(*this); }
|
||||||
|
|
||||||
const IDictionarySource * getSource() const override { return source_ptr.get(); }
|
const IDictionarySource * getSource() const override { return source_ptr.get(); }
|
||||||
|
|
||||||
|
@ -40,7 +40,7 @@ public:
|
|||||||
|
|
||||||
bool isCached() const override { return false; }
|
bool isCached() const override { return false; }
|
||||||
|
|
||||||
DictionaryPtr clone() const override { return std::make_unique<HashedDictionary>(*this); }
|
std::unique_ptr<IExternalLoadable> clone() const override { return std::make_unique<HashedDictionary>(*this); }
|
||||||
|
|
||||||
const IDictionarySource * getSource() const override { return source_ptr.get(); }
|
const IDictionarySource * getSource() const override { return source_ptr.get(); }
|
||||||
|
|
||||||
|
@ -1,22 +1,21 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <Core/Field.h>
|
#include <Core/Field.h>
|
||||||
|
#include <Interpreters/IExternalLoadable.h>
|
||||||
#include <common/StringRef.h>
|
#include <common/StringRef.h>
|
||||||
#include <Core/Names.h>
|
#include <Core/Names.h>
|
||||||
#include <Poco/Util/XMLConfiguration.h>
|
#include <Poco/Util/XMLConfiguration.h>
|
||||||
#include <Common/PODArray.h>
|
#include <Common/PODArray.h>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
|
#include <Dictionaries/IDictionarySource.h>
|
||||||
|
|
||||||
namespace DB
|
namespace DB
|
||||||
{
|
{
|
||||||
|
|
||||||
class IDictionarySource;
|
|
||||||
|
|
||||||
struct IDictionaryBase;
|
struct IDictionaryBase;
|
||||||
using DictionaryPtr = std::unique_ptr<IDictionaryBase>;
|
using DictionaryPtr = std::unique_ptr<IDictionaryBase>;
|
||||||
|
|
||||||
struct DictionaryLifetime;
|
|
||||||
struct DictionaryStructure;
|
struct DictionaryStructure;
|
||||||
class ColumnString;
|
class ColumnString;
|
||||||
|
|
||||||
@ -24,14 +23,10 @@ class IBlockInputStream;
|
|||||||
using BlockInputStreamPtr = std::shared_ptr<IBlockInputStream>;
|
using BlockInputStreamPtr = std::shared_ptr<IBlockInputStream>;
|
||||||
|
|
||||||
|
|
||||||
struct IDictionaryBase : public std::enable_shared_from_this<IDictionaryBase>
|
struct IDictionaryBase : public IExternalLoadable
|
||||||
{
|
{
|
||||||
using Key = UInt64;
|
using Key = UInt64;
|
||||||
|
|
||||||
virtual std::exception_ptr getCreationException() const = 0;
|
|
||||||
|
|
||||||
virtual std::string getName() const = 0;
|
|
||||||
|
|
||||||
virtual std::string getTypeName() const = 0;
|
virtual std::string getTypeName() const = 0;
|
||||||
|
|
||||||
virtual size_t getBytesAllocated() const = 0;
|
virtual size_t getBytesAllocated() const = 0;
|
||||||
@ -45,12 +40,9 @@ struct IDictionaryBase : public std::enable_shared_from_this<IDictionaryBase>
|
|||||||
virtual double getLoadFactor() const = 0;
|
virtual double getLoadFactor() const = 0;
|
||||||
|
|
||||||
virtual bool isCached() const = 0;
|
virtual bool isCached() const = 0;
|
||||||
virtual DictionaryPtr clone() const = 0;
|
|
||||||
|
|
||||||
virtual const IDictionarySource * getSource() const = 0;
|
virtual const IDictionarySource * getSource() const = 0;
|
||||||
|
|
||||||
virtual const DictionaryLifetime & getLifetime() const = 0;
|
|
||||||
|
|
||||||
virtual const DictionaryStructure & getStructure() const = 0;
|
virtual const DictionaryStructure & getStructure() const = 0;
|
||||||
|
|
||||||
virtual std::chrono::time_point<std::chrono::system_clock> getCreationTime() const = 0;
|
virtual std::chrono::time_point<std::chrono::system_clock> getCreationTime() const = 0;
|
||||||
@ -59,7 +51,23 @@ struct IDictionaryBase : public std::enable_shared_from_this<IDictionaryBase>
|
|||||||
|
|
||||||
virtual BlockInputStreamPtr getBlockInputStream(const Names & column_names, size_t max_block_size) const = 0;
|
virtual BlockInputStreamPtr getBlockInputStream(const Names & column_names, size_t max_block_size) const = 0;
|
||||||
|
|
||||||
virtual ~IDictionaryBase() = default;
|
bool supportUpdates() const override { return !isCached(); }
|
||||||
|
|
||||||
|
bool isModified() const override
|
||||||
|
{
|
||||||
|
auto source = getSource();
|
||||||
|
return source && source->isModified();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<IDictionaryBase> shared_from_this()
|
||||||
|
{
|
||||||
|
return std::static_pointer_cast<IDictionaryBase>(IExternalLoadable::shared_from_this());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<const IDictionaryBase> shared_from_this() const
|
||||||
|
{
|
||||||
|
return std::static_pointer_cast<const IDictionaryBase>(IExternalLoadable::shared_from_this());
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
@ -41,7 +41,7 @@ public:
|
|||||||
|
|
||||||
bool isCached() const override { return false; }
|
bool isCached() const override { return false; }
|
||||||
|
|
||||||
DictionaryPtr clone() const override { return std::make_unique<RangeHashedDictionary>(*this); }
|
std::unique_ptr<IExternalLoadable> clone() const override { return std::make_unique<RangeHashedDictionary>(*this); }
|
||||||
|
|
||||||
const IDictionarySource * getSource() const override { return source_ptr.get(); }
|
const IDictionarySource * getSource() const override { return source_ptr.get(); }
|
||||||
|
|
||||||
|
@ -50,7 +50,7 @@ public:
|
|||||||
|
|
||||||
bool isCached() const override { return false; }
|
bool isCached() const override { return false; }
|
||||||
|
|
||||||
DictionaryPtr clone() const override { return std::make_unique<TrieDictionary>(*this); }
|
std::unique_ptr<IExternalLoadable> clone() const override { return std::make_unique<TrieDictionary>(*this); }
|
||||||
|
|
||||||
const IDictionarySource * getSource() const override { return source_ptr.get(); }
|
const IDictionarySource * getSource() const override { return source_ptr.get(); }
|
||||||
|
|
||||||
|
62
dbms/src/Functions/FunctionsExternalModels.cpp
Normal file
62
dbms/src/Functions/FunctionsExternalModels.cpp
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
#include <Functions/FunctionsExternalModels.h>
|
||||||
|
#include <Functions/FunctionHelpers.h>
|
||||||
|
#include <Functions/FunctionFactory.h>
|
||||||
|
|
||||||
|
#include <Interpreters/Context.h>
|
||||||
|
#include <Interpreters/ExternalModels.h>
|
||||||
|
#include <DataTypes/DataTypeString.h>
|
||||||
|
#include <DataTypes/DataTypesNumber.h>
|
||||||
|
#include <Columns/ColumnString.h>
|
||||||
|
#include <ext/range.h>
|
||||||
|
|
||||||
|
namespace DB
|
||||||
|
{
|
||||||
|
|
||||||
|
FunctionPtr FunctionModelEvaluate::create(const Context & context)
|
||||||
|
{
|
||||||
|
return std::make_shared<FunctionModelEvaluate>(context.getExternalModels());
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace ErrorCodes
|
||||||
|
{
|
||||||
|
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||||
|
extern const int TOO_LESS_ARGUMENTS_FOR_FUNCTION;
|
||||||
|
extern const int ILLEGAL_COLUMN;
|
||||||
|
}
|
||||||
|
|
||||||
|
DataTypePtr FunctionModelEvaluate::getReturnTypeImpl(const DataTypes & arguments) const
|
||||||
|
{
|
||||||
|
if (arguments.size() < 2)
|
||||||
|
throw Exception("Function " + getName() + " expects at least 2 arguments",
|
||||||
|
ErrorCodes::TOO_LESS_ARGUMENTS_FOR_FUNCTION);
|
||||||
|
|
||||||
|
if (!checkDataType<DataTypeString>(arguments[0].get()))
|
||||||
|
throw Exception("Illegal type " + arguments[0]->getName() + " of first argument of function " + getName()
|
||||||
|
+ ", expected a string.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||||
|
|
||||||
|
return std::make_shared<DataTypeFloat64>();
|
||||||
|
}
|
||||||
|
|
||||||
|
void FunctionModelEvaluate::executeImpl(Block & block, const ColumnNumbers & arguments, size_t result)
|
||||||
|
{
|
||||||
|
const auto name_col = checkAndGetColumnConst<ColumnString>(block.getByPosition(arguments[0]).column.get());
|
||||||
|
if (!name_col)
|
||||||
|
throw Exception("First argument of function " + getName() + " must be a constant string",
|
||||||
|
ErrorCodes::ILLEGAL_COLUMN);
|
||||||
|
|
||||||
|
auto model = models.getModel(name_col->getValue<String>());
|
||||||
|
|
||||||
|
ConstColumnPlainPtrs columns;
|
||||||
|
columns.reserve(arguments.size());
|
||||||
|
for (auto i : ext::range(1, arguments.size()))
|
||||||
|
columns.push_back(block.getByPosition(arguments[i]).column.get());
|
||||||
|
|
||||||
|
block.getByPosition(result).column = model->evaluate(columns);
|
||||||
|
}
|
||||||
|
|
||||||
|
void registerFunctionsExternalModels(FunctionFactory & factory)
|
||||||
|
{
|
||||||
|
factory.registerFunction<FunctionModelEvaluate>();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
36
dbms/src/Functions/FunctionsExternalModels.h
Normal file
36
dbms/src/Functions/FunctionsExternalModels.h
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
#pragma once
|
||||||
|
#include <Functions/IFunction.h>
|
||||||
|
|
||||||
|
namespace DB
|
||||||
|
{
|
||||||
|
|
||||||
|
class ExternalModels;
|
||||||
|
|
||||||
|
/// Evaluate external model.
|
||||||
|
/// First argument - model name, the others - model arguments.
|
||||||
|
/// * for CatBoost model - float features first, then categorical
|
||||||
|
/// Result - Float64.
|
||||||
|
class FunctionModelEvaluate final : public IFunction
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
static constexpr auto name = "modelEvaluate";
|
||||||
|
|
||||||
|
static FunctionPtr create(const Context & context);
|
||||||
|
|
||||||
|
explicit FunctionModelEvaluate(const ExternalModels & models) : models(models) {}
|
||||||
|
|
||||||
|
String getName() const override { return name; }
|
||||||
|
|
||||||
|
bool isVariadic() const override { return true; }
|
||||||
|
|
||||||
|
size_t getNumberOfArguments() const override { return 0; }
|
||||||
|
|
||||||
|
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override;
|
||||||
|
|
||||||
|
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const ExternalModels & models;
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
@ -20,6 +20,7 @@ void registerFunctionsConversion(FunctionFactory &);
|
|||||||
void registerFunctionsDateTime(FunctionFactory &);
|
void registerFunctionsDateTime(FunctionFactory &);
|
||||||
void registerFunctionsEmbeddedDictionaries(FunctionFactory &);
|
void registerFunctionsEmbeddedDictionaries(FunctionFactory &);
|
||||||
void registerFunctionsExternalDictionaries(FunctionFactory &);
|
void registerFunctionsExternalDictionaries(FunctionFactory &);
|
||||||
|
void registerFunctionsExternalModels(FunctionFactory &);
|
||||||
void registerFunctionsFormatting(FunctionFactory &);
|
void registerFunctionsFormatting(FunctionFactory &);
|
||||||
void registerFunctionsHashing(FunctionFactory &);
|
void registerFunctionsHashing(FunctionFactory &);
|
||||||
void registerFunctionsHigherOrder(FunctionFactory &);
|
void registerFunctionsHigherOrder(FunctionFactory &);
|
||||||
@ -54,6 +55,7 @@ void registerFunctions()
|
|||||||
registerFunctionsDateTime(factory);
|
registerFunctionsDateTime(factory);
|
||||||
registerFunctionsEmbeddedDictionaries(factory);
|
registerFunctionsEmbeddedDictionaries(factory);
|
||||||
registerFunctionsExternalDictionaries(factory);
|
registerFunctionsExternalDictionaries(factory);
|
||||||
|
registerFunctionsExternalModels(factory);
|
||||||
registerFunctionsFormatting(factory);
|
registerFunctionsFormatting(factory);
|
||||||
registerFunctionsHashing(factory);
|
registerFunctionsHashing(factory);
|
||||||
registerFunctionsHigherOrder(factory);
|
registerFunctionsHigherOrder(factory);
|
||||||
|
@ -29,6 +29,7 @@
|
|||||||
#include <Interpreters/Quota.h>
|
#include <Interpreters/Quota.h>
|
||||||
#include <Interpreters/EmbeddedDictionaries.h>
|
#include <Interpreters/EmbeddedDictionaries.h>
|
||||||
#include <Interpreters/ExternalDictionaries.h>
|
#include <Interpreters/ExternalDictionaries.h>
|
||||||
|
#include <Interpreters/ExternalModels.h>
|
||||||
#include <Interpreters/ProcessList.h>
|
#include <Interpreters/ProcessList.h>
|
||||||
#include <Interpreters/Cluster.h>
|
#include <Interpreters/Cluster.h>
|
||||||
#include <Interpreters/InterserverIOHandler.h>
|
#include <Interpreters/InterserverIOHandler.h>
|
||||||
@ -94,6 +95,7 @@ struct ContextShared
|
|||||||
/// Separate mutex for access of dictionaries. Separate mutex to avoid locks when server doing request to itself.
|
/// Separate mutex for access of dictionaries. Separate mutex to avoid locks when server doing request to itself.
|
||||||
mutable std::mutex embedded_dictionaries_mutex;
|
mutable std::mutex embedded_dictionaries_mutex;
|
||||||
mutable std::mutex external_dictionaries_mutex;
|
mutable std::mutex external_dictionaries_mutex;
|
||||||
|
mutable std::mutex external_models_mutex;
|
||||||
/// Separate mutex for re-initialization of zookeer session. This operation could take a long time and must not interfere with another operations.
|
/// Separate mutex for re-initialization of zookeer session. This operation could take a long time and must not interfere with another operations.
|
||||||
mutable std::mutex zookeeper_mutex;
|
mutable std::mutex zookeeper_mutex;
|
||||||
|
|
||||||
@ -111,6 +113,7 @@ struct ContextShared
|
|||||||
FormatFactory format_factory; /// Formats.
|
FormatFactory format_factory; /// Formats.
|
||||||
mutable std::shared_ptr<EmbeddedDictionaries> embedded_dictionaries; /// Metrica's dictionaeis. Have lazy initialization.
|
mutable std::shared_ptr<EmbeddedDictionaries> embedded_dictionaries; /// Metrica's dictionaeis. Have lazy initialization.
|
||||||
mutable std::shared_ptr<ExternalDictionaries> external_dictionaries;
|
mutable std::shared_ptr<ExternalDictionaries> external_dictionaries;
|
||||||
|
mutable std::shared_ptr<ExternalModels> external_models;
|
||||||
String default_profile_name; /// Default profile name used for default values.
|
String default_profile_name; /// Default profile name used for default values.
|
||||||
Users users; /// Known users.
|
Users users; /// Known users.
|
||||||
Quotas quotas; /// Known quotas for resource use.
|
Quotas quotas; /// Known quotas for resource use.
|
||||||
@ -1062,6 +1065,17 @@ ExternalDictionaries & Context::getExternalDictionaries()
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
const ExternalModels & Context::getExternalModels() const
|
||||||
|
{
|
||||||
|
return getExternalModelsImpl(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
ExternalModels & Context::getExternalModels()
|
||||||
|
{
|
||||||
|
return getExternalModelsImpl(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
EmbeddedDictionaries & Context::getEmbeddedDictionariesImpl(const bool throw_on_error) const
|
EmbeddedDictionaries & Context::getEmbeddedDictionariesImpl(const bool throw_on_error) const
|
||||||
{
|
{
|
||||||
std::lock_guard<std::mutex> lock(shared->embedded_dictionaries_mutex);
|
std::lock_guard<std::mutex> lock(shared->embedded_dictionaries_mutex);
|
||||||
@ -1087,6 +1101,19 @@ ExternalDictionaries & Context::getExternalDictionariesImpl(const bool throw_on_
|
|||||||
return *shared->external_dictionaries;
|
return *shared->external_dictionaries;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ExternalModels & Context::getExternalModelsImpl(bool throw_on_error) const
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> lock(shared->external_models_mutex);
|
||||||
|
|
||||||
|
if (!shared->external_models)
|
||||||
|
{
|
||||||
|
if (!this->global_context)
|
||||||
|
throw Exception("Logical error: there is no global context", ErrorCodes::LOGICAL_ERROR);
|
||||||
|
shared->external_models = std::make_shared<ExternalModels>(*this->global_context, throw_on_error);
|
||||||
|
}
|
||||||
|
|
||||||
|
return *shared->external_models;
|
||||||
|
}
|
||||||
|
|
||||||
void Context::tryCreateEmbeddedDictionaries() const
|
void Context::tryCreateEmbeddedDictionaries() const
|
||||||
{
|
{
|
||||||
@ -1100,6 +1127,12 @@ void Context::tryCreateExternalDictionaries() const
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void Context::tryCreateExternalModels() const
|
||||||
|
{
|
||||||
|
static_cast<void>(getExternalModelsImpl(true));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
void Context::setProgressCallback(ProgressCallback callback)
|
void Context::setProgressCallback(ProgressCallback callback)
|
||||||
{
|
{
|
||||||
/// Callback is set to a session or to a query. In the session, only one query is processed at a time. Therefore, the lock is not needed.
|
/// Callback is set to a session or to a query. In the session, only one query is processed at a time. Therefore, the lock is not needed.
|
||||||
|
@ -35,6 +35,7 @@ struct ContextShared;
|
|||||||
class QuotaForIntervals;
|
class QuotaForIntervals;
|
||||||
class EmbeddedDictionaries;
|
class EmbeddedDictionaries;
|
||||||
class ExternalDictionaries;
|
class ExternalDictionaries;
|
||||||
|
class ExternalModels;
|
||||||
class InterserverIOHandler;
|
class InterserverIOHandler;
|
||||||
class BackgroundProcessingPool;
|
class BackgroundProcessingPool;
|
||||||
class ReshardingWorker;
|
class ReshardingWorker;
|
||||||
@ -209,10 +210,13 @@ public:
|
|||||||
|
|
||||||
const EmbeddedDictionaries & getEmbeddedDictionaries() const;
|
const EmbeddedDictionaries & getEmbeddedDictionaries() const;
|
||||||
const ExternalDictionaries & getExternalDictionaries() const;
|
const ExternalDictionaries & getExternalDictionaries() const;
|
||||||
|
const ExternalModels & getExternalModels() const;
|
||||||
EmbeddedDictionaries & getEmbeddedDictionaries();
|
EmbeddedDictionaries & getEmbeddedDictionaries();
|
||||||
ExternalDictionaries & getExternalDictionaries();
|
ExternalDictionaries & getExternalDictionaries();
|
||||||
|
ExternalModels & getExternalModels();
|
||||||
void tryCreateEmbeddedDictionaries() const;
|
void tryCreateEmbeddedDictionaries() const;
|
||||||
void tryCreateExternalDictionaries() const;
|
void tryCreateExternalDictionaries() const;
|
||||||
|
void tryCreateExternalModels() const;
|
||||||
|
|
||||||
/// I/O formats.
|
/// I/O formats.
|
||||||
BlockInputStreamPtr getInputFormat(const String & name, ReadBuffer & buf, const Block & sample, size_t max_block_size) const;
|
BlockInputStreamPtr getInputFormat(const String & name, ReadBuffer & buf, const Block & sample, size_t max_block_size) const;
|
||||||
@ -362,6 +366,7 @@ private:
|
|||||||
|
|
||||||
EmbeddedDictionaries & getEmbeddedDictionariesImpl(bool throw_on_error) const;
|
EmbeddedDictionaries & getEmbeddedDictionariesImpl(bool throw_on_error) const;
|
||||||
ExternalDictionaries & getExternalDictionariesImpl(bool throw_on_error) const;
|
ExternalDictionaries & getExternalDictionariesImpl(bool throw_on_error) const;
|
||||||
|
ExternalModels & getExternalModelsImpl(bool throw_on_error) const;
|
||||||
|
|
||||||
StoragePtr getTableImpl(const String & database_name, const String & table_name, Exception * exception) const;
|
StoragePtr getTableImpl(const String & database_name, const String & table_name, Exception * exception) const;
|
||||||
|
|
||||||
|
@ -24,7 +24,7 @@ namespace ErrorCodes
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
DictionaryPtr DictionaryFactory::create(const std::string & name, Poco::Util::AbstractConfiguration & config,
|
DictionaryPtr DictionaryFactory::create(const std::string & name, const Poco::Util::AbstractConfiguration & config,
|
||||||
const std::string & config_prefix, Context & context) const
|
const std::string & config_prefix, Context & context) const
|
||||||
{
|
{
|
||||||
Poco::Util::AbstractConfiguration::Keys keys;
|
Poco::Util::AbstractConfiguration::Keys keys;
|
||||||
|
@ -1,427 +1,46 @@
|
|||||||
#include <Interpreters/ExternalDictionaries.h>
|
#include <Interpreters/ExternalDictionaries.h>
|
||||||
#include <Interpreters/Context.h>
|
#include <Interpreters/Context.h>
|
||||||
#include <Dictionaries/DictionaryFactory.h>
|
#include <Dictionaries/DictionaryFactory.h>
|
||||||
#include <Dictionaries/DictionaryStructure.h>
|
|
||||||
#include <Dictionaries/IDictionarySource.h>
|
|
||||||
#include <Common/StringUtils.h>
|
|
||||||
#include <Common/MemoryTracker.h>
|
|
||||||
#include <Common/getMultipleKeysFromConfig.h>
|
|
||||||
#include <ext/scope_guard.h>
|
|
||||||
#include <Poco/Util/Application.h>
|
|
||||||
#include <Poco/Glob.h>
|
|
||||||
#include <Poco/File.h>
|
|
||||||
|
|
||||||
|
|
||||||
namespace
|
|
||||||
{
|
|
||||||
const auto check_period_sec = 5;
|
|
||||||
const auto backoff_initial_sec = 5;
|
|
||||||
/// 10 minutes
|
|
||||||
const auto backoff_max_sec = 10 * 60;
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace DB
|
namespace DB
|
||||||
{
|
{
|
||||||
|
|
||||||
namespace ErrorCodes
|
|
||||||
{
|
|
||||||
extern const int LOGICAL_ERROR;
|
|
||||||
extern const int BAD_ARGUMENTS;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
void ExternalDictionaries::reloadPeriodically()
|
|
||||||
{
|
|
||||||
setThreadName("ExterDictReload");
|
|
||||||
|
|
||||||
while (true)
|
|
||||||
{
|
|
||||||
if (destroy.tryWait(check_period_sec * 1000))
|
|
||||||
return;
|
|
||||||
|
|
||||||
reloadAndUpdate();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
ExternalDictionaries::ExternalDictionaries(Context & context, const bool throw_on_error)
|
|
||||||
: context(context), log(&Logger::get("ExternalDictionaries"))
|
|
||||||
{
|
|
||||||
{
|
|
||||||
/** During synchronous loading of external dictionaries at moment of query execution,
|
|
||||||
* we should not use per query memory limit.
|
|
||||||
*/
|
|
||||||
TemporarilyDisableMemoryTracker temporarily_disable_memory_tracker;
|
|
||||||
|
|
||||||
reloadAndUpdate(throw_on_error);
|
|
||||||
}
|
|
||||||
|
|
||||||
reloading_thread = std::thread{&ExternalDictionaries::reloadPeriodically, this};
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
ExternalDictionaries::~ExternalDictionaries()
|
|
||||||
{
|
|
||||||
destroy.set();
|
|
||||||
reloading_thread.join();
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace
|
namespace
|
||||||
{
|
{
|
||||||
std::set<std::string> getDictionariesConfigPaths(const Poco::Util::AbstractConfiguration & config)
|
const ExternalLoaderUpdateSettings externalDictionariesUpdateSettings;
|
||||||
{
|
|
||||||
std::set<std::string> files;
|
const ExternalLoaderConfigSettings & getExternalDictionariesConfigSettings()
|
||||||
auto patterns = getMultipleValuesFromConfig(config, "", "dictionaries_config");
|
|
||||||
for (auto & pattern : patterns)
|
|
||||||
{
|
{
|
||||||
if (pattern.empty())
|
static ExternalLoaderConfigSettings settings;
|
||||||
continue;
|
static std::once_flag flag;
|
||||||
|
|
||||||
if (pattern[0] != '/')
|
std::call_once(flag, [] {
|
||||||
{
|
settings.external_config = "dictionary";
|
||||||
const auto app_config_path = config.getString("config-file", "config.xml");
|
settings.external_name = "name";
|
||||||
const auto config_dir = Poco::Path{app_config_path}.parent().toString();
|
|
||||||
const auto absolute_path = config_dir + pattern;
|
|
||||||
Poco::Glob::glob(absolute_path, files, 0);
|
|
||||||
if (!files.empty())
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
Poco::Glob::glob(pattern, files, 0);
|
settings.path_setting_name = "dictionaries_config";
|
||||||
}
|
});
|
||||||
|
|
||||||
return files;
|
return settings;
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void ExternalDictionaries::reloadAndUpdate(bool throw_on_error)
|
|
||||||
{
|
|
||||||
reloadFromConfigFiles(throw_on_error);
|
|
||||||
|
|
||||||
/// list of recreated dictionaries to perform delayed removal from unordered_map
|
|
||||||
std::list<std::string> recreated_failed_dictionaries;
|
|
||||||
|
|
||||||
std::unique_lock<std::mutex> all_lock(all_mutex);
|
|
||||||
|
|
||||||
/// retry loading failed dictionaries
|
|
||||||
for (auto & failed_dictionary : failed_dictionaries)
|
|
||||||
{
|
|
||||||
if (std::chrono::system_clock::now() < failed_dictionary.second.next_attempt_time)
|
|
||||||
continue;
|
|
||||||
|
|
||||||
const auto & name = failed_dictionary.first;
|
|
||||||
|
|
||||||
try
|
|
||||||
{
|
|
||||||
auto dict_ptr = failed_dictionary.second.dict->clone();
|
|
||||||
if (const auto exception_ptr = dict_ptr->getCreationException())
|
|
||||||
{
|
|
||||||
/// recalculate next attempt time
|
|
||||||
std::uniform_int_distribution<UInt64> distribution(
|
|
||||||
0, static_cast<UInt64>(std::exp2(failed_dictionary.second.error_count)));
|
|
||||||
|
|
||||||
failed_dictionary.second.next_attempt_time = std::chrono::system_clock::now() +
|
|
||||||
std::chrono::seconds{
|
|
||||||
std::min<UInt64>(backoff_max_sec, backoff_initial_sec + distribution(rnd_engine))};
|
|
||||||
|
|
||||||
++failed_dictionary.second.error_count;
|
|
||||||
|
|
||||||
std::rethrow_exception(exception_ptr);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
const std::lock_guard<std::mutex> lock{dictionaries_mutex};
|
|
||||||
|
|
||||||
const auto & lifetime = dict_ptr->getLifetime();
|
|
||||||
std::uniform_int_distribution<UInt64> distribution{lifetime.min_sec, lifetime.max_sec};
|
|
||||||
update_times[name] = std::chrono::system_clock::now() + std::chrono::seconds{distribution(rnd_engine)};
|
|
||||||
|
|
||||||
const auto dict_it = dictionaries.find(name);
|
|
||||||
if (dict_it->second.dict)
|
|
||||||
dict_it->second.dict->set(dict_ptr.release());
|
|
||||||
else
|
|
||||||
dict_it->second.dict = std::make_shared<MultiVersion<IDictionaryBase>>(dict_ptr.release());
|
|
||||||
|
|
||||||
/// clear stored exception on success
|
|
||||||
dict_it->second.exception = std::exception_ptr{};
|
|
||||||
|
|
||||||
recreated_failed_dictionaries.push_back(name);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
catch (...)
|
|
||||||
{
|
|
||||||
tryLogCurrentException(log, "Failed reloading '" + name + "' dictionary");
|
|
||||||
|
|
||||||
if (throw_on_error)
|
|
||||||
throw;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// do not undertake further attempts to recreate these dictionaries
|
|
||||||
for (const auto & name : recreated_failed_dictionaries)
|
|
||||||
failed_dictionaries.erase(name);
|
|
||||||
|
|
||||||
/// periodic update
|
|
||||||
for (auto & dictionary : dictionaries)
|
|
||||||
{
|
|
||||||
const auto & name = dictionary.first;
|
|
||||||
|
|
||||||
try
|
|
||||||
{
|
|
||||||
/// If the dictionary failed to load or even failed to initialize from the config.
|
|
||||||
if (!dictionary.second.dict)
|
|
||||||
continue;
|
|
||||||
|
|
||||||
auto current = dictionary.second.dict->get();
|
|
||||||
const auto & lifetime = current->getLifetime();
|
|
||||||
|
|
||||||
/// do not update dictionaries with zero as lifetime
|
|
||||||
if (lifetime.min_sec == 0 || lifetime.max_sec == 0)
|
|
||||||
continue;
|
|
||||||
|
|
||||||
/// update only non-cached dictionaries
|
|
||||||
if (!current->isCached())
|
|
||||||
{
|
|
||||||
auto & update_time = update_times[current->getName()];
|
|
||||||
|
|
||||||
/// check that timeout has passed
|
|
||||||
if (std::chrono::system_clock::now() < update_time)
|
|
||||||
continue;
|
|
||||||
|
|
||||||
SCOPE_EXIT({
|
|
||||||
/// calculate next update time
|
|
||||||
std::uniform_int_distribution<UInt64> distribution{lifetime.min_sec, lifetime.max_sec};
|
|
||||||
update_time = std::chrono::system_clock::now() + std::chrono::seconds{distribution(rnd_engine)};
|
|
||||||
});
|
|
||||||
|
|
||||||
/// check source modified
|
|
||||||
if (current->getSource()->isModified())
|
|
||||||
{
|
|
||||||
/// create new version of dictionary
|
|
||||||
auto new_version = current->clone();
|
|
||||||
|
|
||||||
if (const auto exception_ptr = new_version->getCreationException())
|
|
||||||
std::rethrow_exception(exception_ptr);
|
|
||||||
|
|
||||||
dictionary.second.dict->set(new_version.release());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// erase stored exception on success
|
|
||||||
dictionary.second.exception = std::exception_ptr{};
|
|
||||||
}
|
|
||||||
catch (...)
|
|
||||||
{
|
|
||||||
dictionary.second.exception = std::current_exception();
|
|
||||||
|
|
||||||
tryLogCurrentException(log, "Cannot update external dictionary '" + name + "', leaving old version");
|
|
||||||
|
|
||||||
if (throw_on_error)
|
|
||||||
throw;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ExternalDictionaries::reloadFromConfigFiles(const bool throw_on_error, const bool force_reload, const std::string & only_dictionary)
|
|
||||||
|
ExternalDictionaries::ExternalDictionaries(Context & context, bool throw_on_error)
|
||||||
|
: ExternalLoader(context.getConfigRef(),
|
||||||
|
externalDictionariesUpdateSettings,
|
||||||
|
getExternalDictionariesConfigSettings(),
|
||||||
|
&Logger::get("ExternalDictionaries"),
|
||||||
|
"external dictionary"),
|
||||||
|
context(context)
|
||||||
{
|
{
|
||||||
const auto config_paths = getDictionariesConfigPaths(context.getConfigRef());
|
init(throw_on_error);
|
||||||
|
|
||||||
for (const auto & config_path : config_paths)
|
|
||||||
{
|
|
||||||
try
|
|
||||||
{
|
|
||||||
reloadFromConfigFile(config_path, throw_on_error, force_reload, only_dictionary);
|
|
||||||
}
|
|
||||||
catch (...)
|
|
||||||
{
|
|
||||||
tryLogCurrentException(log, "reloadFromConfigFile has thrown while reading from " + config_path);
|
|
||||||
|
|
||||||
if (throw_on_error)
|
|
||||||
throw;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void ExternalDictionaries::reloadFromConfigFile(const std::string & config_path, const bool throw_on_error, const bool force_reload,
|
std::unique_ptr<IExternalLoadable> ExternalDictionaries::create(
|
||||||
const std::string & only_dictionary)
|
const std::string & name, const Configuration & config, const std::string & config_prefix)
|
||||||
{
|
{
|
||||||
const Poco::File config_file{config_path};
|
return DictionaryFactory::instance().create(name, config, config_prefix, context);
|
||||||
|
|
||||||
if (config_path.empty() || !config_file.exists())
|
|
||||||
{
|
|
||||||
LOG_WARNING(log, "config file '" + config_path + "' does not exist");
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
std::unique_lock<std::mutex> all_lock(all_mutex);
|
|
||||||
|
|
||||||
auto modification_time_it = last_modification_times.find(config_path);
|
|
||||||
if (modification_time_it == std::end(last_modification_times))
|
|
||||||
modification_time_it = last_modification_times.emplace(config_path, Poco::Timestamp{0}).first;
|
|
||||||
auto & config_last_modified = modification_time_it->second;
|
|
||||||
|
|
||||||
const auto last_modified = config_file.getLastModified();
|
|
||||||
if (force_reload || last_modified > config_last_modified)
|
|
||||||
{
|
|
||||||
Poco::AutoPtr<Poco::Util::XMLConfiguration> config = new Poco::Util::XMLConfiguration(config_path);
|
|
||||||
|
|
||||||
/// Definitions of dictionaries may have changed, recreate all of them
|
|
||||||
|
|
||||||
/// If we need update only one dictionary, don't update modification time: might be other dictionaries in the config file
|
|
||||||
if (only_dictionary.empty())
|
|
||||||
config_last_modified = last_modified;
|
|
||||||
|
|
||||||
/// get all dictionaries' definitions
|
|
||||||
Poco::Util::AbstractConfiguration::Keys keys;
|
|
||||||
config->keys(keys);
|
|
||||||
|
|
||||||
/// for each dictionary defined in xml config
|
|
||||||
for (const auto & key : keys)
|
|
||||||
{
|
|
||||||
std::string name;
|
|
||||||
|
|
||||||
if (!startsWith(key, "dictionary"))
|
|
||||||
{
|
|
||||||
if (!startsWith(key.data(), "comment"))
|
|
||||||
LOG_WARNING(log,
|
|
||||||
config_path << ": unknown node in dictionaries file: '" << key + "', 'dictionary'");
|
|
||||||
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
try
|
|
||||||
{
|
|
||||||
name = config->getString(key + ".name");
|
|
||||||
if (name.empty())
|
|
||||||
{
|
|
||||||
LOG_WARNING(log, config_path << ": dictionary name cannot be empty");
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!only_dictionary.empty() && name != only_dictionary)
|
|
||||||
continue;
|
|
||||||
|
|
||||||
decltype(dictionaries.begin()) dict_it;
|
|
||||||
{
|
|
||||||
std::lock_guard<std::mutex> lock{dictionaries_mutex};
|
|
||||||
dict_it = dictionaries.find(name);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (dict_it != std::end(dictionaries) && dict_it->second.origin != config_path)
|
|
||||||
throw std::runtime_error{"Overriding dictionary from file " + dict_it->second.origin};
|
|
||||||
|
|
||||||
auto dict_ptr = DictionaryFactory::instance().create(name, *config, key, context);
|
|
||||||
|
|
||||||
/// If the dictionary could not be loaded.
|
|
||||||
if (const auto exception_ptr = dict_ptr->getCreationException())
|
|
||||||
{
|
|
||||||
const auto failed_dict_it = failed_dictionaries.find(name);
|
|
||||||
if (failed_dict_it != std::end(failed_dictionaries))
|
|
||||||
{
|
|
||||||
failed_dict_it->second = FailedDictionaryInfo{
|
|
||||||
std::move(dict_ptr),
|
|
||||||
std::chrono::system_clock::now() + std::chrono::seconds{backoff_initial_sec}};
|
|
||||||
}
|
|
||||||
else
|
|
||||||
failed_dictionaries.emplace(name, FailedDictionaryInfo{
|
|
||||||
std::move(dict_ptr),
|
|
||||||
std::chrono::system_clock::now() + std::chrono::seconds{backoff_initial_sec}});
|
|
||||||
|
|
||||||
std::rethrow_exception(exception_ptr);
|
|
||||||
}
|
|
||||||
else if (!dict_ptr->isCached())
|
|
||||||
{
|
|
||||||
const auto & lifetime = dict_ptr->getLifetime();
|
|
||||||
if (lifetime.min_sec != 0 && lifetime.max_sec != 0)
|
|
||||||
{
|
|
||||||
std::uniform_int_distribution<UInt64> distribution{
|
|
||||||
lifetime.min_sec,
|
|
||||||
lifetime.max_sec
|
|
||||||
};
|
|
||||||
update_times[name] = std::chrono::system_clock::now() +
|
|
||||||
std::chrono::seconds{distribution(rnd_engine)};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const std::lock_guard<std::mutex> lock{dictionaries_mutex};
|
|
||||||
|
|
||||||
/// add new dictionary or update an existing version
|
|
||||||
if (dict_it == std::end(dictionaries))
|
|
||||||
dictionaries.emplace(name, DictionaryInfo{
|
|
||||||
std::make_shared<MultiVersion<IDictionaryBase>>(dict_ptr.release()),
|
|
||||||
config_path
|
|
||||||
});
|
|
||||||
else
|
|
||||||
{
|
|
||||||
if (dict_it->second.dict)
|
|
||||||
dict_it->second.dict->set(dict_ptr.release());
|
|
||||||
else
|
|
||||||
dict_it->second.dict = std::make_shared<MultiVersion<IDictionaryBase>>(dict_ptr.release());
|
|
||||||
|
|
||||||
/// erase stored exception on success
|
|
||||||
dict_it->second.exception = std::exception_ptr{};
|
|
||||||
failed_dictionaries.erase(name);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
catch (...)
|
|
||||||
{
|
|
||||||
if (!name.empty())
|
|
||||||
{
|
|
||||||
/// If the dictionary could not load data or even failed to initialize from the config.
|
|
||||||
/// - all the same we insert information into the `dictionaries`, with the zero pointer `dict`.
|
|
||||||
|
|
||||||
const std::lock_guard<std::mutex> lock{dictionaries_mutex};
|
|
||||||
|
|
||||||
const auto exception_ptr = std::current_exception();
|
|
||||||
const auto dict_it = dictionaries.find(name);
|
|
||||||
if (dict_it == std::end(dictionaries))
|
|
||||||
dictionaries.emplace(name, DictionaryInfo{nullptr, config_path, exception_ptr});
|
|
||||||
else
|
|
||||||
dict_it->second.exception = exception_ptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
tryLogCurrentException(log, "Cannot create external dictionary '" + name + "' from config path " + config_path);
|
|
||||||
|
|
||||||
/// propagate exception
|
|
||||||
if (throw_on_error)
|
|
||||||
throw;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void ExternalDictionaries::reload()
|
|
||||||
{
|
|
||||||
reloadFromConfigFiles(true, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
void ExternalDictionaries::reloadDictionary(const std::string & name)
|
|
||||||
{
|
|
||||||
reloadFromConfigFiles(true, true, name);
|
|
||||||
|
|
||||||
/// Check that specified dict was loaded
|
|
||||||
const std::lock_guard<std::mutex> lock{dictionaries_mutex};
|
|
||||||
if (!dictionaries.count(name))
|
|
||||||
throw Exception("Dictionary " + name + " wasn't loaded during the reload process", ErrorCodes::BAD_ARGUMENTS);
|
|
||||||
}
|
|
||||||
|
|
||||||
MultiVersion<IDictionaryBase>::Version ExternalDictionaries::getDictionary(const std::string & name) const
|
|
||||||
{
|
|
||||||
const std::lock_guard<std::mutex> lock{dictionaries_mutex};
|
|
||||||
|
|
||||||
const auto it = dictionaries.find(name);
|
|
||||||
if (it == std::end(dictionaries))
|
|
||||||
throw Exception{
|
|
||||||
"No such dictionary: " + name,
|
|
||||||
ErrorCodes::BAD_ARGUMENTS
|
|
||||||
};
|
|
||||||
|
|
||||||
if (!it->second.dict)
|
|
||||||
it->second.exception ? std::rethrow_exception(it->second.exception) :
|
|
||||||
throw Exception{"No dictionary", ErrorCodes::LOGICAL_ERROR};
|
|
||||||
|
|
||||||
return it->second.dict->get();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -1,19 +1,9 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <Dictionaries/IDictionary.h>
|
#include <Dictionaries/IDictionary.h>
|
||||||
#include <Common/Exception.h>
|
#include <Interpreters/ExternalLoader.h>
|
||||||
#include <Common/setThreadName.h>
|
|
||||||
#include <Common/randomSeed.h>
|
|
||||||
#include <common/MultiVersion.h>
|
|
||||||
#include <common/logger_useful.h>
|
#include <common/logger_useful.h>
|
||||||
#include <Poco/Event.h>
|
#include <memory>
|
||||||
#include <unistd.h>
|
|
||||||
#include <time.h>
|
|
||||||
#include <mutex>
|
|
||||||
#include <thread>
|
|
||||||
#include <unordered_map>
|
|
||||||
#include <chrono>
|
|
||||||
#include <pcg_random.hpp>
|
|
||||||
|
|
||||||
|
|
||||||
namespace DB
|
namespace DB
|
||||||
@ -21,93 +11,36 @@ namespace DB
|
|||||||
|
|
||||||
class Context;
|
class Context;
|
||||||
|
|
||||||
/** Manages user-defined dictionaries.
|
/// Manages user-defined dictionaries.
|
||||||
* Monitors configuration file and automatically reloads dictionaries in a separate thread.
|
class ExternalDictionaries : public ExternalLoader
|
||||||
* The monitoring thread wakes up every @check_period_sec seconds and checks
|
|
||||||
* modification time of dictionaries' configuration file. If said time is greater than
|
|
||||||
* @config_last_modified, the dictionaries are created from scratch using configuration file,
|
|
||||||
* possibly overriding currently existing dictionaries with the same name (previous versions of
|
|
||||||
* overridden dictionaries will live as long as there are any users retaining them).
|
|
||||||
*
|
|
||||||
* Apart from checking configuration file for modifications, each non-cached dictionary
|
|
||||||
* has a lifetime of its own and may be updated if it's source reports that it has been
|
|
||||||
* modified. The time of next update is calculated by choosing uniformly a random number
|
|
||||||
* distributed between lifetime.min_sec and lifetime.max_sec.
|
|
||||||
* If either of lifetime.min_sec and lifetime.max_sec is zero, such dictionary is never updated.
|
|
||||||
*/
|
|
||||||
class ExternalDictionaries
|
|
||||||
{
|
{
|
||||||
private:
|
public:
|
||||||
|
using DictPtr = std::shared_ptr<IDictionaryBase>;
|
||||||
|
|
||||||
|
/// Dictionaries will be loaded immediately and then will be updated in separate thread, each 'reload_period' seconds.
|
||||||
|
ExternalDictionaries(Context & context, bool throw_on_error);
|
||||||
|
|
||||||
|
/// Forcibly reloads specified dictionary.
|
||||||
|
void reloadDictionary(const std::string & name) { reload(name); }
|
||||||
|
|
||||||
|
DictPtr getDictionary(const std::string & name) const
|
||||||
|
{
|
||||||
|
return std::static_pointer_cast<IDictionaryBase>(getLoadable(name));
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
|
||||||
|
std::unique_ptr<IExternalLoadable> create(const std::string & name, const Configuration & config,
|
||||||
|
const std::string & config_prefix) override;
|
||||||
|
|
||||||
|
using ExternalLoader::getObjectsMap;
|
||||||
|
|
||||||
friend class StorageSystemDictionaries;
|
friend class StorageSystemDictionaries;
|
||||||
friend class DatabaseDictionary;
|
friend class DatabaseDictionary;
|
||||||
|
|
||||||
/// Protects only dictionaries map.
|
private:
|
||||||
mutable std::mutex dictionaries_mutex;
|
|
||||||
|
|
||||||
/// Protects all data, currently used to avoid races between updating thread and SYSTEM queries
|
|
||||||
mutable std::mutex all_mutex;
|
|
||||||
|
|
||||||
using DictionaryPtr = std::shared_ptr<MultiVersion<IDictionaryBase>>;
|
|
||||||
struct DictionaryInfo final
|
|
||||||
{
|
|
||||||
DictionaryPtr dict;
|
|
||||||
std::string origin;
|
|
||||||
std::exception_ptr exception;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct FailedDictionaryInfo final
|
|
||||||
{
|
|
||||||
std::unique_ptr<IDictionaryBase> dict;
|
|
||||||
std::chrono::system_clock::time_point next_attempt_time;
|
|
||||||
UInt64 error_count;
|
|
||||||
};
|
|
||||||
|
|
||||||
/** name -> dictionary.
|
|
||||||
*/
|
|
||||||
std::unordered_map<std::string, DictionaryInfo> dictionaries;
|
|
||||||
|
|
||||||
/** Here are dictionaries, that has been never loaded successfully.
|
|
||||||
* They are also in 'dictionaries', but with nullptr as 'dict'.
|
|
||||||
*/
|
|
||||||
std::unordered_map<std::string, FailedDictionaryInfo> failed_dictionaries;
|
|
||||||
|
|
||||||
/** Both for dictionaries and failed_dictionaries.
|
|
||||||
*/
|
|
||||||
std::unordered_map<std::string, std::chrono::system_clock::time_point> update_times;
|
|
||||||
|
|
||||||
pcg64 rnd_engine{randomSeed()};
|
|
||||||
|
|
||||||
Context & context;
|
Context & context;
|
||||||
|
|
||||||
std::thread reloading_thread;
|
|
||||||
Poco::Event destroy;
|
|
||||||
|
|
||||||
Logger * log;
|
|
||||||
|
|
||||||
std::unordered_map<std::string, Poco::Timestamp> last_modification_times;
|
|
||||||
|
|
||||||
/// Check dictionaries definitions in config files and reload or/and add new ones if the definition is changed
|
|
||||||
void reloadFromConfigFiles(const bool throw_on_error, const bool force_reload = false, const std::string & only_dictionary = "");
|
|
||||||
void reloadFromConfigFile(const std::string & config_path, const bool throw_on_error, const bool force_reload,
|
|
||||||
const std::string & only_dictionary);
|
|
||||||
|
|
||||||
/// Check config files and update expired dictionaries
|
|
||||||
void reloadAndUpdate(bool throw_on_error = false);
|
|
||||||
|
|
||||||
void reloadPeriodically();
|
|
||||||
|
|
||||||
public:
|
|
||||||
/// Dictionaries will be loaded immediately and then will be updated in separate thread, each 'reload_period' seconds.
|
|
||||||
ExternalDictionaries(Context & context, const bool throw_on_error);
|
|
||||||
~ExternalDictionaries();
|
|
||||||
|
|
||||||
/// Forcibly reloads all dictionaries.
|
|
||||||
void reload();
|
|
||||||
|
|
||||||
/// Forcibly reloads specified dictionary.
|
|
||||||
void reloadDictionary(const std::string & name);
|
|
||||||
|
|
||||||
MultiVersion<IDictionaryBase>::Version getDictionary(const std::string & name) const;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
433
dbms/src/Interpreters/ExternalLoader.cpp
Normal file
433
dbms/src/Interpreters/ExternalLoader.cpp
Normal file
@ -0,0 +1,433 @@
|
|||||||
|
#include <Interpreters/ExternalLoader.h>
|
||||||
|
#include <Common/StringUtils.h>
|
||||||
|
#include <Common/MemoryTracker.h>
|
||||||
|
#include <Common/Exception.h>
|
||||||
|
#include <Common/getMultipleKeysFromConfig.h>
|
||||||
|
#include <Common/setThreadName.h>
|
||||||
|
#include <ext/scope_guard.h>
|
||||||
|
#include <Poco/Util/Application.h>
|
||||||
|
#include <Poco/Glob.h>
|
||||||
|
#include <Poco/File.h>
|
||||||
|
#include <cmath>
|
||||||
|
#include <Poco/Util/XMLConfiguration.h>
|
||||||
|
|
||||||
|
namespace DB
|
||||||
|
{
|
||||||
|
|
||||||
|
namespace ErrorCodes
|
||||||
|
{
|
||||||
|
extern const int LOGICAL_ERROR;
|
||||||
|
extern const int BAD_ARGUMENTS;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
ExternalLoadableLifetime::ExternalLoadableLifetime(const Poco::Util::AbstractConfiguration & config,
|
||||||
|
const std::string & config_prefix)
|
||||||
|
{
|
||||||
|
const auto & lifetime_min_key = config_prefix + ".min";
|
||||||
|
const auto has_min = config.has(lifetime_min_key);
|
||||||
|
|
||||||
|
min_sec = has_min ? config.getUInt64(lifetime_min_key) : config.getUInt64(config_prefix);
|
||||||
|
max_sec = has_min ? config.getUInt64(config_prefix + ".max") : min_sec;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ExternalLoader::reloadPeriodically()
|
||||||
|
{
|
||||||
|
setThreadName("ExterLdrReload");
|
||||||
|
|
||||||
|
while (true)
|
||||||
|
{
|
||||||
|
if (destroy.tryWait(update_settings.check_period_sec * 1000))
|
||||||
|
return;
|
||||||
|
|
||||||
|
reloadAndUpdate();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
ExternalLoader::ExternalLoader(const Poco::Util::AbstractConfiguration & config,
|
||||||
|
const ExternalLoaderUpdateSettings & update_settings,
|
||||||
|
const ExternalLoaderConfigSettings & config_settings,
|
||||||
|
Logger * log, const std::string & loadable_object_name)
|
||||||
|
: config(config), update_settings(update_settings), config_settings(config_settings),
|
||||||
|
log(log), object_name(loadable_object_name)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
void ExternalLoader::init(bool throw_on_error)
|
||||||
|
{
|
||||||
|
if (is_initialized)
|
||||||
|
return;
|
||||||
|
|
||||||
|
is_initialized = true;
|
||||||
|
|
||||||
|
{
|
||||||
|
/// During synchronous loading of external dictionaries at moment of query execution,
|
||||||
|
/// we should not use per query memory limit.
|
||||||
|
TemporarilyDisableMemoryTracker temporarily_disable_memory_tracker;
|
||||||
|
|
||||||
|
reloadAndUpdate(throw_on_error);
|
||||||
|
}
|
||||||
|
|
||||||
|
reloading_thread = std::thread{&ExternalLoader::reloadPeriodically, this};
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
ExternalLoader::~ExternalLoader()
|
||||||
|
{
|
||||||
|
destroy.set();
|
||||||
|
reloading_thread.join();
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace
|
||||||
|
{
|
||||||
|
std::set<std::string> getConfigPaths(const Poco::Util::AbstractConfiguration & config,
|
||||||
|
const std::string & external_config_paths_setting)
|
||||||
|
{
|
||||||
|
std::set<std::string> files;
|
||||||
|
auto patterns = getMultipleValuesFromConfig(config, "", external_config_paths_setting);
|
||||||
|
for (auto & pattern : patterns)
|
||||||
|
{
|
||||||
|
if (pattern.empty())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
if (pattern[0] != '/')
|
||||||
|
{
|
||||||
|
const auto app_config_path = config.getString("config-file", "config.xml");
|
||||||
|
const auto config_dir = Poco::Path{app_config_path}.parent().toString();
|
||||||
|
const auto absolute_path = config_dir + pattern;
|
||||||
|
Poco::Glob::glob(absolute_path, files, 0);
|
||||||
|
if (!files.empty())
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
Poco::Glob::glob(pattern, files, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
return files;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ExternalLoader::reloadAndUpdate(bool throw_on_error)
|
||||||
|
{
|
||||||
|
reloadFromConfigFiles(throw_on_error);
|
||||||
|
|
||||||
|
/// list of recreated loadable objects to perform delayed removal from unordered_map
|
||||||
|
std::list<std::string> recreated_failed_loadable_objects;
|
||||||
|
|
||||||
|
std::unique_lock<std::mutex> all_lock(all_mutex);
|
||||||
|
|
||||||
|
/// retry loading failed loadable objects
|
||||||
|
for (auto & failed_loadable_object : failed_loadable_objects)
|
||||||
|
{
|
||||||
|
if (std::chrono::system_clock::now() < failed_loadable_object.second.next_attempt_time)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
const auto & name = failed_loadable_object.first;
|
||||||
|
|
||||||
|
try
|
||||||
|
{
|
||||||
|
auto loadable_ptr = failed_loadable_object.second.loadable->clone();
|
||||||
|
if (const auto exception_ptr = loadable_ptr->getCreationException())
|
||||||
|
{
|
||||||
|
/// recalculate next attempt time
|
||||||
|
std::uniform_int_distribution<UInt64> distribution(
|
||||||
|
0, static_cast<UInt64>(std::exp2(failed_loadable_object.second.error_count)));
|
||||||
|
|
||||||
|
std::chrono::seconds delay(std::min<UInt64>(
|
||||||
|
update_settings.backoff_max_sec,
|
||||||
|
update_settings.backoff_initial_sec + distribution(rnd_engine)));
|
||||||
|
failed_loadable_object.second.next_attempt_time = std::chrono::system_clock::now() + delay;
|
||||||
|
|
||||||
|
++failed_loadable_object.second.error_count;
|
||||||
|
|
||||||
|
std::rethrow_exception(exception_ptr);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
const std::lock_guard<std::mutex> lock{map_mutex};
|
||||||
|
|
||||||
|
const auto & lifetime = loadable_ptr->getLifetime();
|
||||||
|
std::uniform_int_distribution<UInt64> distribution{lifetime.min_sec, lifetime.max_sec};
|
||||||
|
update_times[name] = std::chrono::system_clock::now() + std::chrono::seconds{distribution(rnd_engine)};
|
||||||
|
|
||||||
|
const auto dict_it = loadable_objects.find(name);
|
||||||
|
|
||||||
|
dict_it->second.loadable.reset();
|
||||||
|
dict_it->second.loadable = std::move(loadable_ptr);
|
||||||
|
|
||||||
|
/// clear stored exception on success
|
||||||
|
dict_it->second.exception = std::exception_ptr{};
|
||||||
|
|
||||||
|
recreated_failed_loadable_objects.push_back(name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
catch (...)
|
||||||
|
{
|
||||||
|
tryLogCurrentException(log, "Failed reloading '" + name + "' " + object_name);
|
||||||
|
|
||||||
|
if (throw_on_error)
|
||||||
|
throw;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// do not undertake further attempts to recreate these loadable objects
|
||||||
|
for (const auto & name : recreated_failed_loadable_objects)
|
||||||
|
failed_loadable_objects.erase(name);
|
||||||
|
|
||||||
|
/// periodic update
|
||||||
|
for (auto & loadable_object : loadable_objects)
|
||||||
|
{
|
||||||
|
const auto & name = loadable_object.first;
|
||||||
|
|
||||||
|
try
|
||||||
|
{
|
||||||
|
/// If the loadable objects failed to load or even failed to initialize from the config.
|
||||||
|
if (!loadable_object.second.loadable)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
auto current = loadable_object.second.loadable;
|
||||||
|
const auto & lifetime = current->getLifetime();
|
||||||
|
|
||||||
|
/// do not update loadable objects with zero as lifetime
|
||||||
|
if (lifetime.min_sec == 0 || lifetime.max_sec == 0)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
if (current->supportUpdates())
|
||||||
|
{
|
||||||
|
auto & update_time = update_times[current->getName()];
|
||||||
|
|
||||||
|
/// check that timeout has passed
|
||||||
|
if (std::chrono::system_clock::now() < update_time)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
SCOPE_EXIT({
|
||||||
|
/// calculate next update time
|
||||||
|
std::uniform_int_distribution<UInt64> distribution{lifetime.min_sec, lifetime.max_sec};
|
||||||
|
update_time = std::chrono::system_clock::now() + std::chrono::seconds{distribution(rnd_engine)};
|
||||||
|
});
|
||||||
|
|
||||||
|
/// check source modified
|
||||||
|
if (current->isModified())
|
||||||
|
{
|
||||||
|
/// create new version of loadable object
|
||||||
|
auto new_version = current->clone();
|
||||||
|
|
||||||
|
if (const auto exception_ptr = new_version->getCreationException())
|
||||||
|
std::rethrow_exception(exception_ptr);
|
||||||
|
|
||||||
|
loadable_object.second.loadable.reset();
|
||||||
|
loadable_object.second.loadable = std::move(new_version);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// erase stored exception on success
|
||||||
|
loadable_object.second.exception = std::exception_ptr{};
|
||||||
|
}
|
||||||
|
catch (...)
|
||||||
|
{
|
||||||
|
loadable_object.second.exception = std::current_exception();
|
||||||
|
|
||||||
|
tryLogCurrentException(log, "Cannot update " + object_name + " '" + name + "', leaving old version");
|
||||||
|
|
||||||
|
if (throw_on_error)
|
||||||
|
throw;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ExternalLoader::reloadFromConfigFiles(const bool throw_on_error, const bool force_reload, const std::string & only_dictionary)
|
||||||
|
{
|
||||||
|
const auto config_paths = getConfigPaths(config, config_settings.path_setting_name);
|
||||||
|
|
||||||
|
for (const auto & config_path : config_paths)
|
||||||
|
{
|
||||||
|
try
|
||||||
|
{
|
||||||
|
reloadFromConfigFile(config_path, throw_on_error, force_reload, only_dictionary);
|
||||||
|
}
|
||||||
|
catch (...)
|
||||||
|
{
|
||||||
|
tryLogCurrentException(log, "reloadFromConfigFile has thrown while reading from " + config_path);
|
||||||
|
|
||||||
|
if (throw_on_error)
|
||||||
|
throw;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ExternalLoader::reloadFromConfigFile(const std::string & config_path, const bool throw_on_error,
|
||||||
|
const bool force_reload, const std::string & loadable_name)
|
||||||
|
{
|
||||||
|
const Poco::File config_file{config_path};
|
||||||
|
|
||||||
|
if (config_path.empty() || !config_file.exists())
|
||||||
|
{
|
||||||
|
LOG_WARNING(log, "config file '" + config_path + "' does not exist");
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
std::unique_lock<std::mutex> all_lock(all_mutex);
|
||||||
|
|
||||||
|
auto modification_time_it = last_modification_times.find(config_path);
|
||||||
|
if (modification_time_it == std::end(last_modification_times))
|
||||||
|
modification_time_it = last_modification_times.emplace(config_path, Poco::Timestamp{0}).first;
|
||||||
|
auto & config_last_modified = modification_time_it->second;
|
||||||
|
|
||||||
|
const auto last_modified = config_file.getLastModified();
|
||||||
|
if (force_reload || last_modified > config_last_modified)
|
||||||
|
{
|
||||||
|
Poco::AutoPtr<Poco::Util::XMLConfiguration> config = new Poco::Util::XMLConfiguration(config_path);
|
||||||
|
|
||||||
|
/// Definitions of loadable objects may have changed, recreate all of them
|
||||||
|
|
||||||
|
/// If we need update only one object, don't update modification time: might be other objects in the config file
|
||||||
|
if (loadable_name.empty())
|
||||||
|
config_last_modified = last_modified;
|
||||||
|
|
||||||
|
/// get all objects' definitions
|
||||||
|
Poco::Util::AbstractConfiguration::Keys keys;
|
||||||
|
config->keys(keys);
|
||||||
|
|
||||||
|
/// for each loadable object defined in xml config
|
||||||
|
for (const auto & key : keys)
|
||||||
|
{
|
||||||
|
std::string name;
|
||||||
|
|
||||||
|
if (!startsWith(key, config_settings.external_config))
|
||||||
|
{
|
||||||
|
if (!startsWith(key, "comment"))
|
||||||
|
LOG_WARNING(log, config_path << ": unknown node in file: '" << key
|
||||||
|
<< "', expected '" << config_settings.external_config << "'");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
try
|
||||||
|
{
|
||||||
|
name = config->getString(key + "." + config_settings.external_name);
|
||||||
|
if (name.empty())
|
||||||
|
{
|
||||||
|
LOG_WARNING(log, config_path << ": " + config_settings.external_name + " name cannot be empty");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!loadable_name.empty() && name != loadable_name)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
decltype(loadable_objects.begin()) object_it;
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> lock{map_mutex};
|
||||||
|
object_it = loadable_objects.find(name);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (object_it != std::end(loadable_objects) && object_it->second.origin != config_path)
|
||||||
|
throw std::runtime_error{"Overriding " + object_name + " from file " + object_it->second.origin};
|
||||||
|
|
||||||
|
auto object_ptr = create(name, *config, key);
|
||||||
|
|
||||||
|
/// If the object could not be loaded.
|
||||||
|
if (const auto exception_ptr = object_ptr->getCreationException())
|
||||||
|
{
|
||||||
|
std::chrono::seconds delay(update_settings.backoff_initial_sec);
|
||||||
|
const auto failed_dict_it = failed_loadable_objects.find(name);
|
||||||
|
FailedLoadableInfo info{std::move(object_ptr), std::chrono::system_clock::now() + delay, 0};
|
||||||
|
if (failed_dict_it != std::end(failed_loadable_objects))
|
||||||
|
(*failed_dict_it).second = std::move(info);
|
||||||
|
else
|
||||||
|
failed_loadable_objects.emplace(name, std::move(info));
|
||||||
|
|
||||||
|
std::rethrow_exception(exception_ptr);
|
||||||
|
}
|
||||||
|
else if (object_ptr->supportUpdates())
|
||||||
|
{
|
||||||
|
const auto & lifetime = object_ptr->getLifetime();
|
||||||
|
if (lifetime.min_sec != 0 && lifetime.max_sec != 0)
|
||||||
|
{
|
||||||
|
std::uniform_int_distribution<UInt64> distribution(lifetime.min_sec, lifetime.max_sec);
|
||||||
|
|
||||||
|
update_times[name] = std::chrono::system_clock::now() +
|
||||||
|
std::chrono::seconds{distribution(rnd_engine)};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::lock_guard<std::mutex> lock{map_mutex};
|
||||||
|
|
||||||
|
/// add new loadable object or update an existing version
|
||||||
|
if (object_it == std::end(loadable_objects))
|
||||||
|
loadable_objects.emplace(name, LoadableInfo{std::move(object_ptr), config_path});
|
||||||
|
else
|
||||||
|
{
|
||||||
|
if (object_it->second.loadable)
|
||||||
|
object_it->second.loadable.reset();
|
||||||
|
object_it->second.loadable = std::move(object_ptr);
|
||||||
|
|
||||||
|
/// erase stored exception on success
|
||||||
|
object_it->second.exception = std::exception_ptr{};
|
||||||
|
failed_loadable_objects.erase(name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
catch (...)
|
||||||
|
{
|
||||||
|
if (!name.empty())
|
||||||
|
{
|
||||||
|
/// If the loadable object could not load data or even failed to initialize from the config.
|
||||||
|
/// - all the same we insert information into the `loadable_objects`, with the zero pointer `loadable`.
|
||||||
|
|
||||||
|
const std::lock_guard<std::mutex> lock{map_mutex};
|
||||||
|
|
||||||
|
const auto exception_ptr = std::current_exception();
|
||||||
|
const auto loadable_it = loadable_objects.find(name);
|
||||||
|
if (loadable_it == std::end(loadable_objects))
|
||||||
|
loadable_objects.emplace(name, LoadableInfo{nullptr, config_path, exception_ptr});
|
||||||
|
else
|
||||||
|
loadable_it->second.exception = exception_ptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
tryLogCurrentException(log, "Cannot create " + object_name + " '"
|
||||||
|
+ name + "' from config path " + config_path);
|
||||||
|
|
||||||
|
/// propagate exception
|
||||||
|
if (throw_on_error)
|
||||||
|
throw;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ExternalLoader::reload()
|
||||||
|
{
|
||||||
|
reloadFromConfigFiles(true, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ExternalLoader::reload(const std::string & name)
|
||||||
|
{
|
||||||
|
reloadFromConfigFiles(true, true, name);
|
||||||
|
|
||||||
|
/// Check that specified object was loaded
|
||||||
|
const std::lock_guard<std::mutex> lock{map_mutex};
|
||||||
|
if (!loadable_objects.count(name))
|
||||||
|
throw Exception("Failed to load " + object_name + " '" + name + "' during the reload process", ErrorCodes::BAD_ARGUMENTS);
|
||||||
|
}
|
||||||
|
|
||||||
|
ExternalLoader::LoadablePtr ExternalLoader::getLoadable(const std::string & name) const
|
||||||
|
{
|
||||||
|
const std::lock_guard<std::mutex> lock{map_mutex};
|
||||||
|
|
||||||
|
const auto it = loadable_objects.find(name);
|
||||||
|
if (it == std::end(loadable_objects))
|
||||||
|
throw Exception("No such " + object_name + ": " + name, ErrorCodes::BAD_ARGUMENTS);
|
||||||
|
|
||||||
|
if (!it->second.loadable)
|
||||||
|
it->second.exception ? std::rethrow_exception(it->second.exception) :
|
||||||
|
throw Exception{object_name + " '" + name + "' is not loaded", ErrorCodes::LOGICAL_ERROR};
|
||||||
|
|
||||||
|
return it->second.loadable;
|
||||||
|
}
|
||||||
|
|
||||||
|
ExternalLoader::LockedObjectsMap ExternalLoader::getObjectsMap() const
|
||||||
|
{
|
||||||
|
return LockedObjectsMap(map_mutex, loadable_objects);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
173
dbms/src/Interpreters/ExternalLoader.h
Normal file
173
dbms/src/Interpreters/ExternalLoader.h
Normal file
@ -0,0 +1,173 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <common/logger_useful.h>
|
||||||
|
#include <Poco/Event.h>
|
||||||
|
#include <mutex>
|
||||||
|
#include <thread>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <chrono>
|
||||||
|
#include <tuple>
|
||||||
|
#include <Interpreters/IExternalLoadable.h>
|
||||||
|
#include <Core/Types.h>
|
||||||
|
#include <pcg_random.hpp>
|
||||||
|
#include <Common/randomSeed.h>
|
||||||
|
|
||||||
|
|
||||||
|
namespace DB
|
||||||
|
{
|
||||||
|
|
||||||
|
class Context;
|
||||||
|
|
||||||
|
struct ExternalLoaderUpdateSettings
|
||||||
|
{
|
||||||
|
UInt64 check_period_sec = 5;
|
||||||
|
UInt64 backoff_initial_sec = 5;
|
||||||
|
/// 10 minutes
|
||||||
|
UInt64 backoff_max_sec = 10 * 60;
|
||||||
|
|
||||||
|
ExternalLoaderUpdateSettings() = default;
|
||||||
|
ExternalLoaderUpdateSettings(UInt64 check_period_sec, UInt64 backoff_initial_sec, UInt64 backoff_max_sec)
|
||||||
|
: check_period_sec(check_period_sec),
|
||||||
|
backoff_initial_sec(backoff_initial_sec),
|
||||||
|
backoff_max_sec(backoff_max_sec) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
/* External configuration structure.
|
||||||
|
*
|
||||||
|
* <external_group>
|
||||||
|
* <external_config>
|
||||||
|
* <external_name>name</external_name>
|
||||||
|
* ....
|
||||||
|
* </external_config>
|
||||||
|
* </external_group>
|
||||||
|
*/
|
||||||
|
struct ExternalLoaderConfigSettings
|
||||||
|
{
|
||||||
|
std::string external_config;
|
||||||
|
std::string external_name;
|
||||||
|
|
||||||
|
std::string path_setting_name;
|
||||||
|
};
|
||||||
|
|
||||||
|
/** Manages user-defined objects.
|
||||||
|
* Monitors configuration file and automatically reloads objects in a separate thread.
|
||||||
|
* The monitoring thread wakes up every @check_period_sec seconds and checks
|
||||||
|
* modification time of objects' configuration file. If said time is greater than
|
||||||
|
* @config_last_modified, the objects are created from scratch using configuration file,
|
||||||
|
* possibly overriding currently existing objects with the same name (previous versions of
|
||||||
|
* overridden objects will live as long as there are any users retaining them).
|
||||||
|
*
|
||||||
|
* Apart from checking configuration file for modifications, each object
|
||||||
|
* has a lifetime of its own and may be updated if it supportUpdates.
|
||||||
|
* The time of next update is calculated by choosing uniformly a random number
|
||||||
|
* distributed between lifetime.min_sec and lifetime.max_sec.
|
||||||
|
* If either of lifetime.min_sec and lifetime.max_sec is zero, such object is never updated.
|
||||||
|
*/
|
||||||
|
class ExternalLoader
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
using LoadablePtr = std::shared_ptr<IExternalLoadable>;
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct LoadableInfo final
|
||||||
|
{
|
||||||
|
LoadablePtr loadable;
|
||||||
|
std::string origin;
|
||||||
|
std::exception_ptr exception;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct FailedLoadableInfo final
|
||||||
|
{
|
||||||
|
std::unique_ptr<IExternalLoadable> loadable;
|
||||||
|
std::chrono::system_clock::time_point next_attempt_time;
|
||||||
|
UInt64 error_count;
|
||||||
|
};
|
||||||
|
|
||||||
|
public:
|
||||||
|
using Configuration = Poco::Util::AbstractConfiguration;
|
||||||
|
using ObjectsMap = std::unordered_map<std::string, LoadableInfo>;
|
||||||
|
|
||||||
|
/// Objects will be loaded immediately and then will be updated in separate thread, each 'reload_period' seconds.
|
||||||
|
ExternalLoader(const Configuration & config,
|
||||||
|
const ExternalLoaderUpdateSettings & update_settings,
|
||||||
|
const ExternalLoaderConfigSettings & config_settings,
|
||||||
|
Logger * log, const std::string & loadable_object_name);
|
||||||
|
virtual ~ExternalLoader();
|
||||||
|
|
||||||
|
/// Forcibly reloads all loadable objects.
|
||||||
|
void reload();
|
||||||
|
|
||||||
|
/// Forcibly reloads specified loadable object.
|
||||||
|
void reload(const std::string & name);
|
||||||
|
|
||||||
|
LoadablePtr getLoadable(const std::string & name) const;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
virtual std::unique_ptr<IExternalLoadable> create(const std::string & name, const Configuration & config,
|
||||||
|
const std::string & config_prefix) = 0;
|
||||||
|
|
||||||
|
class LockedObjectsMap
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
LockedObjectsMap(std::mutex & mutex, const ObjectsMap & objectsMap) : lock(mutex), objectsMap(objectsMap) {}
|
||||||
|
const ObjectsMap & get() { return objectsMap; }
|
||||||
|
private:
|
||||||
|
std::unique_lock<std::mutex> lock;
|
||||||
|
const ObjectsMap & objectsMap;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Direct access to objects.
|
||||||
|
LockedObjectsMap getObjectsMap() const;
|
||||||
|
|
||||||
|
/// Should be called in derived constructor (to avoid pure virtual call).
|
||||||
|
void init(bool throw_on_error);
|
||||||
|
|
||||||
|
private:
|
||||||
|
|
||||||
|
bool is_initialized = false;
|
||||||
|
|
||||||
|
/// Protects only objects map.
|
||||||
|
mutable std::mutex map_mutex;
|
||||||
|
|
||||||
|
/// Protects all data, currently used to avoid races between updating thread and SYSTEM queries
|
||||||
|
mutable std::mutex all_mutex;
|
||||||
|
|
||||||
|
/// name -> loadable.
|
||||||
|
ObjectsMap loadable_objects;
|
||||||
|
|
||||||
|
/// Here are loadable objects, that has been never loaded successfully.
|
||||||
|
/// They are also in 'loadable_objects', but with nullptr as 'loadable'.
|
||||||
|
std::unordered_map<std::string, FailedLoadableInfo> failed_loadable_objects;
|
||||||
|
|
||||||
|
/// Both for loadable_objects and failed_loadable_objects.
|
||||||
|
std::unordered_map<std::string, std::chrono::system_clock::time_point> update_times;
|
||||||
|
|
||||||
|
pcg64 rnd_engine{randomSeed()};
|
||||||
|
|
||||||
|
const Configuration & config;
|
||||||
|
const ExternalLoaderUpdateSettings & update_settings;
|
||||||
|
const ExternalLoaderConfigSettings & config_settings;
|
||||||
|
|
||||||
|
std::thread reloading_thread;
|
||||||
|
Poco::Event destroy;
|
||||||
|
|
||||||
|
Logger * log;
|
||||||
|
/// Loadable object name to use in log messages.
|
||||||
|
std::string object_name;
|
||||||
|
|
||||||
|
std::unordered_map<std::string, Poco::Timestamp> last_modification_times;
|
||||||
|
|
||||||
|
/// Check objects definitions in config files and reload or/and add new ones if the definition is changed
|
||||||
|
/// If loadable_name is not empty, load only loadable object with name loadable_name
|
||||||
|
void reloadFromConfigFiles(bool throw_on_error, bool force_reload = false, const std::string & loadable_name = "");
|
||||||
|
void reloadFromConfigFile(const std::string & config_path, bool throw_on_error, bool force_reload,
|
||||||
|
const std::string & loadable_name);
|
||||||
|
|
||||||
|
/// Check config files and update expired loadable objects
|
||||||
|
void reloadAndUpdate(bool throw_on_error = false);
|
||||||
|
|
||||||
|
void reloadPeriodically();
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
67
dbms/src/Interpreters/ExternalModels.cpp
Normal file
67
dbms/src/Interpreters/ExternalModels.cpp
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
#include <Interpreters/ExternalModels.h>
|
||||||
|
#include <Interpreters/Context.h>
|
||||||
|
|
||||||
|
namespace DB
|
||||||
|
{
|
||||||
|
|
||||||
|
namespace ErrorCodes
|
||||||
|
{
|
||||||
|
extern const int INVALID_CONFIG_PARAMETER;
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace
|
||||||
|
{
|
||||||
|
|
||||||
|
const ExternalLoaderUpdateSettings externalModelsUpdateSettings;
|
||||||
|
|
||||||
|
const ExternalLoaderConfigSettings & getExternalModelsConfigSettings()
|
||||||
|
{
|
||||||
|
static ExternalLoaderConfigSettings settings;
|
||||||
|
static std::once_flag flag;
|
||||||
|
|
||||||
|
std::call_once(flag, [] {
|
||||||
|
settings.external_config = "model";
|
||||||
|
settings.external_name = "name";
|
||||||
|
|
||||||
|
settings.path_setting_name = "models_config";
|
||||||
|
});
|
||||||
|
|
||||||
|
return settings;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
ExternalModels::ExternalModels(Context & context, bool throw_on_error)
|
||||||
|
: ExternalLoader(context.getConfigRef(),
|
||||||
|
externalModelsUpdateSettings,
|
||||||
|
getExternalModelsConfigSettings(),
|
||||||
|
&Logger::get("ExternalModels"),
|
||||||
|
"external model"),
|
||||||
|
context(context)
|
||||||
|
{
|
||||||
|
init(throw_on_error);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<IExternalLoadable> ExternalModels::create(
|
||||||
|
const std::string & name, const Configuration & config, const std::string & config_prefix)
|
||||||
|
{
|
||||||
|
String type = config.getString(config_prefix + ".type");
|
||||||
|
ExternalLoadableLifetime lifetime(config, config_prefix + ".lifetime");
|
||||||
|
|
||||||
|
/// TODO: add models factory.
|
||||||
|
if (type == "catboost")
|
||||||
|
{
|
||||||
|
return std::make_unique<CatBoostModel>(
|
||||||
|
name, config.getString(config_prefix + ".path"),
|
||||||
|
context.getConfigRef().getString("catboost_dynamic_library_path"),
|
||||||
|
lifetime, config.getUInt(config_prefix + ".float_features_count"),
|
||||||
|
config.getUInt(config_prefix + ".cat_features_count")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
throw Exception("Unknown model type: " + type, ErrorCodes::INVALID_CONFIG_PARAMETER);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
41
dbms/src/Interpreters/ExternalModels.h
Normal file
41
dbms/src/Interpreters/ExternalModels.h
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <Dictionaries/CatBoostModel.h>
|
||||||
|
#include <Interpreters/ExternalLoader.h>
|
||||||
|
#include <common/logger_useful.h>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
|
||||||
|
namespace DB
|
||||||
|
{
|
||||||
|
|
||||||
|
class Context;
|
||||||
|
|
||||||
|
/// Manages user-defined models.
|
||||||
|
class ExternalModels : public ExternalLoader
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
using ModelPtr = std::shared_ptr<IModel>;
|
||||||
|
|
||||||
|
/// Models will be loaded immediately and then will be updated in separate thread, each 'reload_period' seconds.
|
||||||
|
ExternalModels(Context & context, bool throw_on_error);
|
||||||
|
|
||||||
|
/// Forcibly reloads specified model.
|
||||||
|
void reloadModel(const std::string & name) { reload(name); }
|
||||||
|
|
||||||
|
ModelPtr getModel(const std::string & name) const
|
||||||
|
{
|
||||||
|
return std::static_pointer_cast<IModel>(getLoadable(name));
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
|
||||||
|
std::unique_ptr<IExternalLoadable> create(const std::string & name, const Configuration & config,
|
||||||
|
const std::string & config_prefix) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
|
||||||
|
Context & context;
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
45
dbms/src/Interpreters/IExternalLoadable.h
Normal file
45
dbms/src/Interpreters/IExternalLoadable.h
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
#pragma once
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include <Core/Types.h>
|
||||||
|
|
||||||
|
|
||||||
|
namespace Poco::Util
|
||||||
|
{
|
||||||
|
class AbstractConfiguration;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
namespace DB
|
||||||
|
{
|
||||||
|
|
||||||
|
/// Min and max lifetimes for a loadable object or it's entry
|
||||||
|
struct ExternalLoadableLifetime final
|
||||||
|
{
|
||||||
|
UInt64 min_sec;
|
||||||
|
UInt64 max_sec;
|
||||||
|
|
||||||
|
ExternalLoadableLifetime(const Poco::Util::AbstractConfiguration & config, const std::string & config_prefix);
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
/// Basic interface for external loadable objects. Is used in ExternalLoader.
|
||||||
|
class IExternalLoadable : public std::enable_shared_from_this<IExternalLoadable>
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
virtual ~IExternalLoadable() = default;
|
||||||
|
|
||||||
|
virtual const ExternalLoadableLifetime & getLifetime() const = 0;
|
||||||
|
|
||||||
|
virtual std::string getName() const = 0;
|
||||||
|
/// True if object can be updated when lifetime exceeded.
|
||||||
|
virtual bool supportUpdates() const = 0;
|
||||||
|
/// If lifetime exceeded and isModified() ExternalLoader replace current object with the result of clone().
|
||||||
|
virtual bool isModified() const = 0;
|
||||||
|
/// Returns new object with the same configuration. Is used to update modified object when lifetime exceeded.
|
||||||
|
virtual std::unique_ptr<IExternalLoadable> clone() const = 0;
|
||||||
|
|
||||||
|
virtual std::exception_ptr getCreationException() const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
@ -76,16 +76,17 @@ BlockInputStreams StorageSystemDictionaries::read(
|
|||||||
ColumnWithTypeAndName col_source{std::make_shared<ColumnString>(), std::make_shared<DataTypeString>(), "source"};
|
ColumnWithTypeAndName col_source{std::make_shared<ColumnString>(), std::make_shared<DataTypeString>(), "source"};
|
||||||
|
|
||||||
const auto & external_dictionaries = context.getExternalDictionaries();
|
const auto & external_dictionaries = context.getExternalDictionaries();
|
||||||
const std::lock_guard<std::mutex> lock{external_dictionaries.dictionaries_mutex};
|
auto objects_map = external_dictionaries.getObjectsMap();
|
||||||
|
const auto & dictionaries = objects_map.get();
|
||||||
|
|
||||||
for (const auto & dict_info : external_dictionaries.dictionaries)
|
for (const auto & dict_info : dictionaries)
|
||||||
{
|
{
|
||||||
col_name.column->insert(dict_info.first);
|
col_name.column->insert(dict_info.first);
|
||||||
col_origin.column->insert(dict_info.second.origin);
|
col_origin.column->insert(dict_info.second.origin);
|
||||||
|
|
||||||
if (dict_info.second.dict)
|
if (dict_info.second.loadable)
|
||||||
{
|
{
|
||||||
const auto dict_ptr = dict_info.second.dict->get();
|
const auto dict_ptr = std::static_pointer_cast<IDictionaryBase>(dict_info.second.loadable);
|
||||||
|
|
||||||
col_type.column->insert(dict_ptr->getTypeName());
|
col_type.column->insert(dict_ptr->getTypeName());
|
||||||
|
|
||||||
|
17
dbms/tests/external_models/catboost/data/build_catboost.sh
Executable file
17
dbms/tests/external_models/catboost/data/build_catboost.sh
Executable file
@ -0,0 +1,17 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
|
||||||
|
|
||||||
|
cd $DIR
|
||||||
|
git clone https://github.com/catboost/catboost.git
|
||||||
|
|
||||||
|
|
||||||
|
cd "${DIR}/catboost/catboost/libs/model_interface"
|
||||||
|
../../../ya make -r -o "${DIR}/build/lib" -j4
|
||||||
|
cd $DIR
|
||||||
|
ln -sf "${DIR}/build/lib/catboost/libs/model_interface/libcatboostmodel.so" libcatboostmodel.so
|
||||||
|
|
||||||
|
cd "${DIR}/catboost/catboost/python-package/catboost"
|
||||||
|
../../../ya make -r -DUSE_ARCADIA_PYTHON=no -DPYTHON_CONFIG=python2-config -j4
|
||||||
|
cd $DIR
|
||||||
|
ln -sf "${DIR}/catboost/catboost/python-package" python-package
|
42
dbms/tests/external_models/catboost/helpers/client.py
Normal file
42
dbms/tests/external_models/catboost/helpers/client.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
import subprocess
|
||||||
|
import threading
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
class ClickHouseClient:
|
||||||
|
def __init__(self, binary_path, port):
|
||||||
|
self.binary_path = binary_path
|
||||||
|
self.port = port
|
||||||
|
|
||||||
|
def query(self, query, timeout=10, pipe=None):
|
||||||
|
|
||||||
|
result = []
|
||||||
|
process = []
|
||||||
|
|
||||||
|
def run(path, port, text, result, in_pipe, process):
|
||||||
|
|
||||||
|
if in_pipe is None:
|
||||||
|
in_pipe = subprocess.PIPE
|
||||||
|
|
||||||
|
pipe = subprocess.Popen([path, 'client', '--port', str(port), '-q', text],
|
||||||
|
stdin=in_pipe, stdout=subprocess.PIPE, stderr=subprocess.PIPE, close_fds=True)
|
||||||
|
process.append(pipe)
|
||||||
|
stdout_data, stderr_data = pipe.communicate()
|
||||||
|
|
||||||
|
if stderr_data:
|
||||||
|
raise Exception('Error while executing query: {}\nstdout:\n{}\nstderr:\n{}'
|
||||||
|
.format(text, stdout_data, stderr_data))
|
||||||
|
|
||||||
|
result.append(stdout_data)
|
||||||
|
|
||||||
|
thread = threading.Thread(target=run, args=(self.binary_path, self.port, query, result, pipe, process))
|
||||||
|
thread.start()
|
||||||
|
thread.join(timeout)
|
||||||
|
if thread.isAlive():
|
||||||
|
if len(process):
|
||||||
|
process[0].kill()
|
||||||
|
thread.join()
|
||||||
|
raise Exception('timeout exceed for query: ' + query)
|
||||||
|
|
||||||
|
if len(result):
|
||||||
|
return result[0]
|
15
dbms/tests/external_models/catboost/helpers/generate.py
Normal file
15
dbms/tests/external_models/catboost/helpers/generate.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def generate_uniform_int_column(size, low, high, seed=0):
|
||||||
|
np.random.seed(seed)
|
||||||
|
return np.random.randint(low, high, size)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_uniform_float_column(size, low, high, seed=0):
|
||||||
|
np.random.seed(seed)
|
||||||
|
return np.random.random(size) * (high - low) + low
|
||||||
|
|
||||||
|
|
||||||
|
def generate_uniform_string_column(size, samples, seed):
|
||||||
|
return np.array(samples)[generate_uniform_int_column(size, 0, len(samples), seed)]
|
67
dbms/tests/external_models/catboost/helpers/server.py
Normal file
67
dbms/tests/external_models/catboost/helpers/server.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
import subprocess
|
||||||
|
import threading
|
||||||
|
import socket
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
class ClickHouseServer:
|
||||||
|
def __init__(self, binary_path, config_path, stdout_file=None, stderr_file=None, shutdown_timeout=10):
|
||||||
|
self.binary_path = binary_path
|
||||||
|
self.config_path = config_path
|
||||||
|
self.pipe = None
|
||||||
|
self.stdout_file = stdout_file
|
||||||
|
self.stderr_file = stderr_file
|
||||||
|
self.shutdown_timeout = shutdown_timeout
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
cmd = [self.binary_path, 'server', '--config', self.config_path]
|
||||||
|
out_pipe = None
|
||||||
|
err_pipe = None
|
||||||
|
if self.stdout_file is not None:
|
||||||
|
out_pipe = open(self.stdout_file, 'w')
|
||||||
|
if self.stderr_file is not None:
|
||||||
|
err_pipe = open(self.stderr_file, 'w')
|
||||||
|
self.pipe = subprocess.Popen(cmd, stdout=out_pipe, stderr=err_pipe)
|
||||||
|
|
||||||
|
def wait_for_request(self, port, timeout=1):
|
||||||
|
try:
|
||||||
|
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
# is not working
|
||||||
|
# s.settimeout(timeout)
|
||||||
|
|
||||||
|
step = 0.01
|
||||||
|
for iter in range(int(timeout / step)):
|
||||||
|
if s.connect_ex(('127.0.0.1', port)) == 0:
|
||||||
|
return
|
||||||
|
time.sleep(step)
|
||||||
|
|
||||||
|
s.connect(('127.0.0.1', port))
|
||||||
|
except socket.error as socketerror:
|
||||||
|
print "Error: ", socketerror
|
||||||
|
raise
|
||||||
|
|
||||||
|
def shutdown(self, timeout=10):
|
||||||
|
|
||||||
|
def wait(pipe):
|
||||||
|
pipe.wait()
|
||||||
|
|
||||||
|
if self.pipe is not None:
|
||||||
|
self.pipe.terminate()
|
||||||
|
thread = threading.Thread(target=wait, args=(self.pipe,))
|
||||||
|
thread.start()
|
||||||
|
thread.join(timeout)
|
||||||
|
if thread.isAlive():
|
||||||
|
self.pipe.kill()
|
||||||
|
thread.join()
|
||||||
|
|
||||||
|
if self.pipe.stdout is not None:
|
||||||
|
self.pipe.stdout.close()
|
||||||
|
if self.pipe.stderr is not None:
|
||||||
|
self.pipe.stderr.close()
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.start()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, type, value, traceback):
|
||||||
|
self.shutdown(self.shutdown_timeout)
|
@ -0,0 +1,166 @@
|
|||||||
|
from server import ClickHouseServer
|
||||||
|
from client import ClickHouseClient
|
||||||
|
from table import ClickHouseTable
|
||||||
|
import os
|
||||||
|
import errno
|
||||||
|
from shutil import rmtree
|
||||||
|
|
||||||
|
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
CATBOOST_ROOT = os.path.dirname(SCRIPT_DIR)
|
||||||
|
|
||||||
|
CLICKHOUSE_CONFIG = \
|
||||||
|
'''
|
||||||
|
<yandex>
|
||||||
|
<timezone>Europe/Moscow</timezone>
|
||||||
|
<listen_host>::</listen_host>
|
||||||
|
<path>{path}</path>
|
||||||
|
<tmp_path>{tmp_path}</tmp_path>
|
||||||
|
<models_config>{models_config}</models_config>
|
||||||
|
<mark_cache_size>5368709120</mark_cache_size>
|
||||||
|
<users_config>users.xml</users_config>
|
||||||
|
<tcp_port>{tcp_port}</tcp_port>
|
||||||
|
<catboost_dynamic_library_path>{catboost_dynamic_library_path}</catboost_dynamic_library_path>
|
||||||
|
</yandex>
|
||||||
|
'''
|
||||||
|
|
||||||
|
CLICKHOUSE_USERS = \
|
||||||
|
'''
|
||||||
|
<yandex>
|
||||||
|
<profiles>
|
||||||
|
<default>
|
||||||
|
</default>
|
||||||
|
<readonly>
|
||||||
|
<readonly>1</readonly>
|
||||||
|
</readonly>
|
||||||
|
</profiles>
|
||||||
|
|
||||||
|
<users>
|
||||||
|
<readonly>
|
||||||
|
<password></password>
|
||||||
|
<profile>readonly</profile>
|
||||||
|
<quota>default</quota>
|
||||||
|
</readonly>
|
||||||
|
|
||||||
|
<default>
|
||||||
|
<password></password>
|
||||||
|
<profile>default</profile>
|
||||||
|
<quota>default</quota>
|
||||||
|
<networks incl="networks" replace="replace">
|
||||||
|
<ip>::1</ip>
|
||||||
|
<ip>127.0.0.1</ip>
|
||||||
|
</networks>
|
||||||
|
|
||||||
|
</default>
|
||||||
|
</users>
|
||||||
|
|
||||||
|
<quotas>
|
||||||
|
<default>
|
||||||
|
</default>
|
||||||
|
</quotas>
|
||||||
|
</yandex>
|
||||||
|
'''
|
||||||
|
|
||||||
|
CATBOOST_MODEL_CONFIG = \
|
||||||
|
'''
|
||||||
|
<models>
|
||||||
|
<model>
|
||||||
|
<type>catboost</type>
|
||||||
|
<name>{name}</name>
|
||||||
|
<path>{path}</path>
|
||||||
|
<float_features_count>{float_features_count}</float_features_count>
|
||||||
|
<cat_features_count>{cat_features_count}</cat_features_count>
|
||||||
|
<lifetime>0</lifetime>
|
||||||
|
</model>
|
||||||
|
</models>
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
|
class ClickHouseServerWithCatboostModels:
|
||||||
|
def __init__(self, name, binary_path, port, shutdown_timeout=10, clean_folder=False):
|
||||||
|
self.models = {}
|
||||||
|
self.name = name
|
||||||
|
self.binary_path = binary_path
|
||||||
|
self.port = port
|
||||||
|
self.shutdown_timeout = shutdown_timeout
|
||||||
|
self.clean_folder = clean_folder
|
||||||
|
self.root = os.path.join(CATBOOST_ROOT, 'data', 'servers')
|
||||||
|
self.config_path = os.path.join(self.root, 'config.xml')
|
||||||
|
self.users_path = os.path.join(self.root, 'users.xml')
|
||||||
|
self.models_dir = os.path.join(self.root, 'models')
|
||||||
|
self.server = None
|
||||||
|
|
||||||
|
def _get_server(self):
|
||||||
|
stdout_file = os.path.join(self.root, 'server_stdout.txt')
|
||||||
|
stderr_file = os.path.join(self.root, 'server_stderr.txt')
|
||||||
|
return ClickHouseServer(self.binary_path, self.config_path, stdout_file, stderr_file, self.shutdown_timeout)
|
||||||
|
|
||||||
|
def add_model(self, model_name, model, float_features_count, cat_features_count):
|
||||||
|
self.models[model_name] = (float_features_count, cat_features_count, model)
|
||||||
|
|
||||||
|
def apply_model(self, name, df, cat_feature_names):
|
||||||
|
names = list(df)
|
||||||
|
float_feature_names = tuple(name for name in names if name not in cat_feature_names)
|
||||||
|
with ClickHouseTable(self.server, self.port, name, df) as table:
|
||||||
|
return table.apply_model(name, cat_feature_names, float_feature_names)
|
||||||
|
|
||||||
|
def _create_root(self):
|
||||||
|
try:
|
||||||
|
os.makedirs(self.root)
|
||||||
|
except OSError as exc: # Python >2.5
|
||||||
|
if exc.errno == errno.EEXIST and os.path.isdir(self.root):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _clean_root(self):
|
||||||
|
rmtree(self.root)
|
||||||
|
|
||||||
|
def _save_config(self):
|
||||||
|
params = {
|
||||||
|
'tcp_port': self.port,
|
||||||
|
'path': os.path.join(self.root, 'clickhouse'),
|
||||||
|
'tmp_path': os.path.join(self.root, 'clickhouse', 'tmp'),
|
||||||
|
'models_config': os.path.join(self.models_dir, '*_model.xml'),
|
||||||
|
'catboost_dynamic_library_path': os.path.join(CATBOOST_ROOT, 'data', 'libcatboostmodel.so')
|
||||||
|
}
|
||||||
|
config = CLICKHOUSE_CONFIG.format(**params)
|
||||||
|
|
||||||
|
with open(self.config_path, 'w') as f:
|
||||||
|
f.write(config)
|
||||||
|
|
||||||
|
with open(self.users_path, 'w') as f:
|
||||||
|
f.write(CLICKHOUSE_USERS)
|
||||||
|
|
||||||
|
def _save_models(self):
|
||||||
|
if not os.path.exists(self.models_dir):
|
||||||
|
os.makedirs(self.models_dir)
|
||||||
|
|
||||||
|
for name, params in self.models.items():
|
||||||
|
float_features_count, cat_features_count, model = params
|
||||||
|
model_path = os.path.join(self.models_dir, name + '.cbm')
|
||||||
|
config_path = os.path.join(self.models_dir, name + '_model.xml')
|
||||||
|
params = {
|
||||||
|
'name': name,
|
||||||
|
'path': model_path,
|
||||||
|
'float_features_count': float_features_count,
|
||||||
|
'cat_features_count': cat_features_count
|
||||||
|
}
|
||||||
|
config = CATBOOST_MODEL_CONFIG.format(**params)
|
||||||
|
with open(config_path, 'w') as f:
|
||||||
|
f.write(config)
|
||||||
|
|
||||||
|
model.save_model(model_path)
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self._create_root()
|
||||||
|
self._save_config()
|
||||||
|
self._save_models()
|
||||||
|
self.server = self._get_server().__enter__()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
res = self.server.__exit__(exc_type, exc_val, exc_tb)
|
||||||
|
if self.clean_folder:
|
||||||
|
self._clean_root()
|
||||||
|
return res
|
||||||
|
|
69
dbms/tests/external_models/catboost/helpers/table.py
Normal file
69
dbms/tests/external_models/catboost/helpers/table.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
from server import ClickHouseServer
|
||||||
|
from client import ClickHouseClient
|
||||||
|
from pandas import DataFrame
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
|
||||||
|
class ClickHouseTable:
|
||||||
|
def __init__(self, server, port, table_name, df):
|
||||||
|
self.server = server
|
||||||
|
self.port = port
|
||||||
|
self.table_name = table_name
|
||||||
|
self.df = df
|
||||||
|
|
||||||
|
if not isinstance(self.server, ClickHouseServer):
|
||||||
|
raise Exception('Expected ClickHouseServer, got ' + repr(self.server))
|
||||||
|
if not isinstance(self.df, DataFrame):
|
||||||
|
raise Exception('Expected DataFrame, got ' + repr(self.df))
|
||||||
|
|
||||||
|
self.server.wait_for_request(port)
|
||||||
|
self.client = ClickHouseClient(server.binary_path, port)
|
||||||
|
|
||||||
|
def _convert(self, name):
|
||||||
|
types_map = {
|
||||||
|
'float64': 'Float64',
|
||||||
|
'int64': 'Int64',
|
||||||
|
'float32': 'Float32',
|
||||||
|
'int32': 'Int32'
|
||||||
|
}
|
||||||
|
|
||||||
|
if name in types_map:
|
||||||
|
return types_map[name]
|
||||||
|
return 'String'
|
||||||
|
|
||||||
|
def _create_table_from_df(self):
|
||||||
|
self.client.query('create database if not exists test')
|
||||||
|
self.client.query('drop table if exists test.{}'.format(self.table_name))
|
||||||
|
|
||||||
|
column_types = list(self.df.dtypes)
|
||||||
|
column_names = list(self.df)
|
||||||
|
schema = ', '.join((name + ' ' + self._convert(str(t)) for name, t in zip(column_names, column_types)))
|
||||||
|
print 'schema:', schema
|
||||||
|
|
||||||
|
create_query = 'create table test.{} (date Date DEFAULT today(), {}) engine = MergeTree(date, (date), 8192)'
|
||||||
|
self.client.query(create_query.format(self.table_name, schema))
|
||||||
|
|
||||||
|
insert_query = 'insert into test.{} ({}) format CSV'
|
||||||
|
|
||||||
|
with tempfile.TemporaryFile() as tmp_file:
|
||||||
|
self.df.to_csv(tmp_file, header=False, index=False)
|
||||||
|
tmp_file.seek(0)
|
||||||
|
self.client.query(insert_query.format(self.table_name, ', '.join(column_names)), pipe=tmp_file)
|
||||||
|
|
||||||
|
def apply_model(self, model_name, float_columns, cat_columns):
|
||||||
|
columns = ', '.join(list(float_columns) + list(cat_columns))
|
||||||
|
query = "select modelEvaluate('{}', {}) from test.{} format TSV"
|
||||||
|
result = self.client.query(query.format(model_name, columns, self.table_name))
|
||||||
|
return tuple(map(float, filter(len, map(str.strip, result.split()))))
|
||||||
|
|
||||||
|
def _drop_table(self):
|
||||||
|
self.client.query('drop table test.{}'.format(self.table_name))
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self._create_table_from_df()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, type, value, traceback):
|
||||||
|
self._drop_table()
|
28
dbms/tests/external_models/catboost/helpers/train.py
Normal file
28
dbms/tests/external_models/catboost/helpers/train.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pandas import DataFrame
|
||||||
|
|
||||||
|
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
CATBOOST_ROOT = os.path.dirname(SCRIPT_DIR)
|
||||||
|
CATBOOST_PYTHON_DIR = os.path.join(CATBOOST_ROOT, 'data', 'python-package')
|
||||||
|
|
||||||
|
if CATBOOST_PYTHON_DIR not in sys.path:
|
||||||
|
sys.path.append(CATBOOST_PYTHON_DIR)
|
||||||
|
|
||||||
|
|
||||||
|
import catboost
|
||||||
|
from catboost import CatBoostClassifier
|
||||||
|
|
||||||
|
|
||||||
|
def train_catboost_model(df, target, cat_features, params, verbose=True):
|
||||||
|
|
||||||
|
if not isinstance(df, DataFrame):
|
||||||
|
raise Exception('DataFrame object expected, but got ' + repr(df))
|
||||||
|
|
||||||
|
print 'features:', df.columns.tolist()
|
||||||
|
|
||||||
|
cat_features_index = list(df.columns.get_loc(feature) for feature in cat_features)
|
||||||
|
print 'cat features:', cat_features_index
|
||||||
|
model = CatBoostClassifier(**params)
|
||||||
|
model.fit(df, target, cat_features=cat_features_index, verbose=verbose)
|
||||||
|
return model
|
3
dbms/tests/external_models/catboost/pytest.ini
Normal file
3
dbms/tests/external_models/catboost/pytest.ini
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
[pytest]
|
||||||
|
python_files = test.py
|
||||||
|
norecursedirs=data
|
@ -0,0 +1,236 @@
|
|||||||
|
from helpers.server_with_models import ClickHouseServerWithCatboostModels
|
||||||
|
from helpers.generate import generate_uniform_string_column, generate_uniform_float_column, generate_uniform_int_column
|
||||||
|
from helpers.train import train_catboost_model
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
from pandas import DataFrame
|
||||||
|
|
||||||
|
|
||||||
|
PORT = int(os.environ.get('CLICKHOUSE_TESTS_PORT', '9000'))
|
||||||
|
CLICKHOUSE_TESTS_SERVER_BIN_PATH = os.environ.get('CLICKHOUSE_TESTS_SERVER_BIN_PATH', '/usr/bin/clickhouse')
|
||||||
|
|
||||||
|
|
||||||
|
def add_noise_to_target(target, seed, threshold=0.05):
|
||||||
|
col = generate_uniform_float_column(len(target), 0., 1., seed + 1) < threshold
|
||||||
|
return target * (1 - col) + (1 - target) * col
|
||||||
|
|
||||||
|
|
||||||
|
def check_predictions(test_name, target, pred_python, pred_ch, acc_threshold):
|
||||||
|
ch_class = pred_ch.astype(int)
|
||||||
|
python_class = pred_python.astype(int)
|
||||||
|
if not np.array_equal(ch_class, python_class):
|
||||||
|
raise Exception('Got different results:\npython:\n' + str(python_class) + '\nClickHouse:\n' + str(ch_class))
|
||||||
|
|
||||||
|
acc = 1 - np.sum(np.abs(ch_class - np.array(target))) / (len(target) + .0)
|
||||||
|
assert acc >= acc_threshold
|
||||||
|
print test_name, 'accuracy: {:.10f}'.format(acc)
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_float_features_only():
|
||||||
|
|
||||||
|
name = 'test_apply_float_features_only'
|
||||||
|
|
||||||
|
train_size = 10000
|
||||||
|
test_size = 10000
|
||||||
|
|
||||||
|
def gen_data(size, seed):
|
||||||
|
data = {
|
||||||
|
'a': generate_uniform_float_column(size, 0., 1., seed + 1),
|
||||||
|
'b': generate_uniform_float_column(size, 0., 1., seed + 2),
|
||||||
|
'c': generate_uniform_float_column(size, 0., 1., seed + 3)
|
||||||
|
}
|
||||||
|
return DataFrame.from_dict(data)
|
||||||
|
|
||||||
|
def get_target(df):
|
||||||
|
def target_filter(row):
|
||||||
|
return 1 if (row['a'] > .3 and row['b'] > .3) or (row['c'] < .4 and row['a'] * row['b'] > 0.1) else 0
|
||||||
|
return df.apply(target_filter, axis=1).as_matrix()
|
||||||
|
|
||||||
|
train_df = gen_data(train_size, 42)
|
||||||
|
test_df = gen_data(test_size, 43)
|
||||||
|
|
||||||
|
train_target = get_target(train_df)
|
||||||
|
test_target = get_target(test_df)
|
||||||
|
|
||||||
|
print
|
||||||
|
print 'train target', train_target
|
||||||
|
print 'test target', test_target
|
||||||
|
|
||||||
|
params = {
|
||||||
|
'iterations': 4,
|
||||||
|
'depth': 2,
|
||||||
|
'learning_rate': 1,
|
||||||
|
'loss_function': 'Logloss'
|
||||||
|
}
|
||||||
|
|
||||||
|
model = train_catboost_model(train_df, train_target, [], params)
|
||||||
|
pred_python = model.predict(test_df)
|
||||||
|
|
||||||
|
server = ClickHouseServerWithCatboostModels(name, CLICKHOUSE_TESTS_SERVER_BIN_PATH, PORT)
|
||||||
|
server.add_model(name, model, 3, 0)
|
||||||
|
with server:
|
||||||
|
pred_ch = (np.array(server.apply_model(name, test_df, [])) > 0).astype(int)
|
||||||
|
|
||||||
|
print 'python predictions', pred_python
|
||||||
|
print 'clickhouse predictions', pred_ch
|
||||||
|
|
||||||
|
check_predictions(name, test_target, pred_python, pred_ch, 0.9)
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_float_features_with_string_cat_features():
|
||||||
|
|
||||||
|
name = 'test_apply_float_features_with_string_cat_features'
|
||||||
|
|
||||||
|
train_size = 10000
|
||||||
|
test_size = 10000
|
||||||
|
|
||||||
|
def gen_data(size, seed):
|
||||||
|
data = {
|
||||||
|
'a': generate_uniform_float_column(size, 0., 1., seed + 1),
|
||||||
|
'b': generate_uniform_float_column(size, 0., 1., seed + 2),
|
||||||
|
'c': generate_uniform_string_column(size, ['a', 'b', 'c'], seed + 3),
|
||||||
|
'd': generate_uniform_string_column(size, ['e', 'f', 'g'], seed + 4)
|
||||||
|
}
|
||||||
|
return DataFrame.from_dict(data)
|
||||||
|
|
||||||
|
def get_target(df):
|
||||||
|
def target_filter(row):
|
||||||
|
return 1 if (row['a'] > .3 and row['b'] > .3 and row['c'] != 'a') \
|
||||||
|
or (row['a'] * row['b'] > 0.1 and row['c'] != 'b' and row['d'] != 'e') else 0
|
||||||
|
return df.apply(target_filter, axis=1).as_matrix()
|
||||||
|
|
||||||
|
train_df = gen_data(train_size, 42)
|
||||||
|
test_df = gen_data(test_size, 43)
|
||||||
|
|
||||||
|
train_target = get_target(train_df)
|
||||||
|
test_target = get_target(test_df)
|
||||||
|
|
||||||
|
print
|
||||||
|
print 'train target', train_target
|
||||||
|
print 'test target', test_target
|
||||||
|
|
||||||
|
params = {
|
||||||
|
'iterations': 6,
|
||||||
|
'depth': 2,
|
||||||
|
'learning_rate': 1,
|
||||||
|
'loss_function': 'Logloss'
|
||||||
|
}
|
||||||
|
|
||||||
|
model = train_catboost_model(train_df, train_target, ['c', 'd'], params)
|
||||||
|
pred_python = model.predict(test_df)
|
||||||
|
|
||||||
|
server = ClickHouseServerWithCatboostModels(name, CLICKHOUSE_TESTS_SERVER_BIN_PATH, PORT)
|
||||||
|
server.add_model(name, model, 2, 2)
|
||||||
|
with server:
|
||||||
|
pred_ch = (np.array(server.apply_model(name, test_df, [])) > 0).astype(int)
|
||||||
|
|
||||||
|
print 'python predictions', pred_python
|
||||||
|
print 'clickhouse predictions', pred_ch
|
||||||
|
|
||||||
|
check_predictions(name, test_target, pred_python, pred_ch, 0.9)
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_float_features_with_int_cat_features():
|
||||||
|
|
||||||
|
name = 'test_apply_float_features_with_int_cat_features'
|
||||||
|
|
||||||
|
train_size = 10000
|
||||||
|
test_size = 10000
|
||||||
|
|
||||||
|
def gen_data(size, seed):
|
||||||
|
data = {
|
||||||
|
'a': generate_uniform_float_column(size, 0., 1., seed + 1),
|
||||||
|
'b': generate_uniform_float_column(size, 0., 1., seed + 2),
|
||||||
|
'c': generate_uniform_int_column(size, 1, 4, seed + 3),
|
||||||
|
'd': generate_uniform_int_column(size, 1, 4, seed + 4)
|
||||||
|
}
|
||||||
|
return DataFrame.from_dict(data)
|
||||||
|
|
||||||
|
def get_target(df):
|
||||||
|
def target_filter(row):
|
||||||
|
return 1 if (row['a'] > .3 and row['b'] > .3 and row['c'] != 1) \
|
||||||
|
or (row['a'] * row['b'] > 0.1 and row['c'] != 2 and row['d'] != 3) else 0
|
||||||
|
return df.apply(target_filter, axis=1).as_matrix()
|
||||||
|
|
||||||
|
train_df = gen_data(train_size, 42)
|
||||||
|
test_df = gen_data(test_size, 43)
|
||||||
|
|
||||||
|
train_target = get_target(train_df)
|
||||||
|
test_target = get_target(test_df)
|
||||||
|
|
||||||
|
print
|
||||||
|
print 'train target', train_target
|
||||||
|
print 'test target', test_target
|
||||||
|
|
||||||
|
params = {
|
||||||
|
'iterations': 6,
|
||||||
|
'depth': 4,
|
||||||
|
'learning_rate': 1,
|
||||||
|
'loss_function': 'Logloss'
|
||||||
|
}
|
||||||
|
|
||||||
|
model = train_catboost_model(train_df, train_target, ['c', 'd'], params)
|
||||||
|
pred_python = model.predict(test_df)
|
||||||
|
|
||||||
|
server = ClickHouseServerWithCatboostModels(name, CLICKHOUSE_TESTS_SERVER_BIN_PATH, PORT)
|
||||||
|
server.add_model(name, model, 2, 2)
|
||||||
|
with server:
|
||||||
|
pred_ch = (np.array(server.apply_model(name, test_df, [])) > 0).astype(int)
|
||||||
|
|
||||||
|
print 'python predictions', pred_python
|
||||||
|
print 'clickhouse predictions', pred_ch
|
||||||
|
|
||||||
|
check_predictions(name, test_target, pred_python, pred_ch, 0.9)
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_float_features_with_mixed_cat_features():
|
||||||
|
|
||||||
|
name = 'test_apply_float_features_with_mixed_cat_features'
|
||||||
|
|
||||||
|
train_size = 10000
|
||||||
|
test_size = 10000
|
||||||
|
|
||||||
|
def gen_data(size, seed):
|
||||||
|
data = {
|
||||||
|
'a': generate_uniform_float_column(size, 0., 1., seed + 1),
|
||||||
|
'b': generate_uniform_float_column(size, 0., 1., seed + 2),
|
||||||
|
'c': generate_uniform_string_column(size, ['a', 'b', 'c'], seed + 3),
|
||||||
|
'd': generate_uniform_int_column(size, 1, 4, seed + 4)
|
||||||
|
}
|
||||||
|
return DataFrame.from_dict(data)
|
||||||
|
|
||||||
|
def get_target(df):
|
||||||
|
def target_filter(row):
|
||||||
|
return 1 if (row['a'] > .3 and row['b'] > .3 and row['c'] != 'a') \
|
||||||
|
or (row['a'] * row['b'] > 0.1 and row['c'] != 'b' and row['d'] != 2) else 0
|
||||||
|
return df.apply(target_filter, axis=1).as_matrix()
|
||||||
|
|
||||||
|
train_df = gen_data(train_size, 42)
|
||||||
|
test_df = gen_data(test_size, 43)
|
||||||
|
|
||||||
|
train_target = get_target(train_df)
|
||||||
|
test_target = get_target(test_df)
|
||||||
|
|
||||||
|
print
|
||||||
|
print 'train target', train_target
|
||||||
|
print 'test target', test_target
|
||||||
|
|
||||||
|
params = {
|
||||||
|
'iterations': 6,
|
||||||
|
'depth': 4,
|
||||||
|
'learning_rate': 1,
|
||||||
|
'loss_function': 'Logloss'
|
||||||
|
}
|
||||||
|
|
||||||
|
model = train_catboost_model(train_df, train_target, ['c', 'd'], params)
|
||||||
|
pred_python = model.predict(test_df)
|
||||||
|
|
||||||
|
server = ClickHouseServerWithCatboostModels(name, CLICKHOUSE_TESTS_SERVER_BIN_PATH, PORT)
|
||||||
|
server.add_model(name, model, 2, 2)
|
||||||
|
with server:
|
||||||
|
pred_ch = (np.array(server.apply_model(name, test_df, [])) > 0).astype(int)
|
||||||
|
|
||||||
|
print 'python predictions', pred_python
|
||||||
|
print 'clickhouse predictions', pred_ch
|
||||||
|
|
||||||
|
check_predictions(name, test_target, pred_python, pred_ch, 0.9)
|
Loading…
Reference in New Issue
Block a user