diff --git a/src/Interpreters/DictionaryReader.cpp b/src/Interpreters/DictionaryReader.cpp new file mode 100644 index 00000000000..301fe9d57c6 --- /dev/null +++ b/src/Interpreters/DictionaryReader.cpp @@ -0,0 +1,167 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int NUMBER_OF_COLUMNS_DOESNT_MATCH; + extern const int TYPE_MISMATCH; +} + + +DictionaryReader::FunctionWrapper::FunctionWrapper(FunctionOverloadResolverPtr resolver, const ColumnsWithTypeAndName & arguments, + Block & block, const ColumnNumbers & arg_positions_, const String & column_name, + TypeIndex expected_type) + : arg_positions(arg_positions_) + , result_pos(block.columns()) +{ + FunctionBasePtr prepared_function = resolver->build(arguments); + + ColumnWithTypeAndName result; + result.name = "get_" + column_name; + result.type = prepared_function->getReturnType(); + if (result.type->getTypeId() != expected_type) + throw Exception("Type mismatch in dictionary reader for: " + column_name, ErrorCodes::TYPE_MISMATCH); + block.insert(result); + + function = prepared_function->prepare(block, arg_positions, result_pos); +} + +static constexpr const size_t key_size = 1; + +DictionaryReader::DictionaryReader(const String & dictionary_name, const Names & src_column_names, const NamesAndTypesList & result_columns, + const Context & context) + : result_header(makeResultBlock(result_columns)) + , key_position(key_size + result_header.columns()) +{ + if (src_column_names.size() != result_columns.size()) + throw Exception("Columns number mismatch in dictionary reader", ErrorCodes::NUMBER_OF_COLUMNS_DOESNT_MATCH); + + ColumnWithTypeAndName dict_name; + ColumnWithTypeAndName key; + ColumnWithTypeAndName column_name; + + { + dict_name.name = "dict"; + dict_name.type = std::make_shared(); + dict_name.column = dict_name.type->createColumnConst(1, dictionary_name); + + /// TODO: composite key (key_size > 1) + key.name = "key"; + key.type = std::make_shared(); + + column_name.name = "column"; + column_name.type = std::make_shared(); + } + + /// dictHas('dict_name', id) + ColumnsWithTypeAndName arguments_has; + arguments_has.push_back(dict_name); + arguments_has.push_back(key); + + /// dictGet('dict_name', 'attr_name', id) + ColumnsWithTypeAndName arguments_get; + arguments_get.push_back(dict_name); + arguments_get.push_back(column_name); + arguments_get.push_back(key); + + sample_block.insert(dict_name); + + for (auto & columns_name : src_column_names) + { + ColumnWithTypeAndName name; + name.name = "col_" + columns_name; + name.type = std::make_shared(); + name.column = name.type->createColumnConst(1, columns_name); + + sample_block.insert(name); + } + + sample_block.insert(key); + + ColumnNumbers positions_has{0, key_position}; + function_has = std::make_unique(FunctionFactory::instance().get("dictHas", context), + arguments_has, sample_block, positions_has, "has", DataTypeUInt8().getTypeId()); + functions_get.reserve(result_header.columns()); + + for (size_t i = 0; i < result_header.columns(); ++i) + { + size_t column_name_pos = key_size + i; + auto & column = result_header.getByPosition(i); + arguments_get[1].column = DataTypeString().createColumnConst(1, src_column_names[i]); + ColumnNumbers positions_get{0, column_name_pos, key_position}; + functions_get.emplace_back( + FunctionWrapper(FunctionFactory::instance().get("dictGet", context), + arguments_get, sample_block, positions_get, column.name, column.type->getTypeId())); + } +} + +void DictionaryReader::readKeys(const IColumn & keys, Block & out_block, ColumnVector::Container & found, + std::vector & positions) const +{ + Block working_block = sample_block; + size_t has_position = key_position + 1; + size_t size = keys.size(); + + /// set keys for dictHas() + ColumnWithTypeAndName & key_column = working_block.getByPosition(key_position); + key_column.column = keys.cloneResized(size); /// just a copy we cannot avoid + + /// calculate and extract dictHas() + function_has->execute(working_block, size); + ColumnWithTypeAndName & has_column = working_block.getByPosition(has_position); + auto mutable_has = (*std::move(has_column.column)).mutate(); + found.swap(typeid_cast &>(*mutable_has).getData()); + has_column.column = nullptr; + + /// set mapping form source keys to resulting rows in output block + positions.clear(); + positions.resize(size, 0); + size_t pos = 0; + for (size_t i = 0; i < size; ++i) + if (found[i]) + positions[i] = pos++; + + /// set keys for dictGet(): remove not found keys + key_column.column = key_column.column->filter(found, -1); + size_t rows = key_column.column->size(); + + /// calculate dictGet() + for (auto & func : functions_get) + func.execute(working_block, rows); + + /// make result: copy header block with correct names and move data columns + out_block = result_header.cloneEmpty(); + size_t first_get_position = has_position + 1; + for (size_t i = 0; i < out_block.columns(); ++i) + { + auto & src_column = working_block.getByPosition(first_get_position + i); + auto & dst_column = out_block.getByPosition(i); + dst_column.column = src_column.column; + src_column.column = nullptr; + } +} + +Block DictionaryReader::makeResultBlock(const NamesAndTypesList & names) +{ + Block block; + for (auto & nm : names) + { + ColumnWithTypeAndName column{nullptr, nm.type, nm.name}; + if (column.type->isNullable()) + column.type = typeid_cast(*column.type).getNestedType(); + block.insert(std::move(column)); + } + return block; +} + +} diff --git a/src/Interpreters/DictionaryReader.h b/src/Interpreters/DictionaryReader.h index 823a3690669..92e4924ae80 100644 --- a/src/Interpreters/DictionaryReader.h +++ b/src/Interpreters/DictionaryReader.h @@ -1,25 +1,16 @@ #pragma once -#include #include -#include -#include -#include #include -#include -#include -#include -#include +#include namespace DB { -namespace ErrorCodes -{ - extern const int NUMBER_OF_COLUMNS_DOESNT_MATCH; - extern const int TYPE_MISMATCH; -} +class Context; +/// Read block of required columns from Dictionary by UInt64 key column. Rename columns if needed. +/// Current implementation uses dictHas() + N * dictGet() functions. class DictionaryReader { public: @@ -30,21 +21,7 @@ public: size_t result_pos = 0; FunctionWrapper(FunctionOverloadResolverPtr resolver, const ColumnsWithTypeAndName & arguments, Block & block, - const ColumnNumbers & arg_positions_, const String & column_name, TypeIndex expected_type) - : arg_positions(arg_positions_) - , result_pos(block.columns()) - { - FunctionBasePtr prepared_function = resolver->build(arguments); - - ColumnWithTypeAndName result; - result.name = "get_" + column_name; - result.type = prepared_function->getReturnType(); - if (result.type->getTypeId() != expected_type) - throw Exception("Type mismatch in dictionary reader for: " + column_name, ErrorCodes::TYPE_MISMATCH); - block.insert(result); - - function = prepared_function->prepare(block, arg_positions, result_pos); - } + const ColumnNumbers & arg_positions_, const String & column_name, TypeIndex expected_type); void execute(Block & block, size_t rows) const { @@ -53,116 +30,8 @@ public: }; DictionaryReader(const String & dictionary_name, const Names & src_column_names, const NamesAndTypesList & result_columns, - const Context & context, size_t key_size = 1) - : result_header(makeResultBlock(result_columns)) - , key_position(key_size + result_header.columns()) - { - if (src_column_names.size() != result_columns.size()) - throw Exception("Columns number mismatch in dictionary reader", ErrorCodes::NUMBER_OF_COLUMNS_DOESNT_MATCH); - - ColumnWithTypeAndName dict_name; - ColumnWithTypeAndName key; - ColumnWithTypeAndName column_name; - - { - dict_name.name = "dict"; - dict_name.type = std::make_shared(); - dict_name.column = dict_name.type->createColumnConst(1, dictionary_name); - - /// TODO: composite key (key_size > 1) - key.name = "key"; - key.type = std::make_shared(); - - column_name.name = "column"; - column_name.type = std::make_shared(); - } - - /// dictHas('dict_name', id) - ColumnsWithTypeAndName arguments_has; - arguments_has.push_back(dict_name); - arguments_has.push_back(key); - - /// dictGet('dict_name', 'attr_name', id) - ColumnsWithTypeAndName arguments_get; - arguments_get.push_back(dict_name); - arguments_get.push_back(column_name); - arguments_get.push_back(key); - - sample_block.insert(dict_name); - - for (auto & columns_name : src_column_names) - { - ColumnWithTypeAndName name; - name.name = "col_" + columns_name; - name.type = std::make_shared(); - name.column = name.type->createColumnConst(1, columns_name); - - sample_block.insert(name); - } - - sample_block.insert(key); - - ColumnNumbers positions_has{0, key_position}; - function_has = std::make_unique(FunctionFactory::instance().get("dictHas", context), - arguments_has, sample_block, positions_has, "has", DataTypeUInt8().getTypeId()); - functions_get.reserve(result_header.columns()); - - for (size_t i = 0; i < result_header.columns(); ++i) - { - size_t column_name_pos = key_size + i; - auto & column = result_header.getByPosition(i); - arguments_get[1].column = DataTypeString().createColumnConst(1, src_column_names[i]); - ColumnNumbers positions_get{0, column_name_pos, key_position}; - functions_get.emplace_back( - FunctionWrapper(FunctionFactory::instance().get("dictGet", context), - arguments_get, sample_block, positions_get, column.name, column.type->getTypeId())); - } - } - - void readKeys(const IColumn & keys, size_t size, Block & out_block, ColumnVector::Container & found, - std::vector & positions) const - { - Block working_block = sample_block; - size_t has_position = key_position + 1; - - /// set keys for dictHas() - ColumnWithTypeAndName & key_column = working_block.getByPosition(key_position); - key_column.column = keys.cloneResized(size); /// just a copy we cannot avoid - - /// calculate and extract dictHas() - function_has->execute(working_block, size); - ColumnWithTypeAndName & has_column = working_block.getByPosition(has_position); - auto mutable_has = (*std::move(has_column.column)).mutate(); - found.swap(typeid_cast &>(*mutable_has).getData()); - has_column.column = nullptr; - - /// set mapping form source keys to resulting rows in output block - positions.clear(); - positions.resize(size, 0); - size_t pos = 0; - for (size_t i = 0; i < size; ++i) - if (found[i]) - positions[i] = pos++; - - /// set keys for dictGet(): remove not found keys - key_column.column = key_column.column->filter(found, -1); - size_t rows = key_column.column->size(); - - /// calculate dictGet() - for (auto & func : functions_get) - func.execute(working_block, rows); - - /// make result: copy header block with correct names and move data columns - out_block = result_header.cloneEmpty(); - size_t first_get_position = has_position + 1; - for (size_t i = 0; i < out_block.columns(); ++i) - { - auto & src_column = working_block.getByPosition(first_get_position + i); - auto & dst_column = out_block.getByPosition(i); - dst_column.column = src_column.column; - src_column.column = nullptr; - } - } + const Context & context); + void readKeys(const IColumn & keys, Block & out_block, ColumnVector::Container & found, std::vector & positions) const; private: Block result_header; @@ -171,18 +40,7 @@ private: std::unique_ptr function_has; std::vector functions_get; - static Block makeResultBlock(const NamesAndTypesList & names) - { - Block block; - for (auto & nm : names) - { - ColumnWithTypeAndName column{nullptr, nm.type, nm.name}; - if (column.type->isNullable()) - column.type = typeid_cast(*column.type).getNestedType(); - block.insert(std::move(column)); - } - return block; - } + static Block makeResultBlock(const NamesAndTypesList & names); }; } diff --git a/src/Interpreters/HashJoin.cpp b/src/Interpreters/HashJoin.cpp index f58efa1920f..22a8a87cbe0 100644 --- a/src/Interpreters/HashJoin.cpp +++ b/src/Interpreters/HashJoin.cpp @@ -300,7 +300,7 @@ public: const DictionaryReader & reader = *table_join.dictionary_reader; if (!read_result) { - reader.readKeys(*key_columns[0], key_columns[0]->size(), read_result, found, positions); + reader.readKeys(*key_columns[0], read_result, found, positions); result.block = &read_result; if (table_join.forceNullableRight())