mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-28 18:42:26 +00:00
Added multiclass model test.
This commit is contained in:
parent
bdaf1ac109
commit
2632c080d7
@ -20,6 +20,13 @@ CLICKHOUSE_CONFIG = \
|
||||
<users_config>users.xml</users_config>
|
||||
<tcp_port>{tcp_port}</tcp_port>
|
||||
<catboost_dynamic_library_path>{catboost_dynamic_library_path}</catboost_dynamic_library_path>
|
||||
<logger>
|
||||
<level>trace</level>
|
||||
<log>{path}/clickhouse-server.log</log>
|
||||
<errorlog>{path}/clickhouse-server.err.log</errorlog>
|
||||
<size>never</size>
|
||||
<count>50</count>
|
||||
</logger>
|
||||
</yandex>
|
||||
'''
|
||||
|
||||
|
@ -56,7 +56,12 @@ class ClickHouseTable:
|
||||
columns = ', '.join(list(float_columns) + list(cat_columns))
|
||||
query = "select modelEvaluate('{}', {}) from test.{} format TSV"
|
||||
result = self.client.query(query.format(model_name, columns, self.table_name))
|
||||
return tuple(map(float, filter(len, map(str.strip, result.split()))))
|
||||
|
||||
def parse_row(row):
|
||||
values = tuple(map(float, filter(len, map(str.strip, row.replace('(', '').replace(')', '').split(',')))))
|
||||
return values if len(values) != 1 else values[0]
|
||||
|
||||
return tuple(map(parse_row, filter(len, map(str.strip, result.split('\n')))))
|
||||
|
||||
def _drop_table(self):
|
||||
self.client.query('drop table test.{}'.format(self.table_name))
|
||||
|
@ -234,3 +234,61 @@ def test_apply_float_features_with_mixed_cat_features():
|
||||
print 'clickhouse predictions', pred_ch
|
||||
|
||||
check_predictions(name, test_target, pred_python, pred_ch, 0.9)
|
||||
|
||||
|
||||
def test_apply_multiclass():
|
||||
|
||||
name = 'test_apply_float_features_with_mixed_cat_features'
|
||||
|
||||
train_size = 10000
|
||||
test_size = 10000
|
||||
|
||||
def gen_data(size, seed):
|
||||
data = {
|
||||
'a': generate_uniform_float_column(size, 0., 1., seed + 1),
|
||||
'b': generate_uniform_float_column(size, 0., 1., seed + 2),
|
||||
'c': generate_uniform_string_column(size, ['a', 'b', 'c'], seed + 3),
|
||||
'd': generate_uniform_int_column(size, 1, 4, seed + 4)
|
||||
}
|
||||
return DataFrame.from_dict(data)
|
||||
|
||||
def get_target(df):
|
||||
def target_filter(row):
|
||||
if row['a'] > .3 and row['b'] > .3 and row['c'] != 'a':
|
||||
return 0
|
||||
elif row['a'] * row['b'] > 0.1 and row['c'] != 'b' and row['d'] != 2:
|
||||
return 1
|
||||
else:
|
||||
return 2
|
||||
|
||||
return df.apply(target_filter, axis=1).as_matrix()
|
||||
|
||||
train_df = gen_data(train_size, 42)
|
||||
test_df = gen_data(test_size, 43)
|
||||
|
||||
train_target = get_target(train_df)
|
||||
test_target = get_target(test_df)
|
||||
|
||||
print
|
||||
print 'train target', train_target
|
||||
print 'test target', test_target
|
||||
|
||||
params = {
|
||||
'iterations': 10,
|
||||
'depth': 4,
|
||||
'learning_rate': 1,
|
||||
'loss_function': 'MultiClass'
|
||||
}
|
||||
|
||||
model = train_catboost_model(train_df, train_target, ['c', 'd'], params)
|
||||
pred_python = model.predict(test_df)[:,0].astype(int)
|
||||
|
||||
server = ClickHouseServerWithCatboostModels(name, CLICKHOUSE_TESTS_SERVER_BIN_PATH, PORT)
|
||||
server.add_model(name, model)
|
||||
with server:
|
||||
pred_ch = np.argmax(np.array(server.apply_model(name, test_df, [])), axis=1)
|
||||
|
||||
print 'python predictions', pred_python
|
||||
print 'clickhouse predictions', pred_ch
|
||||
|
||||
check_predictions(name, test_target, pred_python, pred_ch, 0.9)
|
||||
|
Loading…
Reference in New Issue
Block a user