From 2632c080d7f304649ee614eb56e000a8f75f2262 Mon Sep 17 00:00:00 2001 From: Nikolai Kochetov Date: Fri, 28 Dec 2018 21:22:39 +0300 Subject: [PATCH] Added multiclass model test. --- .../catboost/helpers/server_with_models.py | 7 +++ .../external_models/catboost/helpers/table.py | 7 ++- .../test_apply_catboost_model/test.py | 58 +++++++++++++++++++ 3 files changed, 71 insertions(+), 1 deletion(-) diff --git a/dbms/tests/external_models/catboost/helpers/server_with_models.py b/dbms/tests/external_models/catboost/helpers/server_with_models.py index e0ed81980e1..ad9feea99fe 100644 --- a/dbms/tests/external_models/catboost/helpers/server_with_models.py +++ b/dbms/tests/external_models/catboost/helpers/server_with_models.py @@ -20,6 +20,13 @@ CLICKHOUSE_CONFIG = \ users.xml {tcp_port} {catboost_dynamic_library_path} + + trace + {path}/clickhouse-server.log + {path}/clickhouse-server.err.log + never + 50 + ''' diff --git a/dbms/tests/external_models/catboost/helpers/table.py b/dbms/tests/external_models/catboost/helpers/table.py index 2e9c454ab10..e6b05ac7b7b 100644 --- a/dbms/tests/external_models/catboost/helpers/table.py +++ b/dbms/tests/external_models/catboost/helpers/table.py @@ -56,7 +56,12 @@ class ClickHouseTable: columns = ', '.join(list(float_columns) + list(cat_columns)) query = "select modelEvaluate('{}', {}) from test.{} format TSV" result = self.client.query(query.format(model_name, columns, self.table_name)) - return tuple(map(float, filter(len, map(str.strip, result.split())))) + + def parse_row(row): + values = tuple(map(float, filter(len, map(str.strip, row.replace('(', '').replace(')', '').split(','))))) + return values if len(values) != 1 else values[0] + + return tuple(map(parse_row, filter(len, map(str.strip, result.split('\n'))))) def _drop_table(self): self.client.query('drop table test.{}'.format(self.table_name)) diff --git a/dbms/tests/external_models/catboost/test_apply_catboost_model/test.py b/dbms/tests/external_models/catboost/test_apply_catboost_model/test.py index 792ba9a13c8..00b9fe0dce1 100644 --- a/dbms/tests/external_models/catboost/test_apply_catboost_model/test.py +++ b/dbms/tests/external_models/catboost/test_apply_catboost_model/test.py @@ -234,3 +234,61 @@ def test_apply_float_features_with_mixed_cat_features(): print 'clickhouse predictions', pred_ch check_predictions(name, test_target, pred_python, pred_ch, 0.9) + + +def test_apply_multiclass(): + + name = 'test_apply_float_features_with_mixed_cat_features' + + train_size = 10000 + test_size = 10000 + + def gen_data(size, seed): + data = { + 'a': generate_uniform_float_column(size, 0., 1., seed + 1), + 'b': generate_uniform_float_column(size, 0., 1., seed + 2), + 'c': generate_uniform_string_column(size, ['a', 'b', 'c'], seed + 3), + 'd': generate_uniform_int_column(size, 1, 4, seed + 4) + } + return DataFrame.from_dict(data) + + def get_target(df): + def target_filter(row): + if row['a'] > .3 and row['b'] > .3 and row['c'] != 'a': + return 0 + elif row['a'] * row['b'] > 0.1 and row['c'] != 'b' and row['d'] != 2: + return 1 + else: + return 2 + + return df.apply(target_filter, axis=1).as_matrix() + + train_df = gen_data(train_size, 42) + test_df = gen_data(test_size, 43) + + train_target = get_target(train_df) + test_target = get_target(test_df) + + print + print 'train target', train_target + print 'test target', test_target + + params = { + 'iterations': 10, + 'depth': 4, + 'learning_rate': 1, + 'loss_function': 'MultiClass' + } + + model = train_catboost_model(train_df, train_target, ['c', 'd'], params) + pred_python = model.predict(test_df)[:,0].astype(int) + + server = ClickHouseServerWithCatboostModels(name, CLICKHOUSE_TESTS_SERVER_BIN_PATH, PORT) + server.add_model(name, model) + with server: + pred_ch = np.argmax(np.array(server.apply_model(name, test_df, [])), axis=1) + + print 'python predictions', pred_python + print 'clickhouse predictions', pred_ch + + check_predictions(name, test_target, pred_python, pred_ch, 0.9)