fixed tests [#CLICKHOUSE-3305]

added aliases for catboost pool [#CLICKHOUSE-3305]
This commit is contained in:
Nikolai Kochetov 2017-10-31 15:22:42 +03:00
parent 554bb889ac
commit 2ae3f0b3b4
4 changed files with 37 additions and 18 deletions

View File

@ -8,6 +8,7 @@
#include <DataStreams/FilterColumnsBlockInputStream.h>
#include <Interpreters/Context.h>
#include <boost/filesystem.hpp>
#include <Parsers/ASTIdentifier.h>
namespace DB
{
@ -167,6 +168,10 @@ void StorageCatBoostPool::parseColumnDescription()
auto column_types_map = getColumnTypesMap();
auto column_types_string = getColumnTypesString(column_types_map);
/// Enumerate default names for columns as Auxiliary, Auxiliary1, Auxiliary2, ...
std::map<DatasetColumnType, size_t> columns_per_type_count;
size_t features_column_count = 0;
while (std::getline(in, line))
{
++line_num;
@ -188,7 +193,7 @@ void StorageCatBoostPool::parseColumnDescription()
std::string str_id = tokens[0];
std::string col_type = tokens[1];
std::string col_name = "feature" + (tokens.size() > 2 ? tokens[2] : str_id);
std::string col_alias = tokens.size() > 2 ? tokens[2] : "";
size_t num_id;
try
@ -211,9 +216,20 @@ void StorageCatBoostPool::parseColumnDescription()
ErrorCodes::CANNOT_PARSE_TEXT);
auto type = column_types_map[col_type];
std::string col_name;
if (type != DatasetColumnType::Num && type != DatasetColumnType::Categ)
col_name = col_type;
columns_description[num_id] = ColumnDescription(col_name, type);
{
auto & col_number = columns_per_type_count[type];
col_name = col_type + (col_number ? std::to_string(col_number) : "");
++col_number;
}
else
{
col_name = "feature" + std::to_string(features_column_count);
++features_column_count;
}
columns_description[num_id] = ColumnDescription(col_name, col_alias, type);
}
}
@ -240,6 +256,13 @@ void StorageCatBoostPool::createSampleBlockAndColumns()
else
materialized_columns.emplace_back(desc.column_name, type);
if (!desc.alias.empty())
{
auto alias = std::make_shared<ASTIdentifier>();
alias->name = desc.alias;
column_defaults[desc.alias] = {ColumnDefaultType::Alias, alias};
}
sample_block.insert(ColumnWithTypeAndName(type->createColumn(), type, desc.column_name));
}
columns.insert(columns.end(), num_columns.begin(), num_columns.end());

View File

@ -67,11 +67,12 @@ private:
struct ColumnDescription
{
std::string column_name;
std::string alias;
DatasetColumnType column_type;
ColumnDescription() : column_type(DatasetColumnType::Num) {}
ColumnDescription(std::string column_name, DatasetColumnType column_type)
: column_name(std::move(column_name)), column_type(column_type) {}
ColumnDescription(std::string column_name, std::string alias, DatasetColumnType column_type)
: column_name(std::move(column_name)), alias(std::move(alias)), column_type(column_type) {}
};
std::vector<ColumnDescription> columns_description;

View File

@ -67,8 +67,6 @@ CATBOOST_MODEL_CONFIG = \
<type>catboost</type>
<name>{name}</name>
<path>{path}</path>
<float_features_count>{float_features_count}</float_features_count>
<cat_features_count>{cat_features_count}</cat_features_count>
<lifetime>0</lifetime>
</model>
</models>
@ -94,8 +92,8 @@ class ClickHouseServerWithCatboostModels:
stderr_file = os.path.join(self.root, 'server_stderr.txt')
return ClickHouseServer(self.binary_path, self.config_path, stdout_file, stderr_file, self.shutdown_timeout)
def add_model(self, model_name, model, float_features_count, cat_features_count):
self.models[model_name] = (float_features_count, cat_features_count, model)
def add_model(self, model_name, model):
self.models[model_name] = model
def apply_model(self, name, df, cat_feature_names):
names = list(df)
@ -135,15 +133,12 @@ class ClickHouseServerWithCatboostModels:
if not os.path.exists(self.models_dir):
os.makedirs(self.models_dir)
for name, params in self.models.items():
float_features_count, cat_features_count, model = params
for name, model in self.models.items():
model_path = os.path.join(self.models_dir, name + '.cbm')
config_path = os.path.join(self.models_dir, name + '_model.xml')
params = {
'name': name,
'path': model_path,
'float_features_count': float_features_count,
'cat_features_count': cat_features_count
'path': model_path
}
config = CATBOOST_MODEL_CONFIG.format(**params)
with open(config_path, 'w') as f:

View File

@ -67,7 +67,7 @@ def test_apply_float_features_only():
pred_python = model.predict(test_df)
server = ClickHouseServerWithCatboostModels(name, CLICKHOUSE_TESTS_SERVER_BIN_PATH, PORT)
server.add_model(name, model, 3, 0)
server.add_model(name, model)
with server:
pred_ch = (np.array(server.apply_model(name, test_df, [])) > 0).astype(int)
@ -120,7 +120,7 @@ def test_apply_float_features_with_string_cat_features():
pred_python = model.predict(test_df)
server = ClickHouseServerWithCatboostModels(name, CLICKHOUSE_TESTS_SERVER_BIN_PATH, PORT)
server.add_model(name, model, 2, 2)
server.add_model(name, model)
with server:
pred_ch = (np.array(server.apply_model(name, test_df, [])) > 0).astype(int)
@ -173,7 +173,7 @@ def test_apply_float_features_with_int_cat_features():
pred_python = model.predict(test_df)
server = ClickHouseServerWithCatboostModels(name, CLICKHOUSE_TESTS_SERVER_BIN_PATH, PORT)
server.add_model(name, model, 2, 2)
server.add_model(name, model)
with server:
pred_ch = (np.array(server.apply_model(name, test_df, [])) > 0).astype(int)
@ -226,7 +226,7 @@ def test_apply_float_features_with_mixed_cat_features():
pred_python = model.predict(test_df)
server = ClickHouseServerWithCatboostModels(name, CLICKHOUSE_TESTS_SERVER_BIN_PATH, PORT)
server.add_model(name, model, 2, 2)
server.add_model(name, model)
with server:
pred_ch = (np.array(server.apply_model(name, test_df, [])) > 0).astype(int)