mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-25 00:52:02 +00:00
Merge remote-tracking branch 'upstream/master' into fix4
This commit is contained in:
commit
723cbb310d
@ -384,6 +384,8 @@ namespace ErrorCodes
|
||||
extern const int UNKNOWN_STATUS_OF_DISTRIBUTED_DDL_TASK = 379;
|
||||
extern const int CANNOT_KILL = 380;
|
||||
extern const int HTTP_LENGTH_REQUIRED = 381;
|
||||
extern const int CANNOT_LOAD_CATBOOST_MODEL = 382;
|
||||
extern const int CANNOT_APPLY_CATBOOST_MODEL = 383;
|
||||
|
||||
extern const int KEEPER_EXCEPTION = 999;
|
||||
extern const int POCO_EXCEPTION = 1000;
|
||||
|
@ -26,18 +26,19 @@ void DatabaseDictionary::loadTables(Context & context, ThreadPool * thread_pool,
|
||||
|
||||
Tables DatabaseDictionary::loadTables()
|
||||
{
|
||||
const std::lock_guard<std::mutex> lock_dictionaries {external_dictionaries.dictionaries_mutex};
|
||||
auto objects_map = external_dictionaries.getObjectsMap();
|
||||
const auto & dictionaries = objects_map.get();
|
||||
|
||||
Tables tables;
|
||||
for (const auto & pair : external_dictionaries.dictionaries)
|
||||
for (const auto & pair : dictionaries)
|
||||
{
|
||||
const std::string & name = pair.first;
|
||||
if (deleted_tables.count(name))
|
||||
continue;
|
||||
auto dict_ptr = pair.second.dict;
|
||||
auto dict_ptr = std::static_pointer_cast<IDictionaryBase>(pair.second.loadable);
|
||||
if (dict_ptr)
|
||||
{
|
||||
const DictionaryStructure & dictionary_structure = dict_ptr->get()->getStructure();
|
||||
const DictionaryStructure & dictionary_structure = dict_ptr->getStructure();
|
||||
auto columns = StorageDictionary::getNamesAndTypes(dictionary_structure);
|
||||
tables[name] = StorageDictionary::create(name, columns, {}, {}, {}, dictionary_structure, name);
|
||||
}
|
||||
@ -50,26 +51,28 @@ bool DatabaseDictionary::isTableExist(
|
||||
const Context & context,
|
||||
const String & table_name) const
|
||||
{
|
||||
const std::lock_guard<std::mutex> lock_dictionaries {external_dictionaries.dictionaries_mutex};
|
||||
return external_dictionaries.dictionaries.count(table_name) && !deleted_tables.count(table_name);
|
||||
auto objects_map = external_dictionaries.getObjectsMap();
|
||||
const auto & dictionaries = objects_map.get();
|
||||
return dictionaries.count(table_name) && !deleted_tables.count(table_name);
|
||||
}
|
||||
|
||||
StoragePtr DatabaseDictionary::tryGetTable(
|
||||
const Context & context,
|
||||
const String & table_name)
|
||||
{
|
||||
const std::lock_guard<std::mutex> lock_dictionaries {external_dictionaries.dictionaries_mutex};
|
||||
auto objects_map = external_dictionaries.getObjectsMap();
|
||||
const auto & dictionaries = objects_map.get();
|
||||
|
||||
if (deleted_tables.count(table_name))
|
||||
return {};
|
||||
{
|
||||
auto it = external_dictionaries.dictionaries.find(table_name);
|
||||
if (it != external_dictionaries.dictionaries.end())
|
||||
auto it = dictionaries.find(table_name);
|
||||
if (it != dictionaries.end())
|
||||
{
|
||||
const auto & dict_ptr = it->second.dict;
|
||||
const auto & dict_ptr = std::static_pointer_cast<IDictionaryBase>(it->second.loadable);
|
||||
if (dict_ptr)
|
||||
{
|
||||
const DictionaryStructure & dictionary_structure = dict_ptr->get()->getStructure();
|
||||
const DictionaryStructure & dictionary_structure = dict_ptr->getStructure();
|
||||
auto columns = StorageDictionary::getNamesAndTypes(dictionary_structure);
|
||||
return StorageDictionary::create(table_name, columns, {}, {}, {}, dictionary_structure, table_name);
|
||||
}
|
||||
@ -86,9 +89,10 @@ DatabaseIteratorPtr DatabaseDictionary::getIterator(const Context & context)
|
||||
|
||||
bool DatabaseDictionary::empty(const Context & context) const
|
||||
{
|
||||
const std::lock_guard<std::mutex> lock_dictionaries {external_dictionaries.dictionaries_mutex};
|
||||
for (const auto & pair : external_dictionaries.dictionaries)
|
||||
if (pair.second.dict && !deleted_tables.count(pair.first))
|
||||
auto objects_map = external_dictionaries.getObjectsMap();
|
||||
const auto & dictionaries = objects_map.get();
|
||||
for (const auto & pair : dictionaries)
|
||||
if (pair.second.loadable && !deleted_tables.count(pair.first))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
@ -119,7 +123,7 @@ void DatabaseDictionary::removeTable(
|
||||
if (!isTableExist(context, table_name))
|
||||
throw Exception("Table " + name + "." + table_name + " doesn't exist.", ErrorCodes::UNKNOWN_TABLE);
|
||||
|
||||
const std::lock_guard<std::mutex> lock_dictionaries {external_dictionaries.dictionaries_mutex};
|
||||
auto objects_map = external_dictionaries.getObjectsMap();
|
||||
deleted_tables.insert(table_name);
|
||||
}
|
||||
|
||||
@ -156,7 +160,6 @@ ASTPtr DatabaseDictionary::getCreateQuery(
|
||||
const String & table_name) const
|
||||
{
|
||||
throw Exception("DatabaseDictionary: getCreateQuery() is not supported", ErrorCodes::NOT_IMPLEMENTED);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void DatabaseDictionary::shutdown()
|
||||
|
@ -54,7 +54,7 @@ public:
|
||||
|
||||
bool isCached() const override { return true; }
|
||||
|
||||
DictionaryPtr clone() const override { return std::make_unique<CacheDictionary>(*this); }
|
||||
std::unique_ptr<IExternalLoadable> clone() const override { return std::make_unique<CacheDictionary>(*this); }
|
||||
|
||||
const IDictionarySource * getSource() const override { return source_ptr.get(); }
|
||||
|
||||
|
523
dbms/src/Dictionaries/CatBoostModel.cpp
Normal file
523
dbms/src/Dictionaries/CatBoostModel.cpp
Normal file
@ -0,0 +1,523 @@
|
||||
#include <Dictionaries/CatBoostModel.h>
|
||||
#include <Core/FieldVisitors.h>
|
||||
#include <mutex>
|
||||
#include <Columns/ColumnString.h>
|
||||
#include <Columns/ColumnFixedString.h>
|
||||
#include <Columns/ColumnVector.h>
|
||||
#include <Common/typeid_cast.h>
|
||||
#include <IO/WriteBufferFromString.h>
|
||||
#include <IO/Operators.h>
|
||||
#include <Common/PODArray.h>
|
||||
#include <Common/SharedLibrary.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int LOGICAL_ERROR;
|
||||
extern const int BAD_ARGUMENTS;
|
||||
extern const int CANNOT_LOAD_CATBOOST_MODEL;
|
||||
extern const int CANNOT_APPLY_CATBOOST_MODEL;
|
||||
}
|
||||
|
||||
|
||||
/// CatBoost wrapper interface functions.
|
||||
struct CatBoostWrapperAPI
|
||||
{
|
||||
typedef void ModelCalcerHandle;
|
||||
|
||||
ModelCalcerHandle * (* ModelCalcerCreate)();
|
||||
|
||||
void (* ModelCalcerDelete)(ModelCalcerHandle * calcer);
|
||||
|
||||
const char * (* GetErrorString)();
|
||||
|
||||
bool (* LoadFullModelFromFile)(ModelCalcerHandle * calcer, const char * filename);
|
||||
|
||||
bool (* CalcModelPredictionFlat)(ModelCalcerHandle * calcer, size_t docCount,
|
||||
const float ** floatFeatures, size_t floatFeaturesSize,
|
||||
double * result, size_t resultSize);
|
||||
|
||||
bool (* CalcModelPrediction)(ModelCalcerHandle * calcer, size_t docCount,
|
||||
const float ** floatFeatures, size_t floatFeaturesSize,
|
||||
const char *** catFeatures, size_t catFeaturesSize,
|
||||
double * result, size_t resultSize);
|
||||
|
||||
bool (* CalcModelPredictionWithHashedCatFeatures)(ModelCalcerHandle * calcer, size_t docCount,
|
||||
const float ** floatFeatures, size_t floatFeaturesSize,
|
||||
const int ** catFeatures, size_t catFeaturesSize,
|
||||
double * result, size_t resultSize);
|
||||
|
||||
int (* GetStringCatFeatureHash)(const char * data, size_t size);
|
||||
|
||||
int (* GetIntegerCatFeatureHash)(long long val);
|
||||
|
||||
size_t (* GetFloatFeaturesCount)(ModelCalcerHandle* calcer);
|
||||
|
||||
size_t (* GetCatFeaturesCount)(ModelCalcerHandle* calcer);
|
||||
};
|
||||
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
class CatBoostModelHolder
|
||||
{
|
||||
private:
|
||||
CatBoostWrapperAPI::ModelCalcerHandle * handle;
|
||||
const CatBoostWrapperAPI * api;
|
||||
public:
|
||||
explicit CatBoostModelHolder(const CatBoostWrapperAPI * api) : api(api) { handle = api->ModelCalcerCreate(); }
|
||||
~CatBoostModelHolder() { api->ModelCalcerDelete(handle); }
|
||||
|
||||
CatBoostWrapperAPI::ModelCalcerHandle * get() { return handle; }
|
||||
explicit operator CatBoostWrapperAPI::ModelCalcerHandle * () { return handle; }
|
||||
};
|
||||
|
||||
|
||||
class CatBoostModelImpl : public ICatBoostModel
|
||||
{
|
||||
public:
|
||||
CatBoostModelImpl(const CatBoostWrapperAPI * api, const std::string & model_path) : api(api)
|
||||
{
|
||||
auto handle_ = std::make_unique<CatBoostModelHolder>(api);
|
||||
if (!handle_)
|
||||
{
|
||||
std::string msg = "Cannot create CatBoost model: ";
|
||||
throw Exception(msg + api->GetErrorString(), ErrorCodes::CANNOT_LOAD_CATBOOST_MODEL);
|
||||
}
|
||||
if (!api->LoadFullModelFromFile(handle_->get(), model_path.c_str()))
|
||||
{
|
||||
std::string msg = "Cannot load CatBoost model: ";
|
||||
throw Exception(msg + api->GetErrorString(), ErrorCodes::CANNOT_LOAD_CATBOOST_MODEL);
|
||||
}
|
||||
handle = std::move(handle_);
|
||||
}
|
||||
|
||||
ColumnPtr evaluate(const ConstColumnPlainPtrs & columns,
|
||||
size_t float_features_count, size_t cat_features_count) const override
|
||||
{
|
||||
checkFeaturesCount(float_features_count, cat_features_count);
|
||||
|
||||
if (columns.empty())
|
||||
throw Exception("Got empty columns list for CatBoost model.", ErrorCodes::BAD_ARGUMENTS);
|
||||
|
||||
if (columns.size() != float_features_count + cat_features_count)
|
||||
{
|
||||
std::string msg;
|
||||
{
|
||||
WriteBufferFromString buffer(msg);
|
||||
buffer << "Number of columns is different with number of features: ";
|
||||
buffer << columns.size() << " vs " << float_features_count << " + " << cat_features_count;
|
||||
}
|
||||
throw Exception(msg, ErrorCodes::BAD_ARGUMENTS);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < float_features_count; ++i)
|
||||
{
|
||||
if (!columns[i]->isNumeric())
|
||||
{
|
||||
std::string msg;
|
||||
{
|
||||
WriteBufferFromString buffer(msg);
|
||||
buffer << "Column " << i << "should be numeric to make float feature.";
|
||||
}
|
||||
throw Exception(msg, ErrorCodes::BAD_ARGUMENTS);
|
||||
}
|
||||
}
|
||||
|
||||
bool cat_features_are_strings = true;
|
||||
for (size_t i = float_features_count; i < float_features_count + cat_features_count; ++i)
|
||||
{
|
||||
auto column = columns[i];
|
||||
if (column->isNumeric())
|
||||
cat_features_are_strings = false;
|
||||
else if (!(typeid_cast<const ColumnString *>(column)
|
||||
|| typeid_cast<const ColumnFixedString *>(column)))
|
||||
{
|
||||
std::string msg;
|
||||
{
|
||||
WriteBufferFromString buffer(msg);
|
||||
buffer << "Column " << i << "should be numeric or string.";
|
||||
}
|
||||
throw Exception(msg, ErrorCodes::BAD_ARGUMENTS);
|
||||
}
|
||||
}
|
||||
|
||||
return evalImpl(columns, float_features_count, cat_features_count, cat_features_are_strings);
|
||||
}
|
||||
|
||||
void checkFeaturesCount(size_t float_features_count, size_t cat_features_count) const
|
||||
{
|
||||
if (api->GetFloatFeaturesCount)
|
||||
{
|
||||
size_t float_features_in_model = api->GetFloatFeaturesCount(handle->get());
|
||||
if (float_features_count != float_features_in_model)
|
||||
throw Exception("CatBoost model expected " + std::to_string(float_features_in_model) + " float features"
|
||||
+ ", but " + std::to_string(float_features_count) + " was provided.");
|
||||
}
|
||||
|
||||
if (api->GetCatFeaturesCount)
|
||||
{
|
||||
size_t cat_features_in_model = api->GetCatFeaturesCount(handle->get());
|
||||
if (cat_features_count != cat_features_in_model)
|
||||
throw Exception("CatBoost model expected " + std::to_string(cat_features_in_model) + " cat features"
|
||||
+ ", but " + std::to_string(cat_features_count) + " was provided.");
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<CatBoostModelHolder> handle;
|
||||
const CatBoostWrapperAPI * api;
|
||||
|
||||
/// Buffer should be allocated with features_count * column->size() elements.
|
||||
/// Place column elements in positions buffer[0], buffer[features_count], ... , buffer[size * features_count]
|
||||
template <typename T>
|
||||
void placeColumnAsNumber(const IColumn * column, T * buffer, size_t features_count) const
|
||||
{
|
||||
size_t size = column->size();
|
||||
FieldVisitorConvertToNumber<T> visitor;
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
/// TODO: Replace with column visitor.
|
||||
Field field;
|
||||
column->get(i, field);
|
||||
*buffer = applyVisitor(visitor, field);
|
||||
buffer += features_count;
|
||||
}
|
||||
}
|
||||
|
||||
/// Buffer should be allocated with features_count * column->size() elements.
|
||||
/// Place string pointers in positions buffer[0], buffer[features_count], ... , buffer[size * features_count]
|
||||
void placeStringColumn(const ColumnString & column, const char ** buffer, size_t features_count) const
|
||||
{
|
||||
size_t size = column.size();
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
*buffer = const_cast<char *>(column.getDataAtWithTerminatingZero(i).data);
|
||||
buffer += features_count;
|
||||
}
|
||||
}
|
||||
|
||||
/// Buffer should be allocated with features_count * column->size() elements.
|
||||
/// Place string pointers in positions buffer[0], buffer[features_count], ... , buffer[size * features_count]
|
||||
/// Returns PODArray which holds data (because ColumnFixedString doesn't store terminating zero).
|
||||
PODArray<char> placeFixedStringColumn(
|
||||
const ColumnFixedString & column, const char ** buffer, size_t features_count) const
|
||||
{
|
||||
size_t size = column.size();
|
||||
size_t str_size = column.getN();
|
||||
PODArray<char> data(size * (str_size + 1));
|
||||
char * data_ptr = data.data();
|
||||
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
auto ref = column.getDataAt(i);
|
||||
memcpy(data_ptr, ref.data, ref.size);
|
||||
data_ptr[ref.size] = 0;
|
||||
*buffer = data_ptr;
|
||||
data_ptr += ref.size + 1;
|
||||
buffer += features_count;
|
||||
}
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
/// Place columns into buffer, returns column which holds placed data. Buffer should contains column->size() values.
|
||||
template <typename T>
|
||||
ColumnPtr placeNumericColumns(const ConstColumnPlainPtrs & columns,
|
||||
size_t offset, size_t size, const T** buffer) const
|
||||
{
|
||||
if (size == 0)
|
||||
return nullptr;
|
||||
size_t column_size = columns[offset]->size();
|
||||
auto data_column = std::make_shared<ColumnVector<T>>(size * column_size);
|
||||
T* data = data_column->getData().data();
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
auto column = columns[offset + i];
|
||||
if (column->isNumeric())
|
||||
placeColumnAsNumber(column, data + i, size);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < column_size; ++i)
|
||||
{
|
||||
*buffer = data;
|
||||
++buffer;
|
||||
data += size;
|
||||
}
|
||||
|
||||
return data_column;
|
||||
}
|
||||
|
||||
/// Place columns into buffer, returns data which was used for fixed string columns.
|
||||
/// Buffer should contains column->size() values, each value contains size strings.
|
||||
std::vector<PODArray<char>> placeStringColumns(
|
||||
const ConstColumnPlainPtrs & columns, size_t offset, size_t size, const char ** buffer) const
|
||||
{
|
||||
if (size == 0)
|
||||
return {};
|
||||
|
||||
std::vector<PODArray<char>> data;
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
auto column = columns[offset + i];
|
||||
if (auto column_string = typeid_cast<const ColumnString *>(column))
|
||||
placeStringColumn(*column_string, buffer + i, size);
|
||||
else if (auto column_fixed_string = typeid_cast<const ColumnFixedString *>(column))
|
||||
data.push_back(placeFixedStringColumn(*column_fixed_string, buffer + i, size));
|
||||
else
|
||||
throw Exception("Cannot place string column.", ErrorCodes::LOGICAL_ERROR);
|
||||
}
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
/// Calc hash for string cat feature at ps positions.
|
||||
template <typename Column>
|
||||
void calcStringHashes(const Column * column, size_t ps, const int ** buffer) const
|
||||
{
|
||||
size_t column_size = column->size();
|
||||
for (size_t j = 0; j < column_size; ++j)
|
||||
{
|
||||
auto ref = column->getDataAt(j);
|
||||
const_cast<int *>(*buffer)[ps] = api->GetStringCatFeatureHash(ref.data, ref.size);
|
||||
++buffer;
|
||||
}
|
||||
}
|
||||
|
||||
/// Calc hash for int cat feature at ps position. Buffer at positions ps should contains unhashed values.
|
||||
void calcIntHashes(size_t column_size, size_t ps, const int ** buffer) const
|
||||
{
|
||||
for (size_t j = 0; j < column_size; ++j)
|
||||
{
|
||||
const_cast<int *>(*buffer)[ps] = api->GetIntegerCatFeatureHash((*buffer)[ps]);
|
||||
++buffer;
|
||||
}
|
||||
}
|
||||
|
||||
/// buffer contains column->size() rows and size columns.
|
||||
/// For int cat features calc hash inplace.
|
||||
/// For string cat features calc hash from column rows.
|
||||
void calcHashes(const ConstColumnPlainPtrs & columns, size_t offset, size_t size, const int ** buffer) const
|
||||
{
|
||||
if (size == 0)
|
||||
return;
|
||||
size_t column_size = columns[offset]->size();
|
||||
|
||||
std::vector<PODArray<char>> data;
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
auto column = columns[offset + i];
|
||||
if (auto column_string = typeid_cast<const ColumnString *>(column))
|
||||
calcStringHashes(column_string, i, buffer);
|
||||
else if (auto column_fixed_string = typeid_cast<const ColumnFixedString *>(column))
|
||||
calcStringHashes(column_fixed_string, i, buffer);
|
||||
else
|
||||
calcIntHashes(column_size, i, buffer);
|
||||
}
|
||||
}
|
||||
|
||||
/// buffer[column_size * cat_features_count] -> char * => cat_features[column_size][cat_features_count] -> char *
|
||||
void fillCatFeaturesBuffer(const char *** cat_features, const char ** buffer,
|
||||
size_t column_size, size_t cat_features_count) const
|
||||
{
|
||||
for (size_t i = 0; i < column_size; ++i)
|
||||
{
|
||||
*cat_features = buffer;
|
||||
++cat_features;
|
||||
buffer += cat_features_count;
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert values to row-oriented format and call evaluation function from CatBoost wrapper api.
|
||||
/// * CalcModelPredictionFlat if no cat features
|
||||
/// * CalcModelPrediction if all cat features are strings
|
||||
/// * CalcModelPredictionWithHashedCatFeatures if has int cat features.
|
||||
ColumnPtr evalImpl(const ConstColumnPlainPtrs & columns, size_t float_features_count, size_t cat_features_count,
|
||||
bool cat_features_are_strings) const
|
||||
{
|
||||
std::string error_msg = "Error occurred while applying CatBoost model: ";
|
||||
size_t column_size = columns.front()->size();
|
||||
|
||||
auto result= std::make_shared<ColumnFloat64>(column_size);
|
||||
auto result_buf = result->getData().data();
|
||||
|
||||
/// Prepare float features.
|
||||
PODArray<const float *> float_features(column_size);
|
||||
auto float_features_buf = float_features.data();
|
||||
/// Store all float data into single column. float_features is a list of pointers to it.
|
||||
auto float_features_col = placeNumericColumns<float>(columns, 0, float_features_count, float_features_buf);
|
||||
|
||||
if (cat_features_count == 0)
|
||||
{
|
||||
if (!api->CalcModelPredictionFlat(handle->get(), column_size,
|
||||
float_features_buf, float_features_count,
|
||||
result_buf, column_size))
|
||||
{
|
||||
|
||||
throw Exception(error_msg + api->GetErrorString(), ErrorCodes::CANNOT_APPLY_CATBOOST_MODEL);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Prepare cat features.
|
||||
if (cat_features_are_strings)
|
||||
{
|
||||
/// cat_features_holder stores pointers to ColumnString data or fixed_strings_data.
|
||||
PODArray<const char *> cat_features_holder(cat_features_count * column_size);
|
||||
PODArray<const char **> cat_features(column_size);
|
||||
auto cat_features_buf = cat_features.data();
|
||||
|
||||
fillCatFeaturesBuffer(cat_features_buf, cat_features_holder.data(), column_size, cat_features_count);
|
||||
/// Fixed strings are stored without termination zero, so have to copy data into fixed_strings_data.
|
||||
auto fixed_strings_data = placeStringColumns(columns, float_features_count,
|
||||
cat_features_count, cat_features_holder.data());
|
||||
|
||||
if (!api->CalcModelPrediction(handle->get(), column_size,
|
||||
float_features_buf, float_features_count,
|
||||
cat_features_buf, cat_features_count,
|
||||
result_buf, column_size))
|
||||
{
|
||||
throw Exception(error_msg + api->GetErrorString(), ErrorCodes::CANNOT_APPLY_CATBOOST_MODEL);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
PODArray<const int *> cat_features(column_size);
|
||||
auto cat_features_buf = cat_features.data();
|
||||
auto cat_features_col = placeNumericColumns<int>(columns, float_features_count,
|
||||
cat_features_count, cat_features_buf);
|
||||
calcHashes(columns, float_features_count, cat_features_count, cat_features_buf);
|
||||
if (!api->CalcModelPredictionWithHashedCatFeatures(
|
||||
handle->get(), column_size,
|
||||
float_features_buf, float_features_count,
|
||||
cat_features_buf, cat_features_count,
|
||||
result_buf, column_size))
|
||||
{
|
||||
throw Exception(error_msg + api->GetErrorString(), ErrorCodes::CANNOT_APPLY_CATBOOST_MODEL);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/// Holds CatBoost wrapper library and provides wrapper interface.
|
||||
class CatBoostLibHolder: public CatBoostWrapperAPIProvider
|
||||
{
|
||||
public:
|
||||
explicit CatBoostLibHolder(const std::string & lib_path) : lib_path(lib_path), lib(lib_path) { initAPI(); }
|
||||
|
||||
const CatBoostWrapperAPI & getAPI() const override { return api; }
|
||||
const std::string & getCurrentPath() const { return lib_path; }
|
||||
|
||||
private:
|
||||
CatBoostWrapperAPI api;
|
||||
std::string lib_path;
|
||||
SharedLibrary lib;
|
||||
|
||||
void initAPI();
|
||||
|
||||
template <typename T>
|
||||
void load(T& func, const std::string & name) { func = lib.get<T>(name); }
|
||||
|
||||
template <typename T>
|
||||
void tryLoad(T& func, const std::string & name) { func = lib.tryGet<T>(name); }
|
||||
};
|
||||
|
||||
void CatBoostLibHolder::initAPI()
|
||||
{
|
||||
load(api.ModelCalcerCreate, "ModelCalcerCreate");
|
||||
load(api.ModelCalcerDelete, "ModelCalcerDelete");
|
||||
load(api.GetErrorString, "GetErrorString");
|
||||
load(api.LoadFullModelFromFile, "LoadFullModelFromFile");
|
||||
load(api.CalcModelPredictionFlat, "CalcModelPredictionFlat");
|
||||
load(api.CalcModelPrediction, "CalcModelPrediction");
|
||||
load(api.CalcModelPredictionWithHashedCatFeatures, "CalcModelPredictionWithHashedCatFeatures");
|
||||
load(api.GetStringCatFeatureHash, "GetStringCatFeatureHash");
|
||||
load(api.GetIntegerCatFeatureHash, "GetIntegerCatFeatureHash");
|
||||
|
||||
tryLoad(api.GetFloatFeaturesCount, "GetFloatFeaturesCount");
|
||||
tryLoad(api.GetCatFeaturesCount, "GetCatFeaturesCount");
|
||||
}
|
||||
|
||||
std::shared_ptr<CatBoostLibHolder> getCatBoostWrapperHolder(const std::string & lib_path)
|
||||
{
|
||||
static std::weak_ptr<CatBoostLibHolder> ptr;
|
||||
static std::mutex mutex;
|
||||
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
auto result = ptr.lock();
|
||||
|
||||
if (!result || result->getCurrentPath() != lib_path)
|
||||
{
|
||||
result = std::make_shared<CatBoostLibHolder>(lib_path);
|
||||
/// This assignment is not atomic, which prevents from creating lock only inside 'if'.
|
||||
ptr = result;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
CatBoostModel::CatBoostModel(const std::string & name, const std::string & model_path, const std::string & lib_path,
|
||||
const ExternalLoadableLifetime & lifetime,
|
||||
size_t float_features_count, size_t cat_features_count)
|
||||
: name(name), model_path(model_path), lib_path(lib_path), lifetime(lifetime),
|
||||
float_features_count(float_features_count), cat_features_count(cat_features_count)
|
||||
{
|
||||
try
|
||||
{
|
||||
init(lib_path);
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
creation_exception = std::current_exception();
|
||||
}
|
||||
}
|
||||
|
||||
void CatBoostModel::init(const std::string & lib_path)
|
||||
{
|
||||
api_provider = getCatBoostWrapperHolder(lib_path);
|
||||
api = &api_provider->getAPI();
|
||||
model = std::make_unique<CatBoostModelImpl>(api, model_path);
|
||||
}
|
||||
|
||||
const ExternalLoadableLifetime & CatBoostModel::getLifetime() const
|
||||
{
|
||||
return lifetime;
|
||||
}
|
||||
|
||||
bool CatBoostModel::isModified() const
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
std::unique_ptr<IExternalLoadable> CatBoostModel::clone() const
|
||||
{
|
||||
return std::make_unique<CatBoostModel>(name, model_path, lib_path, lifetime, float_features_count, cat_features_count);
|
||||
}
|
||||
|
||||
size_t CatBoostModel::getFloatFeaturesCount() const
|
||||
{
|
||||
return float_features_count;
|
||||
}
|
||||
|
||||
size_t CatBoostModel::getCatFeaturesCount() const
|
||||
{
|
||||
return cat_features_count;
|
||||
}
|
||||
|
||||
ColumnPtr CatBoostModel::evaluate(const ConstColumnPlainPtrs & columns) const
|
||||
{
|
||||
if (!model)
|
||||
throw Exception("CatBoost model was not loaded.", ErrorCodes::LOGICAL_ERROR);
|
||||
return model->evaluate(columns, float_features_count, cat_features_count);
|
||||
}
|
||||
|
||||
}
|
79
dbms/src/Dictionaries/CatBoostModel.h
Normal file
79
dbms/src/Dictionaries/CatBoostModel.h
Normal file
@ -0,0 +1,79 @@
|
||||
#pragma once
|
||||
#include <Interpreters/IExternalLoadable.h>
|
||||
#include <Columns/IColumn.h>
|
||||
#include <Columns/ColumnsNumber.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
/// CatBoost wrapper interface functions.
|
||||
struct CatBoostWrapperAPI;
|
||||
class CatBoostWrapperAPIProvider
|
||||
{
|
||||
public:
|
||||
virtual ~CatBoostWrapperAPIProvider() = default;
|
||||
virtual const CatBoostWrapperAPI & getAPI() const = 0;
|
||||
};
|
||||
|
||||
/// CatBoost model interface.
|
||||
class ICatBoostModel
|
||||
{
|
||||
public:
|
||||
virtual ~ICatBoostModel() = default;
|
||||
/// Evaluate model. Use first `float_features_count` columns as float features,
|
||||
/// the others `cat_features_count` as categorical features.
|
||||
virtual ColumnPtr evaluate(const ConstColumnPlainPtrs & columns, size_t float_features_count, size_t cat_features_count) const = 0;
|
||||
};
|
||||
|
||||
/// General ML model evaluator interface.
|
||||
class IModel : public IExternalLoadable
|
||||
{
|
||||
public:
|
||||
virtual ColumnPtr evaluate(const ConstColumnPlainPtrs & columns) const = 0;
|
||||
};
|
||||
|
||||
class CatBoostModel : public IModel
|
||||
{
|
||||
public:
|
||||
CatBoostModel(const std::string & name, const std::string & model_path,
|
||||
const std::string & lib_path, const ExternalLoadableLifetime & lifetime,
|
||||
size_t float_features_count, size_t cat_features_count);
|
||||
|
||||
ColumnPtr evaluate(const ConstColumnPlainPtrs & columns) const override;
|
||||
|
||||
size_t getFloatFeaturesCount() const;
|
||||
size_t getCatFeaturesCount() const;
|
||||
|
||||
/// IExternalLoadable interface.
|
||||
|
||||
const ExternalLoadableLifetime & getLifetime() const override;
|
||||
|
||||
std::string getName() const override { return name; }
|
||||
|
||||
bool supportUpdates() const override { return true; }
|
||||
|
||||
bool isModified() const override;
|
||||
|
||||
std::unique_ptr<IExternalLoadable> clone() const override;
|
||||
|
||||
std::exception_ptr getCreationException() const override { return creation_exception; }
|
||||
|
||||
private:
|
||||
std::string name;
|
||||
std::string model_path;
|
||||
std::string lib_path;
|
||||
ExternalLoadableLifetime lifetime;
|
||||
std::exception_ptr creation_exception;
|
||||
std::shared_ptr<CatBoostWrapperAPIProvider> api_provider;
|
||||
const CatBoostWrapperAPI * api;
|
||||
|
||||
std::unique_ptr<ICatBoostModel> model;
|
||||
|
||||
size_t float_features_count;
|
||||
size_t cat_features_count;
|
||||
|
||||
void init(const std::string & lib_path);
|
||||
};
|
||||
|
||||
}
|
@ -98,7 +98,7 @@ public:
|
||||
return true;
|
||||
}
|
||||
|
||||
DictionaryPtr clone() const override
|
||||
std::unique_ptr<IExternalLoadable> clone() const override
|
||||
{
|
||||
return std::make_unique<ComplexKeyCacheDictionary>(*this);
|
||||
}
|
||||
|
@ -46,7 +46,7 @@ public:
|
||||
|
||||
bool isCached() const override { return false; }
|
||||
|
||||
DictionaryPtr clone() const override { return std::make_unique<ComplexKeyHashedDictionary>(*this); }
|
||||
std::unique_ptr<IExternalLoadable> clone() const override { return std::make_unique<ComplexKeyHashedDictionary>(*this); }
|
||||
|
||||
const IDictionarySource * getSource() const override { return source_ptr.get(); }
|
||||
|
||||
|
@ -13,7 +13,7 @@ class Context;
|
||||
class DictionaryFactory : public ext::singleton<DictionaryFactory>
|
||||
{
|
||||
public:
|
||||
DictionaryPtr create(const std::string & name, Poco::Util::AbstractConfiguration & config,
|
||||
DictionaryPtr create(const std::string & name, const Poco::Util::AbstractConfiguration & config,
|
||||
const std::string & config_prefix, Context & context) const;
|
||||
};
|
||||
|
||||
|
@ -90,7 +90,7 @@ DictionarySourceFactory::DictionarySourceFactory()
|
||||
|
||||
|
||||
DictionarySourcePtr DictionarySourceFactory::create(
|
||||
const std::string & name, Poco::Util::AbstractConfiguration & config, const std::string & config_prefix,
|
||||
const std::string & name, const Poco::Util::AbstractConfiguration & config, const std::string & config_prefix,
|
||||
const DictionaryStructure & dict_struct, Context & context) const
|
||||
{
|
||||
Poco::Util::AbstractConfiguration::Keys keys;
|
||||
|
@ -25,7 +25,7 @@ public:
|
||||
DictionarySourceFactory();
|
||||
|
||||
DictionarySourcePtr create(
|
||||
const std::string & name, Poco::Util::AbstractConfiguration & config, const std::string & config_prefix,
|
||||
const std::string & name, const Poco::Util::AbstractConfiguration & config, const std::string & config_prefix,
|
||||
const DictionaryStructure & dict_struct, Context & context) const;
|
||||
};
|
||||
|
||||
|
@ -105,16 +105,6 @@ std::string toString(const AttributeUnderlyingType type)
|
||||
}
|
||||
|
||||
|
||||
DictionaryLifetime::DictionaryLifetime(const Poco::Util::AbstractConfiguration & config, const std::string & config_prefix)
|
||||
{
|
||||
const auto & lifetime_min_key = config_prefix + ".min";
|
||||
const auto has_min = config.has(lifetime_min_key);
|
||||
|
||||
this->min_sec = has_min ? config.getInt(lifetime_min_key) : config.getInt(config_prefix);
|
||||
this->max_sec = has_min ? config.getInt(config_prefix + ".max") : this->min_sec;
|
||||
}
|
||||
|
||||
|
||||
DictionarySpecialAttribute::DictionarySpecialAttribute(const Poco::Util::AbstractConfiguration & config, const std::string & config_prefix)
|
||||
: name{config.getString(config_prefix + ".name", "")},
|
||||
expression{config.getString(config_prefix + ".expression", "")}
|
||||
|
@ -4,6 +4,7 @@
|
||||
#include <IO/ReadBufferFromString.h>
|
||||
#include <IO/WriteBuffer.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <Interpreters/IExternalLoadable.h>
|
||||
#include <Poco/Util/AbstractConfiguration.h>
|
||||
#include <ext/range.h>
|
||||
#include <numeric>
|
||||
@ -42,13 +43,7 @@ std::string toString(const AttributeUnderlyingType type);
|
||||
|
||||
|
||||
/// Min and max lifetimes for a dictionary or it's entry
|
||||
struct DictionaryLifetime final
|
||||
{
|
||||
UInt64 min_sec;
|
||||
UInt64 max_sec;
|
||||
|
||||
DictionaryLifetime(const Poco::Util::AbstractConfiguration & config, const std::string & config_prefix);
|
||||
};
|
||||
using DictionaryLifetime = ExternalLoadableLifetime;
|
||||
|
||||
|
||||
/** Holds the description of a single dictionary attribute:
|
||||
|
@ -41,7 +41,7 @@ public:
|
||||
|
||||
bool isCached() const override { return false; }
|
||||
|
||||
DictionaryPtr clone() const override { return std::make_unique<FlatDictionary>(*this); }
|
||||
std::unique_ptr<IExternalLoadable> clone() const override { return std::make_unique<FlatDictionary>(*this); }
|
||||
|
||||
const IDictionarySource * getSource() const override { return source_ptr.get(); }
|
||||
|
||||
|
@ -40,7 +40,7 @@ public:
|
||||
|
||||
bool isCached() const override { return false; }
|
||||
|
||||
DictionaryPtr clone() const override { return std::make_unique<HashedDictionary>(*this); }
|
||||
std::unique_ptr<IExternalLoadable> clone() const override { return std::make_unique<HashedDictionary>(*this); }
|
||||
|
||||
const IDictionarySource * getSource() const override { return source_ptr.get(); }
|
||||
|
||||
|
@ -1,22 +1,21 @@
|
||||
#pragma once
|
||||
|
||||
#include <Core/Field.h>
|
||||
#include <Interpreters/IExternalLoadable.h>
|
||||
#include <common/StringRef.h>
|
||||
#include <Core/Names.h>
|
||||
#include <Poco/Util/XMLConfiguration.h>
|
||||
#include <Common/PODArray.h>
|
||||
#include <memory>
|
||||
#include <chrono>
|
||||
#include <Dictionaries/IDictionarySource.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
class IDictionarySource;
|
||||
|
||||
struct IDictionaryBase;
|
||||
using DictionaryPtr = std::unique_ptr<IDictionaryBase>;
|
||||
|
||||
struct DictionaryLifetime;
|
||||
struct DictionaryStructure;
|
||||
class ColumnString;
|
||||
|
||||
@ -24,14 +23,10 @@ class IBlockInputStream;
|
||||
using BlockInputStreamPtr = std::shared_ptr<IBlockInputStream>;
|
||||
|
||||
|
||||
struct IDictionaryBase : public std::enable_shared_from_this<IDictionaryBase>
|
||||
struct IDictionaryBase : public IExternalLoadable
|
||||
{
|
||||
using Key = UInt64;
|
||||
|
||||
virtual std::exception_ptr getCreationException() const = 0;
|
||||
|
||||
virtual std::string getName() const = 0;
|
||||
|
||||
virtual std::string getTypeName() const = 0;
|
||||
|
||||
virtual size_t getBytesAllocated() const = 0;
|
||||
@ -45,12 +40,9 @@ struct IDictionaryBase : public std::enable_shared_from_this<IDictionaryBase>
|
||||
virtual double getLoadFactor() const = 0;
|
||||
|
||||
virtual bool isCached() const = 0;
|
||||
virtual DictionaryPtr clone() const = 0;
|
||||
|
||||
virtual const IDictionarySource * getSource() const = 0;
|
||||
|
||||
virtual const DictionaryLifetime & getLifetime() const = 0;
|
||||
|
||||
virtual const DictionaryStructure & getStructure() const = 0;
|
||||
|
||||
virtual std::chrono::time_point<std::chrono::system_clock> getCreationTime() const = 0;
|
||||
@ -59,7 +51,23 @@ struct IDictionaryBase : public std::enable_shared_from_this<IDictionaryBase>
|
||||
|
||||
virtual BlockInputStreamPtr getBlockInputStream(const Names & column_names, size_t max_block_size) const = 0;
|
||||
|
||||
virtual ~IDictionaryBase() = default;
|
||||
bool supportUpdates() const override { return !isCached(); }
|
||||
|
||||
bool isModified() const override
|
||||
{
|
||||
auto source = getSource();
|
||||
return source && source->isModified();
|
||||
}
|
||||
|
||||
std::shared_ptr<IDictionaryBase> shared_from_this()
|
||||
{
|
||||
return std::static_pointer_cast<IDictionaryBase>(IExternalLoadable::shared_from_this());
|
||||
}
|
||||
|
||||
std::shared_ptr<const IDictionaryBase> shared_from_this() const
|
||||
{
|
||||
return std::static_pointer_cast<const IDictionaryBase>(IExternalLoadable::shared_from_this());
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
@ -41,7 +41,7 @@ public:
|
||||
|
||||
bool isCached() const override { return false; }
|
||||
|
||||
DictionaryPtr clone() const override { return std::make_unique<RangeHashedDictionary>(*this); }
|
||||
std::unique_ptr<IExternalLoadable> clone() const override { return std::make_unique<RangeHashedDictionary>(*this); }
|
||||
|
||||
const IDictionarySource * getSource() const override { return source_ptr.get(); }
|
||||
|
||||
|
@ -50,7 +50,7 @@ public:
|
||||
|
||||
bool isCached() const override { return false; }
|
||||
|
||||
DictionaryPtr clone() const override { return std::make_unique<TrieDictionary>(*this); }
|
||||
std::unique_ptr<IExternalLoadable> clone() const override { return std::make_unique<TrieDictionary>(*this); }
|
||||
|
||||
const IDictionarySource * getSource() const override { return source_ptr.get(); }
|
||||
|
||||
|
62
dbms/src/Functions/FunctionsExternalModels.cpp
Normal file
62
dbms/src/Functions/FunctionsExternalModels.cpp
Normal file
@ -0,0 +1,62 @@
|
||||
#include <Functions/FunctionsExternalModels.h>
|
||||
#include <Functions/FunctionHelpers.h>
|
||||
#include <Functions/FunctionFactory.h>
|
||||
|
||||
#include <Interpreters/Context.h>
|
||||
#include <Interpreters/ExternalModels.h>
|
||||
#include <DataTypes/DataTypeString.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <Columns/ColumnString.h>
|
||||
#include <ext/range.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
FunctionPtr FunctionModelEvaluate::create(const Context & context)
|
||||
{
|
||||
return std::make_shared<FunctionModelEvaluate>(context.getExternalModels());
|
||||
}
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||
extern const int TOO_LESS_ARGUMENTS_FOR_FUNCTION;
|
||||
extern const int ILLEGAL_COLUMN;
|
||||
}
|
||||
|
||||
DataTypePtr FunctionModelEvaluate::getReturnTypeImpl(const DataTypes & arguments) const
|
||||
{
|
||||
if (arguments.size() < 2)
|
||||
throw Exception("Function " + getName() + " expects at least 2 arguments",
|
||||
ErrorCodes::TOO_LESS_ARGUMENTS_FOR_FUNCTION);
|
||||
|
||||
if (!checkDataType<DataTypeString>(arguments[0].get()))
|
||||
throw Exception("Illegal type " + arguments[0]->getName() + " of first argument of function " + getName()
|
||||
+ ", expected a string.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
return std::make_shared<DataTypeFloat64>();
|
||||
}
|
||||
|
||||
void FunctionModelEvaluate::executeImpl(Block & block, const ColumnNumbers & arguments, size_t result)
|
||||
{
|
||||
const auto name_col = checkAndGetColumnConst<ColumnString>(block.getByPosition(arguments[0]).column.get());
|
||||
if (!name_col)
|
||||
throw Exception("First argument of function " + getName() + " must be a constant string",
|
||||
ErrorCodes::ILLEGAL_COLUMN);
|
||||
|
||||
auto model = models.getModel(name_col->getValue<String>());
|
||||
|
||||
ConstColumnPlainPtrs columns;
|
||||
columns.reserve(arguments.size());
|
||||
for (auto i : ext::range(1, arguments.size()))
|
||||
columns.push_back(block.getByPosition(arguments[i]).column.get());
|
||||
|
||||
block.getByPosition(result).column = model->evaluate(columns);
|
||||
}
|
||||
|
||||
void registerFunctionsExternalModels(FunctionFactory & factory)
|
||||
{
|
||||
factory.registerFunction<FunctionModelEvaluate>();
|
||||
}
|
||||
|
||||
}
|
36
dbms/src/Functions/FunctionsExternalModels.h
Normal file
36
dbms/src/Functions/FunctionsExternalModels.h
Normal file
@ -0,0 +1,36 @@
|
||||
#pragma once
|
||||
#include <Functions/IFunction.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
class ExternalModels;
|
||||
|
||||
/// Evaluate external model.
|
||||
/// First argument - model name, the others - model arguments.
|
||||
/// * for CatBoost model - float features first, then categorical
|
||||
/// Result - Float64.
|
||||
class FunctionModelEvaluate final : public IFunction
|
||||
{
|
||||
public:
|
||||
static constexpr auto name = "modelEvaluate";
|
||||
|
||||
static FunctionPtr create(const Context & context);
|
||||
|
||||
explicit FunctionModelEvaluate(const ExternalModels & models) : models(models) {}
|
||||
|
||||
String getName() const override { return name; }
|
||||
|
||||
bool isVariadic() const override { return true; }
|
||||
|
||||
size_t getNumberOfArguments() const override { return 0; }
|
||||
|
||||
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override;
|
||||
|
||||
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) override;
|
||||
|
||||
private:
|
||||
const ExternalModels & models;
|
||||
};
|
||||
|
||||
}
|
@ -20,6 +20,7 @@ void registerFunctionsConversion(FunctionFactory &);
|
||||
void registerFunctionsDateTime(FunctionFactory &);
|
||||
void registerFunctionsEmbeddedDictionaries(FunctionFactory &);
|
||||
void registerFunctionsExternalDictionaries(FunctionFactory &);
|
||||
void registerFunctionsExternalModels(FunctionFactory &);
|
||||
void registerFunctionsFormatting(FunctionFactory &);
|
||||
void registerFunctionsHashing(FunctionFactory &);
|
||||
void registerFunctionsHigherOrder(FunctionFactory &);
|
||||
@ -54,6 +55,7 @@ void registerFunctions()
|
||||
registerFunctionsDateTime(factory);
|
||||
registerFunctionsEmbeddedDictionaries(factory);
|
||||
registerFunctionsExternalDictionaries(factory);
|
||||
registerFunctionsExternalModels(factory);
|
||||
registerFunctionsFormatting(factory);
|
||||
registerFunctionsHashing(factory);
|
||||
registerFunctionsHigherOrder(factory);
|
||||
|
@ -29,6 +29,7 @@
|
||||
#include <Interpreters/Quota.h>
|
||||
#include <Interpreters/EmbeddedDictionaries.h>
|
||||
#include <Interpreters/ExternalDictionaries.h>
|
||||
#include <Interpreters/ExternalModels.h>
|
||||
#include <Interpreters/ProcessList.h>
|
||||
#include <Interpreters/Cluster.h>
|
||||
#include <Interpreters/InterserverIOHandler.h>
|
||||
@ -94,6 +95,7 @@ struct ContextShared
|
||||
/// Separate mutex for access of dictionaries. Separate mutex to avoid locks when server doing request to itself.
|
||||
mutable std::mutex embedded_dictionaries_mutex;
|
||||
mutable std::mutex external_dictionaries_mutex;
|
||||
mutable std::mutex external_models_mutex;
|
||||
/// Separate mutex for re-initialization of zookeer session. This operation could take a long time and must not interfere with another operations.
|
||||
mutable std::mutex zookeeper_mutex;
|
||||
|
||||
@ -111,6 +113,7 @@ struct ContextShared
|
||||
FormatFactory format_factory; /// Formats.
|
||||
mutable std::shared_ptr<EmbeddedDictionaries> embedded_dictionaries; /// Metrica's dictionaeis. Have lazy initialization.
|
||||
mutable std::shared_ptr<ExternalDictionaries> external_dictionaries;
|
||||
mutable std::shared_ptr<ExternalModels> external_models;
|
||||
String default_profile_name; /// Default profile name used for default values.
|
||||
Users users; /// Known users.
|
||||
Quotas quotas; /// Known quotas for resource use.
|
||||
@ -1062,6 +1065,17 @@ ExternalDictionaries & Context::getExternalDictionaries()
|
||||
}
|
||||
|
||||
|
||||
const ExternalModels & Context::getExternalModels() const
|
||||
{
|
||||
return getExternalModelsImpl(false);
|
||||
}
|
||||
|
||||
ExternalModels & Context::getExternalModels()
|
||||
{
|
||||
return getExternalModelsImpl(false);
|
||||
}
|
||||
|
||||
|
||||
EmbeddedDictionaries & Context::getEmbeddedDictionariesImpl(const bool throw_on_error) const
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(shared->embedded_dictionaries_mutex);
|
||||
@ -1087,6 +1101,19 @@ ExternalDictionaries & Context::getExternalDictionariesImpl(const bool throw_on_
|
||||
return *shared->external_dictionaries;
|
||||
}
|
||||
|
||||
ExternalModels & Context::getExternalModelsImpl(bool throw_on_error) const
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(shared->external_models_mutex);
|
||||
|
||||
if (!shared->external_models)
|
||||
{
|
||||
if (!this->global_context)
|
||||
throw Exception("Logical error: there is no global context", ErrorCodes::LOGICAL_ERROR);
|
||||
shared->external_models = std::make_shared<ExternalModels>(*this->global_context, throw_on_error);
|
||||
}
|
||||
|
||||
return *shared->external_models;
|
||||
}
|
||||
|
||||
void Context::tryCreateEmbeddedDictionaries() const
|
||||
{
|
||||
@ -1100,6 +1127,12 @@ void Context::tryCreateExternalDictionaries() const
|
||||
}
|
||||
|
||||
|
||||
void Context::tryCreateExternalModels() const
|
||||
{
|
||||
static_cast<void>(getExternalModelsImpl(true));
|
||||
}
|
||||
|
||||
|
||||
void Context::setProgressCallback(ProgressCallback callback)
|
||||
{
|
||||
/// Callback is set to a session or to a query. In the session, only one query is processed at a time. Therefore, the lock is not needed.
|
||||
|
@ -35,6 +35,7 @@ struct ContextShared;
|
||||
class QuotaForIntervals;
|
||||
class EmbeddedDictionaries;
|
||||
class ExternalDictionaries;
|
||||
class ExternalModels;
|
||||
class InterserverIOHandler;
|
||||
class BackgroundProcessingPool;
|
||||
class ReshardingWorker;
|
||||
@ -209,10 +210,13 @@ public:
|
||||
|
||||
const EmbeddedDictionaries & getEmbeddedDictionaries() const;
|
||||
const ExternalDictionaries & getExternalDictionaries() const;
|
||||
const ExternalModels & getExternalModels() const;
|
||||
EmbeddedDictionaries & getEmbeddedDictionaries();
|
||||
ExternalDictionaries & getExternalDictionaries();
|
||||
ExternalModels & getExternalModels();
|
||||
void tryCreateEmbeddedDictionaries() const;
|
||||
void tryCreateExternalDictionaries() const;
|
||||
void tryCreateExternalModels() const;
|
||||
|
||||
/// I/O formats.
|
||||
BlockInputStreamPtr getInputFormat(const String & name, ReadBuffer & buf, const Block & sample, size_t max_block_size) const;
|
||||
@ -362,6 +366,7 @@ private:
|
||||
|
||||
EmbeddedDictionaries & getEmbeddedDictionariesImpl(bool throw_on_error) const;
|
||||
ExternalDictionaries & getExternalDictionariesImpl(bool throw_on_error) const;
|
||||
ExternalModels & getExternalModelsImpl(bool throw_on_error) const;
|
||||
|
||||
StoragePtr getTableImpl(const String & database_name, const String & table_name, Exception * exception) const;
|
||||
|
||||
|
@ -24,7 +24,7 @@ namespace ErrorCodes
|
||||
}
|
||||
|
||||
|
||||
DictionaryPtr DictionaryFactory::create(const std::string & name, Poco::Util::AbstractConfiguration & config,
|
||||
DictionaryPtr DictionaryFactory::create(const std::string & name, const Poco::Util::AbstractConfiguration & config,
|
||||
const std::string & config_prefix, Context & context) const
|
||||
{
|
||||
Poco::Util::AbstractConfiguration::Keys keys;
|
||||
|
@ -1,427 +1,46 @@
|
||||
#include <Interpreters/ExternalDictionaries.h>
|
||||
#include <Interpreters/Context.h>
|
||||
#include <Dictionaries/DictionaryFactory.h>
|
||||
#include <Dictionaries/DictionaryStructure.h>
|
||||
#include <Dictionaries/IDictionarySource.h>
|
||||
#include <Common/StringUtils.h>
|
||||
#include <Common/MemoryTracker.h>
|
||||
#include <Common/getMultipleKeysFromConfig.h>
|
||||
#include <ext/scope_guard.h>
|
||||
#include <Poco/Util/Application.h>
|
||||
#include <Poco/Glob.h>
|
||||
#include <Poco/File.h>
|
||||
|
||||
|
||||
namespace
|
||||
{
|
||||
const auto check_period_sec = 5;
|
||||
const auto backoff_initial_sec = 5;
|
||||
/// 10 minutes
|
||||
const auto backoff_max_sec = 10 * 60;
|
||||
}
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int LOGICAL_ERROR;
|
||||
extern const int BAD_ARGUMENTS;
|
||||
}
|
||||
|
||||
|
||||
void ExternalDictionaries::reloadPeriodically()
|
||||
{
|
||||
setThreadName("ExterDictReload");
|
||||
|
||||
while (true)
|
||||
{
|
||||
if (destroy.tryWait(check_period_sec * 1000))
|
||||
return;
|
||||
|
||||
reloadAndUpdate();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
ExternalDictionaries::ExternalDictionaries(Context & context, const bool throw_on_error)
|
||||
: context(context), log(&Logger::get("ExternalDictionaries"))
|
||||
{
|
||||
{
|
||||
/** During synchronous loading of external dictionaries at moment of query execution,
|
||||
* we should not use per query memory limit.
|
||||
*/
|
||||
TemporarilyDisableMemoryTracker temporarily_disable_memory_tracker;
|
||||
|
||||
reloadAndUpdate(throw_on_error);
|
||||
}
|
||||
|
||||
reloading_thread = std::thread{&ExternalDictionaries::reloadPeriodically, this};
|
||||
}
|
||||
|
||||
|
||||
ExternalDictionaries::~ExternalDictionaries()
|
||||
{
|
||||
destroy.set();
|
||||
reloading_thread.join();
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
std::set<std::string> getDictionariesConfigPaths(const Poco::Util::AbstractConfiguration & config)
|
||||
{
|
||||
std::set<std::string> files;
|
||||
auto patterns = getMultipleValuesFromConfig(config, "", "dictionaries_config");
|
||||
for (auto & pattern : patterns)
|
||||
const ExternalLoaderUpdateSettings externalDictionariesUpdateSettings;
|
||||
|
||||
const ExternalLoaderConfigSettings & getExternalDictionariesConfigSettings()
|
||||
{
|
||||
if (pattern.empty())
|
||||
continue;
|
||||
static ExternalLoaderConfigSettings settings;
|
||||
static std::once_flag flag;
|
||||
|
||||
if (pattern[0] != '/')
|
||||
{
|
||||
const auto app_config_path = config.getString("config-file", "config.xml");
|
||||
const auto config_dir = Poco::Path{app_config_path}.parent().toString();
|
||||
const auto absolute_path = config_dir + pattern;
|
||||
Poco::Glob::glob(absolute_path, files, 0);
|
||||
if (!files.empty())
|
||||
continue;
|
||||
}
|
||||
std::call_once(flag, [] {
|
||||
settings.external_config = "dictionary";
|
||||
settings.external_name = "name";
|
||||
|
||||
Poco::Glob::glob(pattern, files, 0);
|
||||
}
|
||||
settings.path_setting_name = "dictionaries_config";
|
||||
});
|
||||
|
||||
return files;
|
||||
}
|
||||
}
|
||||
|
||||
void ExternalDictionaries::reloadAndUpdate(bool throw_on_error)
|
||||
{
|
||||
reloadFromConfigFiles(throw_on_error);
|
||||
|
||||
/// list of recreated dictionaries to perform delayed removal from unordered_map
|
||||
std::list<std::string> recreated_failed_dictionaries;
|
||||
|
||||
std::unique_lock<std::mutex> all_lock(all_mutex);
|
||||
|
||||
/// retry loading failed dictionaries
|
||||
for (auto & failed_dictionary : failed_dictionaries)
|
||||
{
|
||||
if (std::chrono::system_clock::now() < failed_dictionary.second.next_attempt_time)
|
||||
continue;
|
||||
|
||||
const auto & name = failed_dictionary.first;
|
||||
|
||||
try
|
||||
{
|
||||
auto dict_ptr = failed_dictionary.second.dict->clone();
|
||||
if (const auto exception_ptr = dict_ptr->getCreationException())
|
||||
{
|
||||
/// recalculate next attempt time
|
||||
std::uniform_int_distribution<UInt64> distribution(
|
||||
0, static_cast<UInt64>(std::exp2(failed_dictionary.second.error_count)));
|
||||
|
||||
failed_dictionary.second.next_attempt_time = std::chrono::system_clock::now() +
|
||||
std::chrono::seconds{
|
||||
std::min<UInt64>(backoff_max_sec, backoff_initial_sec + distribution(rnd_engine))};
|
||||
|
||||
++failed_dictionary.second.error_count;
|
||||
|
||||
std::rethrow_exception(exception_ptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
const std::lock_guard<std::mutex> lock{dictionaries_mutex};
|
||||
|
||||
const auto & lifetime = dict_ptr->getLifetime();
|
||||
std::uniform_int_distribution<UInt64> distribution{lifetime.min_sec, lifetime.max_sec};
|
||||
update_times[name] = std::chrono::system_clock::now() + std::chrono::seconds{distribution(rnd_engine)};
|
||||
|
||||
const auto dict_it = dictionaries.find(name);
|
||||
if (dict_it->second.dict)
|
||||
dict_it->second.dict->set(dict_ptr.release());
|
||||
else
|
||||
dict_it->second.dict = std::make_shared<MultiVersion<IDictionaryBase>>(dict_ptr.release());
|
||||
|
||||
/// clear stored exception on success
|
||||
dict_it->second.exception = std::exception_ptr{};
|
||||
|
||||
recreated_failed_dictionaries.push_back(name);
|
||||
}
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
tryLogCurrentException(log, "Failed reloading '" + name + "' dictionary");
|
||||
|
||||
if (throw_on_error)
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
/// do not undertake further attempts to recreate these dictionaries
|
||||
for (const auto & name : recreated_failed_dictionaries)
|
||||
failed_dictionaries.erase(name);
|
||||
|
||||
/// periodic update
|
||||
for (auto & dictionary : dictionaries)
|
||||
{
|
||||
const auto & name = dictionary.first;
|
||||
|
||||
try
|
||||
{
|
||||
/// If the dictionary failed to load or even failed to initialize from the config.
|
||||
if (!dictionary.second.dict)
|
||||
continue;
|
||||
|
||||
auto current = dictionary.second.dict->get();
|
||||
const auto & lifetime = current->getLifetime();
|
||||
|
||||
/// do not update dictionaries with zero as lifetime
|
||||
if (lifetime.min_sec == 0 || lifetime.max_sec == 0)
|
||||
continue;
|
||||
|
||||
/// update only non-cached dictionaries
|
||||
if (!current->isCached())
|
||||
{
|
||||
auto & update_time = update_times[current->getName()];
|
||||
|
||||
/// check that timeout has passed
|
||||
if (std::chrono::system_clock::now() < update_time)
|
||||
continue;
|
||||
|
||||
SCOPE_EXIT({
|
||||
/// calculate next update time
|
||||
std::uniform_int_distribution<UInt64> distribution{lifetime.min_sec, lifetime.max_sec};
|
||||
update_time = std::chrono::system_clock::now() + std::chrono::seconds{distribution(rnd_engine)};
|
||||
});
|
||||
|
||||
/// check source modified
|
||||
if (current->getSource()->isModified())
|
||||
{
|
||||
/// create new version of dictionary
|
||||
auto new_version = current->clone();
|
||||
|
||||
if (const auto exception_ptr = new_version->getCreationException())
|
||||
std::rethrow_exception(exception_ptr);
|
||||
|
||||
dictionary.second.dict->set(new_version.release());
|
||||
}
|
||||
}
|
||||
|
||||
/// erase stored exception on success
|
||||
dictionary.second.exception = std::exception_ptr{};
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
dictionary.second.exception = std::current_exception();
|
||||
|
||||
tryLogCurrentException(log, "Cannot update external dictionary '" + name + "', leaving old version");
|
||||
|
||||
if (throw_on_error)
|
||||
throw;
|
||||
}
|
||||
return settings;
|
||||
}
|
||||
}
|
||||
|
||||
void ExternalDictionaries::reloadFromConfigFiles(const bool throw_on_error, const bool force_reload, const std::string & only_dictionary)
|
||||
|
||||
ExternalDictionaries::ExternalDictionaries(Context & context, bool throw_on_error)
|
||||
: ExternalLoader(context.getConfigRef(),
|
||||
externalDictionariesUpdateSettings,
|
||||
getExternalDictionariesConfigSettings(),
|
||||
&Logger::get("ExternalDictionaries"),
|
||||
"external dictionary"),
|
||||
context(context)
|
||||
{
|
||||
const auto config_paths = getDictionariesConfigPaths(context.getConfigRef());
|
||||
|
||||
for (const auto & config_path : config_paths)
|
||||
{
|
||||
try
|
||||
{
|
||||
reloadFromConfigFile(config_path, throw_on_error, force_reload, only_dictionary);
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
tryLogCurrentException(log, "reloadFromConfigFile has thrown while reading from " + config_path);
|
||||
|
||||
if (throw_on_error)
|
||||
throw;
|
||||
}
|
||||
}
|
||||
init(throw_on_error);
|
||||
}
|
||||
|
||||
void ExternalDictionaries::reloadFromConfigFile(const std::string & config_path, const bool throw_on_error, const bool force_reload,
|
||||
const std::string & only_dictionary)
|
||||
std::unique_ptr<IExternalLoadable> ExternalDictionaries::create(
|
||||
const std::string & name, const Configuration & config, const std::string & config_prefix)
|
||||
{
|
||||
const Poco::File config_file{config_path};
|
||||
|
||||
if (config_path.empty() || !config_file.exists())
|
||||
{
|
||||
LOG_WARNING(log, "config file '" + config_path + "' does not exist");
|
||||
}
|
||||
else
|
||||
{
|
||||
std::unique_lock<std::mutex> all_lock(all_mutex);
|
||||
|
||||
auto modification_time_it = last_modification_times.find(config_path);
|
||||
if (modification_time_it == std::end(last_modification_times))
|
||||
modification_time_it = last_modification_times.emplace(config_path, Poco::Timestamp{0}).first;
|
||||
auto & config_last_modified = modification_time_it->second;
|
||||
|
||||
const auto last_modified = config_file.getLastModified();
|
||||
if (force_reload || last_modified > config_last_modified)
|
||||
{
|
||||
Poco::AutoPtr<Poco::Util::XMLConfiguration> config = new Poco::Util::XMLConfiguration(config_path);
|
||||
|
||||
/// Definitions of dictionaries may have changed, recreate all of them
|
||||
|
||||
/// If we need update only one dictionary, don't update modification time: might be other dictionaries in the config file
|
||||
if (only_dictionary.empty())
|
||||
config_last_modified = last_modified;
|
||||
|
||||
/// get all dictionaries' definitions
|
||||
Poco::Util::AbstractConfiguration::Keys keys;
|
||||
config->keys(keys);
|
||||
|
||||
/// for each dictionary defined in xml config
|
||||
for (const auto & key : keys)
|
||||
{
|
||||
std::string name;
|
||||
|
||||
if (!startsWith(key, "dictionary"))
|
||||
{
|
||||
if (!startsWith(key.data(), "comment"))
|
||||
LOG_WARNING(log,
|
||||
config_path << ": unknown node in dictionaries file: '" << key + "', 'dictionary'");
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
name = config->getString(key + ".name");
|
||||
if (name.empty())
|
||||
{
|
||||
LOG_WARNING(log, config_path << ": dictionary name cannot be empty");
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!only_dictionary.empty() && name != only_dictionary)
|
||||
continue;
|
||||
|
||||
decltype(dictionaries.begin()) dict_it;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock{dictionaries_mutex};
|
||||
dict_it = dictionaries.find(name);
|
||||
}
|
||||
|
||||
if (dict_it != std::end(dictionaries) && dict_it->second.origin != config_path)
|
||||
throw std::runtime_error{"Overriding dictionary from file " + dict_it->second.origin};
|
||||
|
||||
auto dict_ptr = DictionaryFactory::instance().create(name, *config, key, context);
|
||||
|
||||
/// If the dictionary could not be loaded.
|
||||
if (const auto exception_ptr = dict_ptr->getCreationException())
|
||||
{
|
||||
const auto failed_dict_it = failed_dictionaries.find(name);
|
||||
if (failed_dict_it != std::end(failed_dictionaries))
|
||||
{
|
||||
failed_dict_it->second = FailedDictionaryInfo{
|
||||
std::move(dict_ptr),
|
||||
std::chrono::system_clock::now() + std::chrono::seconds{backoff_initial_sec}};
|
||||
}
|
||||
else
|
||||
failed_dictionaries.emplace(name, FailedDictionaryInfo{
|
||||
std::move(dict_ptr),
|
||||
std::chrono::system_clock::now() + std::chrono::seconds{backoff_initial_sec}});
|
||||
|
||||
std::rethrow_exception(exception_ptr);
|
||||
}
|
||||
else if (!dict_ptr->isCached())
|
||||
{
|
||||
const auto & lifetime = dict_ptr->getLifetime();
|
||||
if (lifetime.min_sec != 0 && lifetime.max_sec != 0)
|
||||
{
|
||||
std::uniform_int_distribution<UInt64> distribution{
|
||||
lifetime.min_sec,
|
||||
lifetime.max_sec
|
||||
};
|
||||
update_times[name] = std::chrono::system_clock::now() +
|
||||
std::chrono::seconds{distribution(rnd_engine)};
|
||||
}
|
||||
}
|
||||
|
||||
const std::lock_guard<std::mutex> lock{dictionaries_mutex};
|
||||
|
||||
/// add new dictionary or update an existing version
|
||||
if (dict_it == std::end(dictionaries))
|
||||
dictionaries.emplace(name, DictionaryInfo{
|
||||
std::make_shared<MultiVersion<IDictionaryBase>>(dict_ptr.release()),
|
||||
config_path
|
||||
});
|
||||
else
|
||||
{
|
||||
if (dict_it->second.dict)
|
||||
dict_it->second.dict->set(dict_ptr.release());
|
||||
else
|
||||
dict_it->second.dict = std::make_shared<MultiVersion<IDictionaryBase>>(dict_ptr.release());
|
||||
|
||||
/// erase stored exception on success
|
||||
dict_it->second.exception = std::exception_ptr{};
|
||||
failed_dictionaries.erase(name);
|
||||
}
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
if (!name.empty())
|
||||
{
|
||||
/// If the dictionary could not load data or even failed to initialize from the config.
|
||||
/// - all the same we insert information into the `dictionaries`, with the zero pointer `dict`.
|
||||
|
||||
const std::lock_guard<std::mutex> lock{dictionaries_mutex};
|
||||
|
||||
const auto exception_ptr = std::current_exception();
|
||||
const auto dict_it = dictionaries.find(name);
|
||||
if (dict_it == std::end(dictionaries))
|
||||
dictionaries.emplace(name, DictionaryInfo{nullptr, config_path, exception_ptr});
|
||||
else
|
||||
dict_it->second.exception = exception_ptr;
|
||||
}
|
||||
|
||||
tryLogCurrentException(log, "Cannot create external dictionary '" + name + "' from config path " + config_path);
|
||||
|
||||
/// propagate exception
|
||||
if (throw_on_error)
|
||||
throw;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ExternalDictionaries::reload()
|
||||
{
|
||||
reloadFromConfigFiles(true, true);
|
||||
}
|
||||
|
||||
void ExternalDictionaries::reloadDictionary(const std::string & name)
|
||||
{
|
||||
reloadFromConfigFiles(true, true, name);
|
||||
|
||||
/// Check that specified dict was loaded
|
||||
const std::lock_guard<std::mutex> lock{dictionaries_mutex};
|
||||
if (!dictionaries.count(name))
|
||||
throw Exception("Dictionary " + name + " wasn't loaded during the reload process", ErrorCodes::BAD_ARGUMENTS);
|
||||
}
|
||||
|
||||
MultiVersion<IDictionaryBase>::Version ExternalDictionaries::getDictionary(const std::string & name) const
|
||||
{
|
||||
const std::lock_guard<std::mutex> lock{dictionaries_mutex};
|
||||
|
||||
const auto it = dictionaries.find(name);
|
||||
if (it == std::end(dictionaries))
|
||||
throw Exception{
|
||||
"No such dictionary: " + name,
|
||||
ErrorCodes::BAD_ARGUMENTS
|
||||
};
|
||||
|
||||
if (!it->second.dict)
|
||||
it->second.exception ? std::rethrow_exception(it->second.exception) :
|
||||
throw Exception{"No dictionary", ErrorCodes::LOGICAL_ERROR};
|
||||
|
||||
return it->second.dict->get();
|
||||
return DictionaryFactory::instance().create(name, config, config_prefix, context);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -1,19 +1,9 @@
|
||||
#pragma once
|
||||
|
||||
#include <Dictionaries/IDictionary.h>
|
||||
#include <Common/Exception.h>
|
||||
#include <Common/setThreadName.h>
|
||||
#include <Common/randomSeed.h>
|
||||
#include <common/MultiVersion.h>
|
||||
#include <Interpreters/ExternalLoader.h>
|
||||
#include <common/logger_useful.h>
|
||||
#include <Poco/Event.h>
|
||||
#include <unistd.h>
|
||||
#include <time.h>
|
||||
#include <mutex>
|
||||
#include <thread>
|
||||
#include <unordered_map>
|
||||
#include <chrono>
|
||||
#include <pcg_random.hpp>
|
||||
#include <memory>
|
||||
|
||||
|
||||
namespace DB
|
||||
@ -21,93 +11,36 @@ namespace DB
|
||||
|
||||
class Context;
|
||||
|
||||
/** Manages user-defined dictionaries.
|
||||
* Monitors configuration file and automatically reloads dictionaries in a separate thread.
|
||||
* The monitoring thread wakes up every @check_period_sec seconds and checks
|
||||
* modification time of dictionaries' configuration file. If said time is greater than
|
||||
* @config_last_modified, the dictionaries are created from scratch using configuration file,
|
||||
* possibly overriding currently existing dictionaries with the same name (previous versions of
|
||||
* overridden dictionaries will live as long as there are any users retaining them).
|
||||
*
|
||||
* Apart from checking configuration file for modifications, each non-cached dictionary
|
||||
* has a lifetime of its own and may be updated if it's source reports that it has been
|
||||
* modified. The time of next update is calculated by choosing uniformly a random number
|
||||
* distributed between lifetime.min_sec and lifetime.max_sec.
|
||||
* If either of lifetime.min_sec and lifetime.max_sec is zero, such dictionary is never updated.
|
||||
*/
|
||||
class ExternalDictionaries
|
||||
/// Manages user-defined dictionaries.
|
||||
class ExternalDictionaries : public ExternalLoader
|
||||
{
|
||||
private:
|
||||
public:
|
||||
using DictPtr = std::shared_ptr<IDictionaryBase>;
|
||||
|
||||
/// Dictionaries will be loaded immediately and then will be updated in separate thread, each 'reload_period' seconds.
|
||||
ExternalDictionaries(Context & context, bool throw_on_error);
|
||||
|
||||
/// Forcibly reloads specified dictionary.
|
||||
void reloadDictionary(const std::string & name) { reload(name); }
|
||||
|
||||
DictPtr getDictionary(const std::string & name) const
|
||||
{
|
||||
return std::static_pointer_cast<IDictionaryBase>(getLoadable(name));
|
||||
}
|
||||
|
||||
protected:
|
||||
|
||||
std::unique_ptr<IExternalLoadable> create(const std::string & name, const Configuration & config,
|
||||
const std::string & config_prefix) override;
|
||||
|
||||
using ExternalLoader::getObjectsMap;
|
||||
|
||||
friend class StorageSystemDictionaries;
|
||||
friend class DatabaseDictionary;
|
||||
|
||||
/// Protects only dictionaries map.
|
||||
mutable std::mutex dictionaries_mutex;
|
||||
|
||||
/// Protects all data, currently used to avoid races between updating thread and SYSTEM queries
|
||||
mutable std::mutex all_mutex;
|
||||
|
||||
using DictionaryPtr = std::shared_ptr<MultiVersion<IDictionaryBase>>;
|
||||
struct DictionaryInfo final
|
||||
{
|
||||
DictionaryPtr dict;
|
||||
std::string origin;
|
||||
std::exception_ptr exception;
|
||||
};
|
||||
|
||||
struct FailedDictionaryInfo final
|
||||
{
|
||||
std::unique_ptr<IDictionaryBase> dict;
|
||||
std::chrono::system_clock::time_point next_attempt_time;
|
||||
UInt64 error_count;
|
||||
};
|
||||
|
||||
/** name -> dictionary.
|
||||
*/
|
||||
std::unordered_map<std::string, DictionaryInfo> dictionaries;
|
||||
|
||||
/** Here are dictionaries, that has been never loaded successfully.
|
||||
* They are also in 'dictionaries', but with nullptr as 'dict'.
|
||||
*/
|
||||
std::unordered_map<std::string, FailedDictionaryInfo> failed_dictionaries;
|
||||
|
||||
/** Both for dictionaries and failed_dictionaries.
|
||||
*/
|
||||
std::unordered_map<std::string, std::chrono::system_clock::time_point> update_times;
|
||||
|
||||
pcg64 rnd_engine{randomSeed()};
|
||||
private:
|
||||
|
||||
Context & context;
|
||||
|
||||
std::thread reloading_thread;
|
||||
Poco::Event destroy;
|
||||
|
||||
Logger * log;
|
||||
|
||||
std::unordered_map<std::string, Poco::Timestamp> last_modification_times;
|
||||
|
||||
/// Check dictionaries definitions in config files and reload or/and add new ones if the definition is changed
|
||||
void reloadFromConfigFiles(const bool throw_on_error, const bool force_reload = false, const std::string & only_dictionary = "");
|
||||
void reloadFromConfigFile(const std::string & config_path, const bool throw_on_error, const bool force_reload,
|
||||
const std::string & only_dictionary);
|
||||
|
||||
/// Check config files and update expired dictionaries
|
||||
void reloadAndUpdate(bool throw_on_error = false);
|
||||
|
||||
void reloadPeriodically();
|
||||
|
||||
public:
|
||||
/// Dictionaries will be loaded immediately and then will be updated in separate thread, each 'reload_period' seconds.
|
||||
ExternalDictionaries(Context & context, const bool throw_on_error);
|
||||
~ExternalDictionaries();
|
||||
|
||||
/// Forcibly reloads all dictionaries.
|
||||
void reload();
|
||||
|
||||
/// Forcibly reloads specified dictionary.
|
||||
void reloadDictionary(const std::string & name);
|
||||
|
||||
MultiVersion<IDictionaryBase>::Version getDictionary(const std::string & name) const;
|
||||
};
|
||||
|
||||
}
|
||||
|
433
dbms/src/Interpreters/ExternalLoader.cpp
Normal file
433
dbms/src/Interpreters/ExternalLoader.cpp
Normal file
@ -0,0 +1,433 @@
|
||||
#include <Interpreters/ExternalLoader.h>
|
||||
#include <Common/StringUtils.h>
|
||||
#include <Common/MemoryTracker.h>
|
||||
#include <Common/Exception.h>
|
||||
#include <Common/getMultipleKeysFromConfig.h>
|
||||
#include <Common/setThreadName.h>
|
||||
#include <ext/scope_guard.h>
|
||||
#include <Poco/Util/Application.h>
|
||||
#include <Poco/Glob.h>
|
||||
#include <Poco/File.h>
|
||||
#include <cmath>
|
||||
#include <Poco/Util/XMLConfiguration.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int LOGICAL_ERROR;
|
||||
extern const int BAD_ARGUMENTS;
|
||||
}
|
||||
|
||||
|
||||
ExternalLoadableLifetime::ExternalLoadableLifetime(const Poco::Util::AbstractConfiguration & config,
|
||||
const std::string & config_prefix)
|
||||
{
|
||||
const auto & lifetime_min_key = config_prefix + ".min";
|
||||
const auto has_min = config.has(lifetime_min_key);
|
||||
|
||||
min_sec = has_min ? config.getUInt64(lifetime_min_key) : config.getUInt64(config_prefix);
|
||||
max_sec = has_min ? config.getUInt64(config_prefix + ".max") : min_sec;
|
||||
}
|
||||
|
||||
void ExternalLoader::reloadPeriodically()
|
||||
{
|
||||
setThreadName("ExterLdrReload");
|
||||
|
||||
while (true)
|
||||
{
|
||||
if (destroy.tryWait(update_settings.check_period_sec * 1000))
|
||||
return;
|
||||
|
||||
reloadAndUpdate();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
ExternalLoader::ExternalLoader(const Poco::Util::AbstractConfiguration & config,
|
||||
const ExternalLoaderUpdateSettings & update_settings,
|
||||
const ExternalLoaderConfigSettings & config_settings,
|
||||
Logger * log, const std::string & loadable_object_name)
|
||||
: config(config), update_settings(update_settings), config_settings(config_settings),
|
||||
log(log), object_name(loadable_object_name)
|
||||
{
|
||||
}
|
||||
|
||||
void ExternalLoader::init(bool throw_on_error)
|
||||
{
|
||||
if (is_initialized)
|
||||
return;
|
||||
|
||||
is_initialized = true;
|
||||
|
||||
{
|
||||
/// During synchronous loading of external dictionaries at moment of query execution,
|
||||
/// we should not use per query memory limit.
|
||||
TemporarilyDisableMemoryTracker temporarily_disable_memory_tracker;
|
||||
|
||||
reloadAndUpdate(throw_on_error);
|
||||
}
|
||||
|
||||
reloading_thread = std::thread{&ExternalLoader::reloadPeriodically, this};
|
||||
}
|
||||
|
||||
|
||||
ExternalLoader::~ExternalLoader()
|
||||
{
|
||||
destroy.set();
|
||||
reloading_thread.join();
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
std::set<std::string> getConfigPaths(const Poco::Util::AbstractConfiguration & config,
|
||||
const std::string & external_config_paths_setting)
|
||||
{
|
||||
std::set<std::string> files;
|
||||
auto patterns = getMultipleValuesFromConfig(config, "", external_config_paths_setting);
|
||||
for (auto & pattern : patterns)
|
||||
{
|
||||
if (pattern.empty())
|
||||
continue;
|
||||
|
||||
if (pattern[0] != '/')
|
||||
{
|
||||
const auto app_config_path = config.getString("config-file", "config.xml");
|
||||
const auto config_dir = Poco::Path{app_config_path}.parent().toString();
|
||||
const auto absolute_path = config_dir + pattern;
|
||||
Poco::Glob::glob(absolute_path, files, 0);
|
||||
if (!files.empty())
|
||||
continue;
|
||||
}
|
||||
|
||||
Poco::Glob::glob(pattern, files, 0);
|
||||
}
|
||||
|
||||
return files;
|
||||
}
|
||||
}
|
||||
|
||||
void ExternalLoader::reloadAndUpdate(bool throw_on_error)
|
||||
{
|
||||
reloadFromConfigFiles(throw_on_error);
|
||||
|
||||
/// list of recreated loadable objects to perform delayed removal from unordered_map
|
||||
std::list<std::string> recreated_failed_loadable_objects;
|
||||
|
||||
std::unique_lock<std::mutex> all_lock(all_mutex);
|
||||
|
||||
/// retry loading failed loadable objects
|
||||
for (auto & failed_loadable_object : failed_loadable_objects)
|
||||
{
|
||||
if (std::chrono::system_clock::now() < failed_loadable_object.second.next_attempt_time)
|
||||
continue;
|
||||
|
||||
const auto & name = failed_loadable_object.first;
|
||||
|
||||
try
|
||||
{
|
||||
auto loadable_ptr = failed_loadable_object.second.loadable->clone();
|
||||
if (const auto exception_ptr = loadable_ptr->getCreationException())
|
||||
{
|
||||
/// recalculate next attempt time
|
||||
std::uniform_int_distribution<UInt64> distribution(
|
||||
0, static_cast<UInt64>(std::exp2(failed_loadable_object.second.error_count)));
|
||||
|
||||
std::chrono::seconds delay(std::min<UInt64>(
|
||||
update_settings.backoff_max_sec,
|
||||
update_settings.backoff_initial_sec + distribution(rnd_engine)));
|
||||
failed_loadable_object.second.next_attempt_time = std::chrono::system_clock::now() + delay;
|
||||
|
||||
++failed_loadable_object.second.error_count;
|
||||
|
||||
std::rethrow_exception(exception_ptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
const std::lock_guard<std::mutex> lock{map_mutex};
|
||||
|
||||
const auto & lifetime = loadable_ptr->getLifetime();
|
||||
std::uniform_int_distribution<UInt64> distribution{lifetime.min_sec, lifetime.max_sec};
|
||||
update_times[name] = std::chrono::system_clock::now() + std::chrono::seconds{distribution(rnd_engine)};
|
||||
|
||||
const auto dict_it = loadable_objects.find(name);
|
||||
|
||||
dict_it->second.loadable.reset();
|
||||
dict_it->second.loadable = std::move(loadable_ptr);
|
||||
|
||||
/// clear stored exception on success
|
||||
dict_it->second.exception = std::exception_ptr{};
|
||||
|
||||
recreated_failed_loadable_objects.push_back(name);
|
||||
}
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
tryLogCurrentException(log, "Failed reloading '" + name + "' " + object_name);
|
||||
|
||||
if (throw_on_error)
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
/// do not undertake further attempts to recreate these loadable objects
|
||||
for (const auto & name : recreated_failed_loadable_objects)
|
||||
failed_loadable_objects.erase(name);
|
||||
|
||||
/// periodic update
|
||||
for (auto & loadable_object : loadable_objects)
|
||||
{
|
||||
const auto & name = loadable_object.first;
|
||||
|
||||
try
|
||||
{
|
||||
/// If the loadable objects failed to load or even failed to initialize from the config.
|
||||
if (!loadable_object.second.loadable)
|
||||
continue;
|
||||
|
||||
auto current = loadable_object.second.loadable;
|
||||
const auto & lifetime = current->getLifetime();
|
||||
|
||||
/// do not update loadable objects with zero as lifetime
|
||||
if (lifetime.min_sec == 0 || lifetime.max_sec == 0)
|
||||
continue;
|
||||
|
||||
if (current->supportUpdates())
|
||||
{
|
||||
auto & update_time = update_times[current->getName()];
|
||||
|
||||
/// check that timeout has passed
|
||||
if (std::chrono::system_clock::now() < update_time)
|
||||
continue;
|
||||
|
||||
SCOPE_EXIT({
|
||||
/// calculate next update time
|
||||
std::uniform_int_distribution<UInt64> distribution{lifetime.min_sec, lifetime.max_sec};
|
||||
update_time = std::chrono::system_clock::now() + std::chrono::seconds{distribution(rnd_engine)};
|
||||
});
|
||||
|
||||
/// check source modified
|
||||
if (current->isModified())
|
||||
{
|
||||
/// create new version of loadable object
|
||||
auto new_version = current->clone();
|
||||
|
||||
if (const auto exception_ptr = new_version->getCreationException())
|
||||
std::rethrow_exception(exception_ptr);
|
||||
|
||||
loadable_object.second.loadable.reset();
|
||||
loadable_object.second.loadable = std::move(new_version);
|
||||
}
|
||||
}
|
||||
|
||||
/// erase stored exception on success
|
||||
loadable_object.second.exception = std::exception_ptr{};
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
loadable_object.second.exception = std::current_exception();
|
||||
|
||||
tryLogCurrentException(log, "Cannot update " + object_name + " '" + name + "', leaving old version");
|
||||
|
||||
if (throw_on_error)
|
||||
throw;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ExternalLoader::reloadFromConfigFiles(const bool throw_on_error, const bool force_reload, const std::string & only_dictionary)
|
||||
{
|
||||
const auto config_paths = getConfigPaths(config, config_settings.path_setting_name);
|
||||
|
||||
for (const auto & config_path : config_paths)
|
||||
{
|
||||
try
|
||||
{
|
||||
reloadFromConfigFile(config_path, throw_on_error, force_reload, only_dictionary);
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
tryLogCurrentException(log, "reloadFromConfigFile has thrown while reading from " + config_path);
|
||||
|
||||
if (throw_on_error)
|
||||
throw;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ExternalLoader::reloadFromConfigFile(const std::string & config_path, const bool throw_on_error,
|
||||
const bool force_reload, const std::string & loadable_name)
|
||||
{
|
||||
const Poco::File config_file{config_path};
|
||||
|
||||
if (config_path.empty() || !config_file.exists())
|
||||
{
|
||||
LOG_WARNING(log, "config file '" + config_path + "' does not exist");
|
||||
}
|
||||
else
|
||||
{
|
||||
std::unique_lock<std::mutex> all_lock(all_mutex);
|
||||
|
||||
auto modification_time_it = last_modification_times.find(config_path);
|
||||
if (modification_time_it == std::end(last_modification_times))
|
||||
modification_time_it = last_modification_times.emplace(config_path, Poco::Timestamp{0}).first;
|
||||
auto & config_last_modified = modification_time_it->second;
|
||||
|
||||
const auto last_modified = config_file.getLastModified();
|
||||
if (force_reload || last_modified > config_last_modified)
|
||||
{
|
||||
Poco::AutoPtr<Poco::Util::XMLConfiguration> config = new Poco::Util::XMLConfiguration(config_path);
|
||||
|
||||
/// Definitions of loadable objects may have changed, recreate all of them
|
||||
|
||||
/// If we need update only one object, don't update modification time: might be other objects in the config file
|
||||
if (loadable_name.empty())
|
||||
config_last_modified = last_modified;
|
||||
|
||||
/// get all objects' definitions
|
||||
Poco::Util::AbstractConfiguration::Keys keys;
|
||||
config->keys(keys);
|
||||
|
||||
/// for each loadable object defined in xml config
|
||||
for (const auto & key : keys)
|
||||
{
|
||||
std::string name;
|
||||
|
||||
if (!startsWith(key, config_settings.external_config))
|
||||
{
|
||||
if (!startsWith(key, "comment"))
|
||||
LOG_WARNING(log, config_path << ": unknown node in file: '" << key
|
||||
<< "', expected '" << config_settings.external_config << "'");
|
||||
continue;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
name = config->getString(key + "." + config_settings.external_name);
|
||||
if (name.empty())
|
||||
{
|
||||
LOG_WARNING(log, config_path << ": " + config_settings.external_name + " name cannot be empty");
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!loadable_name.empty() && name != loadable_name)
|
||||
continue;
|
||||
|
||||
decltype(loadable_objects.begin()) object_it;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock{map_mutex};
|
||||
object_it = loadable_objects.find(name);
|
||||
}
|
||||
|
||||
if (object_it != std::end(loadable_objects) && object_it->second.origin != config_path)
|
||||
throw std::runtime_error{"Overriding " + object_name + " from file " + object_it->second.origin};
|
||||
|
||||
auto object_ptr = create(name, *config, key);
|
||||
|
||||
/// If the object could not be loaded.
|
||||
if (const auto exception_ptr = object_ptr->getCreationException())
|
||||
{
|
||||
std::chrono::seconds delay(update_settings.backoff_initial_sec);
|
||||
const auto failed_dict_it = failed_loadable_objects.find(name);
|
||||
FailedLoadableInfo info{std::move(object_ptr), std::chrono::system_clock::now() + delay, 0};
|
||||
if (failed_dict_it != std::end(failed_loadable_objects))
|
||||
(*failed_dict_it).second = std::move(info);
|
||||
else
|
||||
failed_loadable_objects.emplace(name, std::move(info));
|
||||
|
||||
std::rethrow_exception(exception_ptr);
|
||||
}
|
||||
else if (object_ptr->supportUpdates())
|
||||
{
|
||||
const auto & lifetime = object_ptr->getLifetime();
|
||||
if (lifetime.min_sec != 0 && lifetime.max_sec != 0)
|
||||
{
|
||||
std::uniform_int_distribution<UInt64> distribution(lifetime.min_sec, lifetime.max_sec);
|
||||
|
||||
update_times[name] = std::chrono::system_clock::now() +
|
||||
std::chrono::seconds{distribution(rnd_engine)};
|
||||
}
|
||||
}
|
||||
|
||||
const std::lock_guard<std::mutex> lock{map_mutex};
|
||||
|
||||
/// add new loadable object or update an existing version
|
||||
if (object_it == std::end(loadable_objects))
|
||||
loadable_objects.emplace(name, LoadableInfo{std::move(object_ptr), config_path});
|
||||
else
|
||||
{
|
||||
if (object_it->second.loadable)
|
||||
object_it->second.loadable.reset();
|
||||
object_it->second.loadable = std::move(object_ptr);
|
||||
|
||||
/// erase stored exception on success
|
||||
object_it->second.exception = std::exception_ptr{};
|
||||
failed_loadable_objects.erase(name);
|
||||
}
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
if (!name.empty())
|
||||
{
|
||||
/// If the loadable object could not load data or even failed to initialize from the config.
|
||||
/// - all the same we insert information into the `loadable_objects`, with the zero pointer `loadable`.
|
||||
|
||||
const std::lock_guard<std::mutex> lock{map_mutex};
|
||||
|
||||
const auto exception_ptr = std::current_exception();
|
||||
const auto loadable_it = loadable_objects.find(name);
|
||||
if (loadable_it == std::end(loadable_objects))
|
||||
loadable_objects.emplace(name, LoadableInfo{nullptr, config_path, exception_ptr});
|
||||
else
|
||||
loadable_it->second.exception = exception_ptr;
|
||||
}
|
||||
|
||||
tryLogCurrentException(log, "Cannot create " + object_name + " '"
|
||||
+ name + "' from config path " + config_path);
|
||||
|
||||
/// propagate exception
|
||||
if (throw_on_error)
|
||||
throw;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ExternalLoader::reload()
|
||||
{
|
||||
reloadFromConfigFiles(true, true);
|
||||
}
|
||||
|
||||
void ExternalLoader::reload(const std::string & name)
|
||||
{
|
||||
reloadFromConfigFiles(true, true, name);
|
||||
|
||||
/// Check that specified object was loaded
|
||||
const std::lock_guard<std::mutex> lock{map_mutex};
|
||||
if (!loadable_objects.count(name))
|
||||
throw Exception("Failed to load " + object_name + " '" + name + "' during the reload process", ErrorCodes::BAD_ARGUMENTS);
|
||||
}
|
||||
|
||||
ExternalLoader::LoadablePtr ExternalLoader::getLoadable(const std::string & name) const
|
||||
{
|
||||
const std::lock_guard<std::mutex> lock{map_mutex};
|
||||
|
||||
const auto it = loadable_objects.find(name);
|
||||
if (it == std::end(loadable_objects))
|
||||
throw Exception("No such " + object_name + ": " + name, ErrorCodes::BAD_ARGUMENTS);
|
||||
|
||||
if (!it->second.loadable)
|
||||
it->second.exception ? std::rethrow_exception(it->second.exception) :
|
||||
throw Exception{object_name + " '" + name + "' is not loaded", ErrorCodes::LOGICAL_ERROR};
|
||||
|
||||
return it->second.loadable;
|
||||
}
|
||||
|
||||
ExternalLoader::LockedObjectsMap ExternalLoader::getObjectsMap() const
|
||||
{
|
||||
return LockedObjectsMap(map_mutex, loadable_objects);
|
||||
}
|
||||
|
||||
}
|
173
dbms/src/Interpreters/ExternalLoader.h
Normal file
173
dbms/src/Interpreters/ExternalLoader.h
Normal file
@ -0,0 +1,173 @@
|
||||
#pragma once
|
||||
|
||||
#include <common/logger_useful.h>
|
||||
#include <Poco/Event.h>
|
||||
#include <mutex>
|
||||
#include <thread>
|
||||
#include <unordered_map>
|
||||
#include <chrono>
|
||||
#include <tuple>
|
||||
#include <Interpreters/IExternalLoadable.h>
|
||||
#include <Core/Types.h>
|
||||
#include <pcg_random.hpp>
|
||||
#include <Common/randomSeed.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
class Context;
|
||||
|
||||
struct ExternalLoaderUpdateSettings
|
||||
{
|
||||
UInt64 check_period_sec = 5;
|
||||
UInt64 backoff_initial_sec = 5;
|
||||
/// 10 minutes
|
||||
UInt64 backoff_max_sec = 10 * 60;
|
||||
|
||||
ExternalLoaderUpdateSettings() = default;
|
||||
ExternalLoaderUpdateSettings(UInt64 check_period_sec, UInt64 backoff_initial_sec, UInt64 backoff_max_sec)
|
||||
: check_period_sec(check_period_sec),
|
||||
backoff_initial_sec(backoff_initial_sec),
|
||||
backoff_max_sec(backoff_max_sec) {}
|
||||
};
|
||||
|
||||
|
||||
/* External configuration structure.
|
||||
*
|
||||
* <external_group>
|
||||
* <external_config>
|
||||
* <external_name>name</external_name>
|
||||
* ....
|
||||
* </external_config>
|
||||
* </external_group>
|
||||
*/
|
||||
struct ExternalLoaderConfigSettings
|
||||
{
|
||||
std::string external_config;
|
||||
std::string external_name;
|
||||
|
||||
std::string path_setting_name;
|
||||
};
|
||||
|
||||
/** Manages user-defined objects.
|
||||
* Monitors configuration file and automatically reloads objects in a separate thread.
|
||||
* The monitoring thread wakes up every @check_period_sec seconds and checks
|
||||
* modification time of objects' configuration file. If said time is greater than
|
||||
* @config_last_modified, the objects are created from scratch using configuration file,
|
||||
* possibly overriding currently existing objects with the same name (previous versions of
|
||||
* overridden objects will live as long as there are any users retaining them).
|
||||
*
|
||||
* Apart from checking configuration file for modifications, each object
|
||||
* has a lifetime of its own and may be updated if it supportUpdates.
|
||||
* The time of next update is calculated by choosing uniformly a random number
|
||||
* distributed between lifetime.min_sec and lifetime.max_sec.
|
||||
* If either of lifetime.min_sec and lifetime.max_sec is zero, such object is never updated.
|
||||
*/
|
||||
class ExternalLoader
|
||||
{
|
||||
public:
|
||||
using LoadablePtr = std::shared_ptr<IExternalLoadable>;
|
||||
|
||||
private:
|
||||
struct LoadableInfo final
|
||||
{
|
||||
LoadablePtr loadable;
|
||||
std::string origin;
|
||||
std::exception_ptr exception;
|
||||
};
|
||||
|
||||
struct FailedLoadableInfo final
|
||||
{
|
||||
std::unique_ptr<IExternalLoadable> loadable;
|
||||
std::chrono::system_clock::time_point next_attempt_time;
|
||||
UInt64 error_count;
|
||||
};
|
||||
|
||||
public:
|
||||
using Configuration = Poco::Util::AbstractConfiguration;
|
||||
using ObjectsMap = std::unordered_map<std::string, LoadableInfo>;
|
||||
|
||||
/// Objects will be loaded immediately and then will be updated in separate thread, each 'reload_period' seconds.
|
||||
ExternalLoader(const Configuration & config,
|
||||
const ExternalLoaderUpdateSettings & update_settings,
|
||||
const ExternalLoaderConfigSettings & config_settings,
|
||||
Logger * log, const std::string & loadable_object_name);
|
||||
virtual ~ExternalLoader();
|
||||
|
||||
/// Forcibly reloads all loadable objects.
|
||||
void reload();
|
||||
|
||||
/// Forcibly reloads specified loadable object.
|
||||
void reload(const std::string & name);
|
||||
|
||||
LoadablePtr getLoadable(const std::string & name) const;
|
||||
|
||||
protected:
|
||||
virtual std::unique_ptr<IExternalLoadable> create(const std::string & name, const Configuration & config,
|
||||
const std::string & config_prefix) = 0;
|
||||
|
||||
class LockedObjectsMap
|
||||
{
|
||||
public:
|
||||
LockedObjectsMap(std::mutex & mutex, const ObjectsMap & objectsMap) : lock(mutex), objectsMap(objectsMap) {}
|
||||
const ObjectsMap & get() { return objectsMap; }
|
||||
private:
|
||||
std::unique_lock<std::mutex> lock;
|
||||
const ObjectsMap & objectsMap;
|
||||
};
|
||||
|
||||
/// Direct access to objects.
|
||||
LockedObjectsMap getObjectsMap() const;
|
||||
|
||||
/// Should be called in derived constructor (to avoid pure virtual call).
|
||||
void init(bool throw_on_error);
|
||||
|
||||
private:
|
||||
|
||||
bool is_initialized = false;
|
||||
|
||||
/// Protects only objects map.
|
||||
mutable std::mutex map_mutex;
|
||||
|
||||
/// Protects all data, currently used to avoid races between updating thread and SYSTEM queries
|
||||
mutable std::mutex all_mutex;
|
||||
|
||||
/// name -> loadable.
|
||||
ObjectsMap loadable_objects;
|
||||
|
||||
/// Here are loadable objects, that has been never loaded successfully.
|
||||
/// They are also in 'loadable_objects', but with nullptr as 'loadable'.
|
||||
std::unordered_map<std::string, FailedLoadableInfo> failed_loadable_objects;
|
||||
|
||||
/// Both for loadable_objects and failed_loadable_objects.
|
||||
std::unordered_map<std::string, std::chrono::system_clock::time_point> update_times;
|
||||
|
||||
pcg64 rnd_engine{randomSeed()};
|
||||
|
||||
const Configuration & config;
|
||||
const ExternalLoaderUpdateSettings & update_settings;
|
||||
const ExternalLoaderConfigSettings & config_settings;
|
||||
|
||||
std::thread reloading_thread;
|
||||
Poco::Event destroy;
|
||||
|
||||
Logger * log;
|
||||
/// Loadable object name to use in log messages.
|
||||
std::string object_name;
|
||||
|
||||
std::unordered_map<std::string, Poco::Timestamp> last_modification_times;
|
||||
|
||||
/// Check objects definitions in config files and reload or/and add new ones if the definition is changed
|
||||
/// If loadable_name is not empty, load only loadable object with name loadable_name
|
||||
void reloadFromConfigFiles(bool throw_on_error, bool force_reload = false, const std::string & loadable_name = "");
|
||||
void reloadFromConfigFile(const std::string & config_path, bool throw_on_error, bool force_reload,
|
||||
const std::string & loadable_name);
|
||||
|
||||
/// Check config files and update expired loadable objects
|
||||
void reloadAndUpdate(bool throw_on_error = false);
|
||||
|
||||
void reloadPeriodically();
|
||||
};
|
||||
|
||||
}
|
67
dbms/src/Interpreters/ExternalModels.cpp
Normal file
67
dbms/src/Interpreters/ExternalModels.cpp
Normal file
@ -0,0 +1,67 @@
|
||||
#include <Interpreters/ExternalModels.h>
|
||||
#include <Interpreters/Context.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int INVALID_CONFIG_PARAMETER;
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
const ExternalLoaderUpdateSettings externalModelsUpdateSettings;
|
||||
|
||||
const ExternalLoaderConfigSettings & getExternalModelsConfigSettings()
|
||||
{
|
||||
static ExternalLoaderConfigSettings settings;
|
||||
static std::once_flag flag;
|
||||
|
||||
std::call_once(flag, [] {
|
||||
settings.external_config = "model";
|
||||
settings.external_name = "name";
|
||||
|
||||
settings.path_setting_name = "models_config";
|
||||
});
|
||||
|
||||
return settings;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
ExternalModels::ExternalModels(Context & context, bool throw_on_error)
|
||||
: ExternalLoader(context.getConfigRef(),
|
||||
externalModelsUpdateSettings,
|
||||
getExternalModelsConfigSettings(),
|
||||
&Logger::get("ExternalModels"),
|
||||
"external model"),
|
||||
context(context)
|
||||
{
|
||||
init(throw_on_error);
|
||||
}
|
||||
|
||||
std::unique_ptr<IExternalLoadable> ExternalModels::create(
|
||||
const std::string & name, const Configuration & config, const std::string & config_prefix)
|
||||
{
|
||||
String type = config.getString(config_prefix + ".type");
|
||||
ExternalLoadableLifetime lifetime(config, config_prefix + ".lifetime");
|
||||
|
||||
/// TODO: add models factory.
|
||||
if (type == "catboost")
|
||||
{
|
||||
return std::make_unique<CatBoostModel>(
|
||||
name, config.getString(config_prefix + ".path"),
|
||||
context.getConfigRef().getString("catboost_dynamic_library_path"),
|
||||
lifetime, config.getUInt(config_prefix + ".float_features_count"),
|
||||
config.getUInt(config_prefix + ".cat_features_count")
|
||||
);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw Exception("Unknown model type: " + type, ErrorCodes::INVALID_CONFIG_PARAMETER);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
41
dbms/src/Interpreters/ExternalModels.h
Normal file
41
dbms/src/Interpreters/ExternalModels.h
Normal file
@ -0,0 +1,41 @@
|
||||
#pragma once
|
||||
|
||||
#include <Dictionaries/CatBoostModel.h>
|
||||
#include <Interpreters/ExternalLoader.h>
|
||||
#include <common/logger_useful.h>
|
||||
#include <memory>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
class Context;
|
||||
|
||||
/// Manages user-defined models.
|
||||
class ExternalModels : public ExternalLoader
|
||||
{
|
||||
public:
|
||||
using ModelPtr = std::shared_ptr<IModel>;
|
||||
|
||||
/// Models will be loaded immediately and then will be updated in separate thread, each 'reload_period' seconds.
|
||||
ExternalModels(Context & context, bool throw_on_error);
|
||||
|
||||
/// Forcibly reloads specified model.
|
||||
void reloadModel(const std::string & name) { reload(name); }
|
||||
|
||||
ModelPtr getModel(const std::string & name) const
|
||||
{
|
||||
return std::static_pointer_cast<IModel>(getLoadable(name));
|
||||
}
|
||||
|
||||
protected:
|
||||
|
||||
std::unique_ptr<IExternalLoadable> create(const std::string & name, const Configuration & config,
|
||||
const std::string & config_prefix) override;
|
||||
|
||||
private:
|
||||
|
||||
Context & context;
|
||||
};
|
||||
|
||||
}
|
45
dbms/src/Interpreters/IExternalLoadable.h
Normal file
45
dbms/src/Interpreters/IExternalLoadable.h
Normal file
@ -0,0 +1,45 @@
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <Core/Types.h>
|
||||
|
||||
|
||||
namespace Poco::Util
|
||||
{
|
||||
class AbstractConfiguration;
|
||||
}
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
/// Min and max lifetimes for a loadable object or it's entry
|
||||
struct ExternalLoadableLifetime final
|
||||
{
|
||||
UInt64 min_sec;
|
||||
UInt64 max_sec;
|
||||
|
||||
ExternalLoadableLifetime(const Poco::Util::AbstractConfiguration & config, const std::string & config_prefix);
|
||||
};
|
||||
|
||||
|
||||
/// Basic interface for external loadable objects. Is used in ExternalLoader.
|
||||
class IExternalLoadable : public std::enable_shared_from_this<IExternalLoadable>
|
||||
{
|
||||
public:
|
||||
virtual ~IExternalLoadable() = default;
|
||||
|
||||
virtual const ExternalLoadableLifetime & getLifetime() const = 0;
|
||||
|
||||
virtual std::string getName() const = 0;
|
||||
/// True if object can be updated when lifetime exceeded.
|
||||
virtual bool supportUpdates() const = 0;
|
||||
/// If lifetime exceeded and isModified() ExternalLoader replace current object with the result of clone().
|
||||
virtual bool isModified() const = 0;
|
||||
/// Returns new object with the same configuration. Is used to update modified object when lifetime exceeded.
|
||||
virtual std::unique_ptr<IExternalLoadable> clone() const = 0;
|
||||
|
||||
virtual std::exception_ptr getCreationException() const = 0;
|
||||
};
|
||||
|
||||
}
|
@ -76,16 +76,17 @@ BlockInputStreams StorageSystemDictionaries::read(
|
||||
ColumnWithTypeAndName col_source{std::make_shared<ColumnString>(), std::make_shared<DataTypeString>(), "source"};
|
||||
|
||||
const auto & external_dictionaries = context.getExternalDictionaries();
|
||||
const std::lock_guard<std::mutex> lock{external_dictionaries.dictionaries_mutex};
|
||||
auto objects_map = external_dictionaries.getObjectsMap();
|
||||
const auto & dictionaries = objects_map.get();
|
||||
|
||||
for (const auto & dict_info : external_dictionaries.dictionaries)
|
||||
for (const auto & dict_info : dictionaries)
|
||||
{
|
||||
col_name.column->insert(dict_info.first);
|
||||
col_origin.column->insert(dict_info.second.origin);
|
||||
|
||||
if (dict_info.second.dict)
|
||||
if (dict_info.second.loadable)
|
||||
{
|
||||
const auto dict_ptr = dict_info.second.dict->get();
|
||||
const auto dict_ptr = std::static_pointer_cast<IDictionaryBase>(dict_info.second.loadable);
|
||||
|
||||
col_type.column->insert(dict_ptr->getTypeName());
|
||||
|
||||
|
17
dbms/tests/external_models/catboost/data/build_catboost.sh
Executable file
17
dbms/tests/external_models/catboost/data/build_catboost.sh
Executable file
@ -0,0 +1,17 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
|
||||
|
||||
cd $DIR
|
||||
git clone https://github.com/catboost/catboost.git
|
||||
|
||||
|
||||
cd "${DIR}/catboost/catboost/libs/model_interface"
|
||||
../../../ya make -r -o "${DIR}/build/lib" -j4
|
||||
cd $DIR
|
||||
ln -sf "${DIR}/build/lib/catboost/libs/model_interface/libcatboostmodel.so" libcatboostmodel.so
|
||||
|
||||
cd "${DIR}/catboost/catboost/python-package/catboost"
|
||||
../../../ya make -r -DUSE_ARCADIA_PYTHON=no -DPYTHON_CONFIG=python2-config -j4
|
||||
cd $DIR
|
||||
ln -sf "${DIR}/catboost/catboost/python-package" python-package
|
42
dbms/tests/external_models/catboost/helpers/client.py
Normal file
42
dbms/tests/external_models/catboost/helpers/client.py
Normal file
@ -0,0 +1,42 @@
|
||||
import subprocess
|
||||
import threading
|
||||
import os
|
||||
|
||||
|
||||
class ClickHouseClient:
|
||||
def __init__(self, binary_path, port):
|
||||
self.binary_path = binary_path
|
||||
self.port = port
|
||||
|
||||
def query(self, query, timeout=10, pipe=None):
|
||||
|
||||
result = []
|
||||
process = []
|
||||
|
||||
def run(path, port, text, result, in_pipe, process):
|
||||
|
||||
if in_pipe is None:
|
||||
in_pipe = subprocess.PIPE
|
||||
|
||||
pipe = subprocess.Popen([path, 'client', '--port', str(port), '-q', text],
|
||||
stdin=in_pipe, stdout=subprocess.PIPE, stderr=subprocess.PIPE, close_fds=True)
|
||||
process.append(pipe)
|
||||
stdout_data, stderr_data = pipe.communicate()
|
||||
|
||||
if stderr_data:
|
||||
raise Exception('Error while executing query: {}\nstdout:\n{}\nstderr:\n{}'
|
||||
.format(text, stdout_data, stderr_data))
|
||||
|
||||
result.append(stdout_data)
|
||||
|
||||
thread = threading.Thread(target=run, args=(self.binary_path, self.port, query, result, pipe, process))
|
||||
thread.start()
|
||||
thread.join(timeout)
|
||||
if thread.isAlive():
|
||||
if len(process):
|
||||
process[0].kill()
|
||||
thread.join()
|
||||
raise Exception('timeout exceed for query: ' + query)
|
||||
|
||||
if len(result):
|
||||
return result[0]
|
15
dbms/tests/external_models/catboost/helpers/generate.py
Normal file
15
dbms/tests/external_models/catboost/helpers/generate.py
Normal file
@ -0,0 +1,15 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
def generate_uniform_int_column(size, low, high, seed=0):
|
||||
np.random.seed(seed)
|
||||
return np.random.randint(low, high, size)
|
||||
|
||||
|
||||
def generate_uniform_float_column(size, low, high, seed=0):
|
||||
np.random.seed(seed)
|
||||
return np.random.random(size) * (high - low) + low
|
||||
|
||||
|
||||
def generate_uniform_string_column(size, samples, seed):
|
||||
return np.array(samples)[generate_uniform_int_column(size, 0, len(samples), seed)]
|
67
dbms/tests/external_models/catboost/helpers/server.py
Normal file
67
dbms/tests/external_models/catboost/helpers/server.py
Normal file
@ -0,0 +1,67 @@
|
||||
import subprocess
|
||||
import threading
|
||||
import socket
|
||||
import time
|
||||
|
||||
|
||||
class ClickHouseServer:
|
||||
def __init__(self, binary_path, config_path, stdout_file=None, stderr_file=None, shutdown_timeout=10):
|
||||
self.binary_path = binary_path
|
||||
self.config_path = config_path
|
||||
self.pipe = None
|
||||
self.stdout_file = stdout_file
|
||||
self.stderr_file = stderr_file
|
||||
self.shutdown_timeout = shutdown_timeout
|
||||
|
||||
def start(self):
|
||||
cmd = [self.binary_path, 'server', '--config', self.config_path]
|
||||
out_pipe = None
|
||||
err_pipe = None
|
||||
if self.stdout_file is not None:
|
||||
out_pipe = open(self.stdout_file, 'w')
|
||||
if self.stderr_file is not None:
|
||||
err_pipe = open(self.stderr_file, 'w')
|
||||
self.pipe = subprocess.Popen(cmd, stdout=out_pipe, stderr=err_pipe)
|
||||
|
||||
def wait_for_request(self, port, timeout=1):
|
||||
try:
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
# is not working
|
||||
# s.settimeout(timeout)
|
||||
|
||||
step = 0.01
|
||||
for iter in range(int(timeout / step)):
|
||||
if s.connect_ex(('127.0.0.1', port)) == 0:
|
||||
return
|
||||
time.sleep(step)
|
||||
|
||||
s.connect(('127.0.0.1', port))
|
||||
except socket.error as socketerror:
|
||||
print "Error: ", socketerror
|
||||
raise
|
||||
|
||||
def shutdown(self, timeout=10):
|
||||
|
||||
def wait(pipe):
|
||||
pipe.wait()
|
||||
|
||||
if self.pipe is not None:
|
||||
self.pipe.terminate()
|
||||
thread = threading.Thread(target=wait, args=(self.pipe,))
|
||||
thread.start()
|
||||
thread.join(timeout)
|
||||
if thread.isAlive():
|
||||
self.pipe.kill()
|
||||
thread.join()
|
||||
|
||||
if self.pipe.stdout is not None:
|
||||
self.pipe.stdout.close()
|
||||
if self.pipe.stderr is not None:
|
||||
self.pipe.stderr.close()
|
||||
|
||||
def __enter__(self):
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
self.shutdown(self.shutdown_timeout)
|
@ -0,0 +1,166 @@
|
||||
from server import ClickHouseServer
|
||||
from client import ClickHouseClient
|
||||
from table import ClickHouseTable
|
||||
import os
|
||||
import errno
|
||||
from shutil import rmtree
|
||||
|
||||
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
|
||||
CATBOOST_ROOT = os.path.dirname(SCRIPT_DIR)
|
||||
|
||||
CLICKHOUSE_CONFIG = \
|
||||
'''
|
||||
<yandex>
|
||||
<timezone>Europe/Moscow</timezone>
|
||||
<listen_host>::</listen_host>
|
||||
<path>{path}</path>
|
||||
<tmp_path>{tmp_path}</tmp_path>
|
||||
<models_config>{models_config}</models_config>
|
||||
<mark_cache_size>5368709120</mark_cache_size>
|
||||
<users_config>users.xml</users_config>
|
||||
<tcp_port>{tcp_port}</tcp_port>
|
||||
<catboost_dynamic_library_path>{catboost_dynamic_library_path}</catboost_dynamic_library_path>
|
||||
</yandex>
|
||||
'''
|
||||
|
||||
CLICKHOUSE_USERS = \
|
||||
'''
|
||||
<yandex>
|
||||
<profiles>
|
||||
<default>
|
||||
</default>
|
||||
<readonly>
|
||||
<readonly>1</readonly>
|
||||
</readonly>
|
||||
</profiles>
|
||||
|
||||
<users>
|
||||
<readonly>
|
||||
<password></password>
|
||||
<profile>readonly</profile>
|
||||
<quota>default</quota>
|
||||
</readonly>
|
||||
|
||||
<default>
|
||||
<password></password>
|
||||
<profile>default</profile>
|
||||
<quota>default</quota>
|
||||
<networks incl="networks" replace="replace">
|
||||
<ip>::1</ip>
|
||||
<ip>127.0.0.1</ip>
|
||||
</networks>
|
||||
|
||||
</default>
|
||||
</users>
|
||||
|
||||
<quotas>
|
||||
<default>
|
||||
</default>
|
||||
</quotas>
|
||||
</yandex>
|
||||
'''
|
||||
|
||||
CATBOOST_MODEL_CONFIG = \
|
||||
'''
|
||||
<models>
|
||||
<model>
|
||||
<type>catboost</type>
|
||||
<name>{name}</name>
|
||||
<path>{path}</path>
|
||||
<float_features_count>{float_features_count}</float_features_count>
|
||||
<cat_features_count>{cat_features_count}</cat_features_count>
|
||||
<lifetime>0</lifetime>
|
||||
</model>
|
||||
</models>
|
||||
'''
|
||||
|
||||
|
||||
class ClickHouseServerWithCatboostModels:
|
||||
def __init__(self, name, binary_path, port, shutdown_timeout=10, clean_folder=False):
|
||||
self.models = {}
|
||||
self.name = name
|
||||
self.binary_path = binary_path
|
||||
self.port = port
|
||||
self.shutdown_timeout = shutdown_timeout
|
||||
self.clean_folder = clean_folder
|
||||
self.root = os.path.join(CATBOOST_ROOT, 'data', 'servers')
|
||||
self.config_path = os.path.join(self.root, 'config.xml')
|
||||
self.users_path = os.path.join(self.root, 'users.xml')
|
||||
self.models_dir = os.path.join(self.root, 'models')
|
||||
self.server = None
|
||||
|
||||
def _get_server(self):
|
||||
stdout_file = os.path.join(self.root, 'server_stdout.txt')
|
||||
stderr_file = os.path.join(self.root, 'server_stderr.txt')
|
||||
return ClickHouseServer(self.binary_path, self.config_path, stdout_file, stderr_file, self.shutdown_timeout)
|
||||
|
||||
def add_model(self, model_name, model, float_features_count, cat_features_count):
|
||||
self.models[model_name] = (float_features_count, cat_features_count, model)
|
||||
|
||||
def apply_model(self, name, df, cat_feature_names):
|
||||
names = list(df)
|
||||
float_feature_names = tuple(name for name in names if name not in cat_feature_names)
|
||||
with ClickHouseTable(self.server, self.port, name, df) as table:
|
||||
return table.apply_model(name, cat_feature_names, float_feature_names)
|
||||
|
||||
def _create_root(self):
|
||||
try:
|
||||
os.makedirs(self.root)
|
||||
except OSError as exc: # Python >2.5
|
||||
if exc.errno == errno.EEXIST and os.path.isdir(self.root):
|
||||
pass
|
||||
else:
|
||||
raise
|
||||
|
||||
def _clean_root(self):
|
||||
rmtree(self.root)
|
||||
|
||||
def _save_config(self):
|
||||
params = {
|
||||
'tcp_port': self.port,
|
||||
'path': os.path.join(self.root, 'clickhouse'),
|
||||
'tmp_path': os.path.join(self.root, 'clickhouse', 'tmp'),
|
||||
'models_config': os.path.join(self.models_dir, '*_model.xml'),
|
||||
'catboost_dynamic_library_path': os.path.join(CATBOOST_ROOT, 'data', 'libcatboostmodel.so')
|
||||
}
|
||||
config = CLICKHOUSE_CONFIG.format(**params)
|
||||
|
||||
with open(self.config_path, 'w') as f:
|
||||
f.write(config)
|
||||
|
||||
with open(self.users_path, 'w') as f:
|
||||
f.write(CLICKHOUSE_USERS)
|
||||
|
||||
def _save_models(self):
|
||||
if not os.path.exists(self.models_dir):
|
||||
os.makedirs(self.models_dir)
|
||||
|
||||
for name, params in self.models.items():
|
||||
float_features_count, cat_features_count, model = params
|
||||
model_path = os.path.join(self.models_dir, name + '.cbm')
|
||||
config_path = os.path.join(self.models_dir, name + '_model.xml')
|
||||
params = {
|
||||
'name': name,
|
||||
'path': model_path,
|
||||
'float_features_count': float_features_count,
|
||||
'cat_features_count': cat_features_count
|
||||
}
|
||||
config = CATBOOST_MODEL_CONFIG.format(**params)
|
||||
with open(config_path, 'w') as f:
|
||||
f.write(config)
|
||||
|
||||
model.save_model(model_path)
|
||||
|
||||
def __enter__(self):
|
||||
self._create_root()
|
||||
self._save_config()
|
||||
self._save_models()
|
||||
self.server = self._get_server().__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
res = self.server.__exit__(exc_type, exc_val, exc_tb)
|
||||
if self.clean_folder:
|
||||
self._clean_root()
|
||||
return res
|
||||
|
69
dbms/tests/external_models/catboost/helpers/table.py
Normal file
69
dbms/tests/external_models/catboost/helpers/table.py
Normal file
@ -0,0 +1,69 @@
|
||||
from server import ClickHouseServer
|
||||
from client import ClickHouseClient
|
||||
from pandas import DataFrame
|
||||
import os
|
||||
import threading
|
||||
import tempfile
|
||||
|
||||
|
||||
class ClickHouseTable:
|
||||
def __init__(self, server, port, table_name, df):
|
||||
self.server = server
|
||||
self.port = port
|
||||
self.table_name = table_name
|
||||
self.df = df
|
||||
|
||||
if not isinstance(self.server, ClickHouseServer):
|
||||
raise Exception('Expected ClickHouseServer, got ' + repr(self.server))
|
||||
if not isinstance(self.df, DataFrame):
|
||||
raise Exception('Expected DataFrame, got ' + repr(self.df))
|
||||
|
||||
self.server.wait_for_request(port)
|
||||
self.client = ClickHouseClient(server.binary_path, port)
|
||||
|
||||
def _convert(self, name):
|
||||
types_map = {
|
||||
'float64': 'Float64',
|
||||
'int64': 'Int64',
|
||||
'float32': 'Float32',
|
||||
'int32': 'Int32'
|
||||
}
|
||||
|
||||
if name in types_map:
|
||||
return types_map[name]
|
||||
return 'String'
|
||||
|
||||
def _create_table_from_df(self):
|
||||
self.client.query('create database if not exists test')
|
||||
self.client.query('drop table if exists test.{}'.format(self.table_name))
|
||||
|
||||
column_types = list(self.df.dtypes)
|
||||
column_names = list(self.df)
|
||||
schema = ', '.join((name + ' ' + self._convert(str(t)) for name, t in zip(column_names, column_types)))
|
||||
print 'schema:', schema
|
||||
|
||||
create_query = 'create table test.{} (date Date DEFAULT today(), {}) engine = MergeTree(date, (date), 8192)'
|
||||
self.client.query(create_query.format(self.table_name, schema))
|
||||
|
||||
insert_query = 'insert into test.{} ({}) format CSV'
|
||||
|
||||
with tempfile.TemporaryFile() as tmp_file:
|
||||
self.df.to_csv(tmp_file, header=False, index=False)
|
||||
tmp_file.seek(0)
|
||||
self.client.query(insert_query.format(self.table_name, ', '.join(column_names)), pipe=tmp_file)
|
||||
|
||||
def apply_model(self, model_name, float_columns, cat_columns):
|
||||
columns = ', '.join(list(float_columns) + list(cat_columns))
|
||||
query = "select modelEvaluate('{}', {}) from test.{} format TSV"
|
||||
result = self.client.query(query.format(model_name, columns, self.table_name))
|
||||
return tuple(map(float, filter(len, map(str.strip, result.split()))))
|
||||
|
||||
def _drop_table(self):
|
||||
self.client.query('drop table test.{}'.format(self.table_name))
|
||||
|
||||
def __enter__(self):
|
||||
self._create_table_from_df()
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
self._drop_table()
|
28
dbms/tests/external_models/catboost/helpers/train.py
Normal file
28
dbms/tests/external_models/catboost/helpers/train.py
Normal file
@ -0,0 +1,28 @@
|
||||
import os
|
||||
import sys
|
||||
from pandas import DataFrame
|
||||
|
||||
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
|
||||
CATBOOST_ROOT = os.path.dirname(SCRIPT_DIR)
|
||||
CATBOOST_PYTHON_DIR = os.path.join(CATBOOST_ROOT, 'data', 'python-package')
|
||||
|
||||
if CATBOOST_PYTHON_DIR not in sys.path:
|
||||
sys.path.append(CATBOOST_PYTHON_DIR)
|
||||
|
||||
|
||||
import catboost
|
||||
from catboost import CatBoostClassifier
|
||||
|
||||
|
||||
def train_catboost_model(df, target, cat_features, params, verbose=True):
|
||||
|
||||
if not isinstance(df, DataFrame):
|
||||
raise Exception('DataFrame object expected, but got ' + repr(df))
|
||||
|
||||
print 'features:', df.columns.tolist()
|
||||
|
||||
cat_features_index = list(df.columns.get_loc(feature) for feature in cat_features)
|
||||
print 'cat features:', cat_features_index
|
||||
model = CatBoostClassifier(**params)
|
||||
model.fit(df, target, cat_features=cat_features_index, verbose=verbose)
|
||||
return model
|
3
dbms/tests/external_models/catboost/pytest.ini
Normal file
3
dbms/tests/external_models/catboost/pytest.ini
Normal file
@ -0,0 +1,3 @@
|
||||
[pytest]
|
||||
python_files = test.py
|
||||
norecursedirs=data
|
@ -0,0 +1,236 @@
|
||||
from helpers.server_with_models import ClickHouseServerWithCatboostModels
|
||||
from helpers.generate import generate_uniform_string_column, generate_uniform_float_column, generate_uniform_int_column
|
||||
from helpers.train import train_catboost_model
|
||||
import os
|
||||
import numpy as np
|
||||
from pandas import DataFrame
|
||||
|
||||
|
||||
PORT = int(os.environ.get('CLICKHOUSE_TESTS_PORT', '9000'))
|
||||
CLICKHOUSE_TESTS_SERVER_BIN_PATH = os.environ.get('CLICKHOUSE_TESTS_SERVER_BIN_PATH', '/usr/bin/clickhouse')
|
||||
|
||||
|
||||
def add_noise_to_target(target, seed, threshold=0.05):
|
||||
col = generate_uniform_float_column(len(target), 0., 1., seed + 1) < threshold
|
||||
return target * (1 - col) + (1 - target) * col
|
||||
|
||||
|
||||
def check_predictions(test_name, target, pred_python, pred_ch, acc_threshold):
|
||||
ch_class = pred_ch.astype(int)
|
||||
python_class = pred_python.astype(int)
|
||||
if not np.array_equal(ch_class, python_class):
|
||||
raise Exception('Got different results:\npython:\n' + str(python_class) + '\nClickHouse:\n' + str(ch_class))
|
||||
|
||||
acc = 1 - np.sum(np.abs(ch_class - np.array(target))) / (len(target) + .0)
|
||||
assert acc >= acc_threshold
|
||||
print test_name, 'accuracy: {:.10f}'.format(acc)
|
||||
|
||||
|
||||
def test_apply_float_features_only():
|
||||
|
||||
name = 'test_apply_float_features_only'
|
||||
|
||||
train_size = 10000
|
||||
test_size = 10000
|
||||
|
||||
def gen_data(size, seed):
|
||||
data = {
|
||||
'a': generate_uniform_float_column(size, 0., 1., seed + 1),
|
||||
'b': generate_uniform_float_column(size, 0., 1., seed + 2),
|
||||
'c': generate_uniform_float_column(size, 0., 1., seed + 3)
|
||||
}
|
||||
return DataFrame.from_dict(data)
|
||||
|
||||
def get_target(df):
|
||||
def target_filter(row):
|
||||
return 1 if (row['a'] > .3 and row['b'] > .3) or (row['c'] < .4 and row['a'] * row['b'] > 0.1) else 0
|
||||
return df.apply(target_filter, axis=1).as_matrix()
|
||||
|
||||
train_df = gen_data(train_size, 42)
|
||||
test_df = gen_data(test_size, 43)
|
||||
|
||||
train_target = get_target(train_df)
|
||||
test_target = get_target(test_df)
|
||||
|
||||
print
|
||||
print 'train target', train_target
|
||||
print 'test target', test_target
|
||||
|
||||
params = {
|
||||
'iterations': 4,
|
||||
'depth': 2,
|
||||
'learning_rate': 1,
|
||||
'loss_function': 'Logloss'
|
||||
}
|
||||
|
||||
model = train_catboost_model(train_df, train_target, [], params)
|
||||
pred_python = model.predict(test_df)
|
||||
|
||||
server = ClickHouseServerWithCatboostModels(name, CLICKHOUSE_TESTS_SERVER_BIN_PATH, PORT)
|
||||
server.add_model(name, model, 3, 0)
|
||||
with server:
|
||||
pred_ch = (np.array(server.apply_model(name, test_df, [])) > 0).astype(int)
|
||||
|
||||
print 'python predictions', pred_python
|
||||
print 'clickhouse predictions', pred_ch
|
||||
|
||||
check_predictions(name, test_target, pred_python, pred_ch, 0.9)
|
||||
|
||||
|
||||
def test_apply_float_features_with_string_cat_features():
|
||||
|
||||
name = 'test_apply_float_features_with_string_cat_features'
|
||||
|
||||
train_size = 10000
|
||||
test_size = 10000
|
||||
|
||||
def gen_data(size, seed):
|
||||
data = {
|
||||
'a': generate_uniform_float_column(size, 0., 1., seed + 1),
|
||||
'b': generate_uniform_float_column(size, 0., 1., seed + 2),
|
||||
'c': generate_uniform_string_column(size, ['a', 'b', 'c'], seed + 3),
|
||||
'd': generate_uniform_string_column(size, ['e', 'f', 'g'], seed + 4)
|
||||
}
|
||||
return DataFrame.from_dict(data)
|
||||
|
||||
def get_target(df):
|
||||
def target_filter(row):
|
||||
return 1 if (row['a'] > .3 and row['b'] > .3 and row['c'] != 'a') \
|
||||
or (row['a'] * row['b'] > 0.1 and row['c'] != 'b' and row['d'] != 'e') else 0
|
||||
return df.apply(target_filter, axis=1).as_matrix()
|
||||
|
||||
train_df = gen_data(train_size, 42)
|
||||
test_df = gen_data(test_size, 43)
|
||||
|
||||
train_target = get_target(train_df)
|
||||
test_target = get_target(test_df)
|
||||
|
||||
print
|
||||
print 'train target', train_target
|
||||
print 'test target', test_target
|
||||
|
||||
params = {
|
||||
'iterations': 6,
|
||||
'depth': 2,
|
||||
'learning_rate': 1,
|
||||
'loss_function': 'Logloss'
|
||||
}
|
||||
|
||||
model = train_catboost_model(train_df, train_target, ['c', 'd'], params)
|
||||
pred_python = model.predict(test_df)
|
||||
|
||||
server = ClickHouseServerWithCatboostModels(name, CLICKHOUSE_TESTS_SERVER_BIN_PATH, PORT)
|
||||
server.add_model(name, model, 2, 2)
|
||||
with server:
|
||||
pred_ch = (np.array(server.apply_model(name, test_df, [])) > 0).astype(int)
|
||||
|
||||
print 'python predictions', pred_python
|
||||
print 'clickhouse predictions', pred_ch
|
||||
|
||||
check_predictions(name, test_target, pred_python, pred_ch, 0.9)
|
||||
|
||||
|
||||
def test_apply_float_features_with_int_cat_features():
|
||||
|
||||
name = 'test_apply_float_features_with_int_cat_features'
|
||||
|
||||
train_size = 10000
|
||||
test_size = 10000
|
||||
|
||||
def gen_data(size, seed):
|
||||
data = {
|
||||
'a': generate_uniform_float_column(size, 0., 1., seed + 1),
|
||||
'b': generate_uniform_float_column(size, 0., 1., seed + 2),
|
||||
'c': generate_uniform_int_column(size, 1, 4, seed + 3),
|
||||
'd': generate_uniform_int_column(size, 1, 4, seed + 4)
|
||||
}
|
||||
return DataFrame.from_dict(data)
|
||||
|
||||
def get_target(df):
|
||||
def target_filter(row):
|
||||
return 1 if (row['a'] > .3 and row['b'] > .3 and row['c'] != 1) \
|
||||
or (row['a'] * row['b'] > 0.1 and row['c'] != 2 and row['d'] != 3) else 0
|
||||
return df.apply(target_filter, axis=1).as_matrix()
|
||||
|
||||
train_df = gen_data(train_size, 42)
|
||||
test_df = gen_data(test_size, 43)
|
||||
|
||||
train_target = get_target(train_df)
|
||||
test_target = get_target(test_df)
|
||||
|
||||
print
|
||||
print 'train target', train_target
|
||||
print 'test target', test_target
|
||||
|
||||
params = {
|
||||
'iterations': 6,
|
||||
'depth': 4,
|
||||
'learning_rate': 1,
|
||||
'loss_function': 'Logloss'
|
||||
}
|
||||
|
||||
model = train_catboost_model(train_df, train_target, ['c', 'd'], params)
|
||||
pred_python = model.predict(test_df)
|
||||
|
||||
server = ClickHouseServerWithCatboostModels(name, CLICKHOUSE_TESTS_SERVER_BIN_PATH, PORT)
|
||||
server.add_model(name, model, 2, 2)
|
||||
with server:
|
||||
pred_ch = (np.array(server.apply_model(name, test_df, [])) > 0).astype(int)
|
||||
|
||||
print 'python predictions', pred_python
|
||||
print 'clickhouse predictions', pred_ch
|
||||
|
||||
check_predictions(name, test_target, pred_python, pred_ch, 0.9)
|
||||
|
||||
|
||||
def test_apply_float_features_with_mixed_cat_features():
|
||||
|
||||
name = 'test_apply_float_features_with_mixed_cat_features'
|
||||
|
||||
train_size = 10000
|
||||
test_size = 10000
|
||||
|
||||
def gen_data(size, seed):
|
||||
data = {
|
||||
'a': generate_uniform_float_column(size, 0., 1., seed + 1),
|
||||
'b': generate_uniform_float_column(size, 0., 1., seed + 2),
|
||||
'c': generate_uniform_string_column(size, ['a', 'b', 'c'], seed + 3),
|
||||
'd': generate_uniform_int_column(size, 1, 4, seed + 4)
|
||||
}
|
||||
return DataFrame.from_dict(data)
|
||||
|
||||
def get_target(df):
|
||||
def target_filter(row):
|
||||
return 1 if (row['a'] > .3 and row['b'] > .3 and row['c'] != 'a') \
|
||||
or (row['a'] * row['b'] > 0.1 and row['c'] != 'b' and row['d'] != 2) else 0
|
||||
return df.apply(target_filter, axis=1).as_matrix()
|
||||
|
||||
train_df = gen_data(train_size, 42)
|
||||
test_df = gen_data(test_size, 43)
|
||||
|
||||
train_target = get_target(train_df)
|
||||
test_target = get_target(test_df)
|
||||
|
||||
print
|
||||
print 'train target', train_target
|
||||
print 'test target', test_target
|
||||
|
||||
params = {
|
||||
'iterations': 6,
|
||||
'depth': 4,
|
||||
'learning_rate': 1,
|
||||
'loss_function': 'Logloss'
|
||||
}
|
||||
|
||||
model = train_catboost_model(train_df, train_target, ['c', 'd'], params)
|
||||
pred_python = model.predict(test_df)
|
||||
|
||||
server = ClickHouseServerWithCatboostModels(name, CLICKHOUSE_TESTS_SERVER_BIN_PATH, PORT)
|
||||
server.add_model(name, model, 2, 2)
|
||||
with server:
|
||||
pred_ch = (np.array(server.apply_model(name, test_df, [])) > 0).astype(int)
|
||||
|
||||
print 'python predictions', pred_python
|
||||
print 'clickhouse predictions', pred_ch
|
||||
|
||||
check_predictions(name, test_target, pred_python, pred_ch, 0.9)
|
Loading…
Reference in New Issue
Block a user