diff --git a/dbms/src/Interpreters/ExpressionAnalyzer.cpp b/dbms/src/Interpreters/ExpressionAnalyzer.cpp index 37239e0bd11..fef0b05ae51 100644 --- a/dbms/src/Interpreters/ExpressionAnalyzer.cpp +++ b/dbms/src/Interpreters/ExpressionAnalyzer.cpp @@ -451,10 +451,10 @@ bool SelectQueryExpressionAnalyzer::appendJoin(ExpressionActionsChain & chain, b return true; } -static JoinPtr tryGetStorageJoin(const ASTTablesInSelectQueryElement & join_element, const Context & context) +static JoinPtr tryGetStorageJoin(const ASTTablesInSelectQueryElement & join_element, std::shared_ptr analyzed_join, + const Context & context) { const auto & table_to_join = join_element.table_expression->as(); - auto & join_params = join_element.table_join->as(); /// TODO This syntax does not support specifying a database name. if (table_to_join.database_and_table_name) @@ -465,14 +465,8 @@ static JoinPtr tryGetStorageJoin(const ASTTablesInSelectQueryElement & join_elem if (table) { auto * storage_join = dynamic_cast(table.get()); - if (storage_join) - { - storage_join->assertCompatible(join_params.kind, join_params.strictness); - /// TODO Check the set of keys. - - return storage_join->getJoin(); - } + return storage_join->getJoin(analyzed_join); } } @@ -497,7 +491,7 @@ JoinPtr SelectQueryExpressionAnalyzer::makeTableJoin(const ASTTablesInSelectQuer /// Special case - if table name is specified on the right of JOIN, then the table has the type Join (the previously prepared mapping). if (!subquery_for_join.join) - subquery_for_join.join = tryGetStorageJoin(join_element, context); + subquery_for_join.join = tryGetStorageJoin(join_element, syntax->analyzed_join, context); if (!subquery_for_join.join) { diff --git a/dbms/src/Interpreters/Join.cpp b/dbms/src/Interpreters/Join.cpp index 4da687ac1e4..ff6e2c0690f 100644 --- a/dbms/src/Interpreters/Join.cpp +++ b/dbms/src/Interpreters/Join.cpp @@ -125,6 +125,7 @@ Join::Join(std::shared_ptr table_join_, const Block & right_sample , nullable_left_side(table_join->forceNullableLeft()) , any_take_last_row(any_take_last_row_) , asof_inequality(table_join->getAsofInequality()) + , data(std::make_shared()) , log(&Logger::get("Join")) { setSampleBlock(right_sample_block); @@ -260,26 +261,26 @@ struct KeyGetterForType void Join::init(Type type_) { - type = type_; + data->type = type_; if (kind == ASTTableJoin::Kind::Cross) return; - joinDispatchInit(kind, strictness, maps); - joinDispatch(kind, strictness, maps, [&](auto, auto, auto & map) { map.create(type); }); + joinDispatchInit(kind, strictness, data->maps); + joinDispatch(kind, strictness, data->maps, [&](auto, auto, auto & map) { map.create(data->type); }); } size_t Join::getTotalRowCount() const { size_t res = 0; - if (type == Type::CROSS) + if (data->type == Type::CROSS) { - for (const auto & block : blocks) + for (const auto & block : data->blocks) res += block.rows(); } else { - joinDispatch(kind, strictness, maps, [&](auto, auto, auto & map) { res += map.getTotalRowCount(type); }); + joinDispatch(kind, strictness, data->maps, [&](auto, auto, auto & map) { res += map.getTotalRowCount(data->type); }); } return res; @@ -289,15 +290,15 @@ size_t Join::getTotalByteCount() const { size_t res = 0; - if (type == Type::CROSS) + if (data->type == Type::CROSS) { - for (const auto & block : blocks) + for (const auto & block : data->blocks) res += block.bytes(); } else { - joinDispatch(kind, strictness, maps, [&](auto, auto, auto & map) { res += map.getTotalByteCountImpl(type); }); - res += pool.size(); + joinDispatch(kind, strictness, data->maps, [&](auto, auto, auto & map) { res += map.getTotalByteCountImpl(data->type); }); + res += data->pool.size(); } return res; @@ -482,6 +483,8 @@ void Join::initRequiredRightKeys() void Join::initRightBlockStructure() { + auto & saved_block_sample = data->sample_block; + if (isRightOrFull(kind)) { /// Save keys for NonJoinedBlockInputStream @@ -504,7 +507,7 @@ void Join::initRightBlockStructure() Block Join::structureRightBlock(const Block & block) const { Block structured_block; - for (auto & sample_column : saved_block_sample.getColumnsWithTypeAndName()) + for (auto & sample_column : savedBlockSample().getColumnsWithTypeAndName()) { ColumnWithTypeAndName column = block.getByName(sample_column.name); if (sample_column.column->isNullable()) @@ -543,24 +546,24 @@ bool Join::addJoinedBlock(const Block & source_block) size_t total_bytes = 0; { - std::unique_lock lock(rwlock); + std::unique_lock lock(data->rwlock); - blocks.emplace_back(std::move(structured_block)); - Block * stored_block = &blocks.back(); + data->blocks.emplace_back(std::move(structured_block)); + Block * stored_block = &data->blocks.back(); if (rows) - has_no_rows_in_maps = false; + data->empty = false; if (kind != ASTTableJoin::Kind::Cross) { - joinDispatch(kind, strictness, maps, [&](auto, auto strictness_, auto & map) + joinDispatch(kind, strictness, data->maps, [&](auto, auto strictness_, auto & map) { - insertFromBlockImpl(*this, type, map, rows, key_columns, key_sizes, stored_block, null_map, pool); + insertFromBlockImpl(*this, data->type, map, rows, key_columns, key_sizes, stored_block, null_map, data->pool); }); } if (save_nullmap) - blocks_nullmaps.emplace_back(stored_block, null_map_holder); + data->blocks_nullmaps.emplace_back(stored_block, null_map_holder); /// TODO: Do not calculate them every time total_rows = getTotalRowCount(); @@ -915,12 +918,12 @@ void Join::joinBlockImpl( if constexpr (is_asof_join) extras.push_back(right_table_keys.getByName(key_names_right.back())); - AddedColumns added_columns(sample_block_with_columns_to_add, block_with_columns_to_add, block, saved_block_sample, + AddedColumns added_columns(sample_block_with_columns_to_add, block_with_columns_to_add, block, savedBlockSample(), extras, *this, key_columns, key_sizes); bool has_required_right_keys = (required_right_keys.columns() != 0); added_columns.need_filter = need_filter || has_required_right_keys; - IColumn::Filter row_filter = switchJoinRightColumns(maps_, added_columns, type, null_map); + IColumn::Filter row_filter = switchJoinRightColumns(maps_, added_columns, data->type, null_map); for (size_t i = 0; i < added_columns.size(); ++i) block.insert(added_columns.moveColumn(i)); @@ -1012,7 +1015,7 @@ void Join::joinBlockImplCross(Block & block) const for (size_t i = 0; i < rows_left; ++i) { - for (const Block & block_right : blocks) + for (const Block & block_right : data->blocks) { size_t rows_right = block_right.rows(); @@ -1050,7 +1053,7 @@ static void checkTypeOfKey(const Block & block_left, const Block & block_right) DataTypePtr Join::joinGetReturnType(const String & column_name) const { - std::shared_lock lock(rwlock); + std::shared_lock lock(data->rwlock); if (!sample_block_with_columns_to_add.has(column_name)) throw Exception("StorageJoin doesn't contain column " + column_name, ErrorCodes::LOGICAL_ERROR); @@ -1071,7 +1074,7 @@ void Join::joinGetImpl(Block & block, const String & column_name, const Maps & m // TODO: return array of values when strictness == ASTTableJoin::Strictness::All void Join::joinGet(Block & block, const String & column_name) const { - std::shared_lock lock(rwlock); + std::shared_lock lock(data->rwlock); if (key_names_right.size() != 1) throw Exception("joinGet only supports StorageJoin containing exactly one key", ErrorCodes::LOGICAL_ERROR); @@ -1081,7 +1084,7 @@ void Join::joinGet(Block & block, const String & column_name) const if ((strictness == ASTTableJoin::Strictness::Any || strictness == ASTTableJoin::Strictness::RightAny) && kind == ASTTableJoin::Kind::Left) { - joinGetImpl(block, column_name, std::get(maps)); + joinGetImpl(block, column_name, std::get(data->maps)); } else throw Exception("joinGet only supports StorageJoin of type Left Any", ErrorCodes::LOGICAL_ERROR); @@ -1090,12 +1093,12 @@ void Join::joinGet(Block & block, const String & column_name) const void Join::joinBlock(Block & block) { - std::shared_lock lock(rwlock); + std::shared_lock lock(data->rwlock); const Names & key_names_left = table_join->keyNamesLeft(); JoinCommon::checkTypesOfKeys(block, key_names_left, right_table_keys, key_names_right); - if (joinDispatch(kind, strictness, maps, [&](auto kind_, auto strictness_, auto & map) + if (joinDispatch(kind, strictness, data->maps, [&](auto kind_, auto strictness_, auto & map) { joinBlockImpl(block, key_names_left, sample_block_with_columns_to_add, map); })) @@ -1172,7 +1175,7 @@ public: const String & right_key_name = parent.table_join->keyNamesRight()[i]; size_t left_key_pos = result_sample_block.getPositionByName(left_key_name); - size_t right_key_pos = parent.saved_block_sample.getPositionByName(right_key_name); + size_t right_key_pos = parent.savedBlockSample().getPositionByName(right_key_name); if (remap_keys && !parent.required_right_keys.has(right_key_name)) left_to_right_key_remap[left_key_pos] = right_key_pos; @@ -1194,9 +1197,10 @@ public: column_indices_left.emplace_back(left_pos); } - for (size_t right_pos = 0; right_pos < parent.saved_block_sample.columns(); ++right_pos) + const auto & saved_block_sample = parent.savedBlockSample(); + for (size_t right_pos = 0; right_pos < saved_block_sample.columns(); ++right_pos) { - const String & name = parent.saved_block_sample.getByPosition(right_pos).name; + const String & name = saved_block_sample.getByPosition(right_pos).name; if (!result_sample_block.has(name)) continue; @@ -1225,7 +1229,7 @@ public: protected: Block readImpl() override { - if (parent.blocks.empty()) + if (parent.data->blocks.empty()) return Block(); return createBlock(); } @@ -1262,14 +1266,14 @@ private: bool hasNullabilityChange(size_t right_pos, size_t result_pos) const { - const auto & src = parent.saved_block_sample.getByPosition(right_pos).column; + const auto & src = parent.savedBlockSample().getByPosition(right_pos).column; const auto & dst = result_sample_block.getByPosition(result_pos).column; return src->isNullable() != dst->isNullable(); } Block createBlock() { - MutableColumns columns_right = parent.saved_block_sample.cloneEmptyColumns(); + MutableColumns columns_right = parent.savedBlockSample().cloneEmptyColumns(); size_t rows_added = 0; @@ -1278,7 +1282,7 @@ private: rows_added = fillColumnsFromMap(map, columns_right); }; - if (!joinDispatch(parent.kind, parent.strictness, parent.maps, fill_callback)) + if (!joinDispatch(parent.kind, parent.strictness, parent.data->maps, fill_callback)) throw Exception("Logical error: unknown JOIN strictness (must be on of: ANY, ALL, ASOF)", ErrorCodes::LOGICAL_ERROR); fillNullsFromBlocks(columns_right, rows_added); @@ -1329,7 +1333,7 @@ private: template size_t fillColumnsFromMap(const Maps & maps, MutableColumns & columns_keys_and_right) { - switch (parent.type) + switch (parent.data->type) { #define M(TYPE) \ case Join::Type::TYPE: \ @@ -1337,7 +1341,7 @@ private: APPLY_FOR_JOIN_VARIANTS(M) #undef M default: - throw Exception("Unsupported JOIN keys. Type: " + toString(static_cast(parent.type)), + throw Exception("Unsupported JOIN keys. Type: " + toString(static_cast(parent.data->type)), ErrorCodes::UNSUPPORTED_JOIN_KEYS); } @@ -1380,9 +1384,9 @@ private: void fillNullsFromBlocks(MutableColumns & columns_keys_and_right, size_t & rows_added) { if (!nulls_position.has_value()) - nulls_position = parent.blocks_nullmaps.begin(); + nulls_position = parent.data->blocks_nullmaps.begin(); - auto end = parent.blocks_nullmaps.end(); + auto end = parent.data->blocks_nullmaps.end(); for (auto & it = *nulls_position; it != end && rows_added < max_block_size; ++it) { diff --git a/dbms/src/Interpreters/Join.h b/dbms/src/Interpreters/Join.h index ff46380db13..378bc2ef51a 100644 --- a/dbms/src/Interpreters/Join.h +++ b/dbms/src/Interpreters/Join.h @@ -148,7 +148,7 @@ class Join : public IJoin public: Join(std::shared_ptr table_join_, const Block & right_sample_block, bool any_take_last_row_ = false); - bool empty() { return type == Type::EMPTY; } + bool empty() { return data->type == Type::EMPTY; } /** Add block of data from right hand of JOIN to the map. * Returns false, if some limit was exceeded and you should not insert more data. @@ -185,7 +185,7 @@ public: /// Sum size in bytes of all buffers, used for JOIN maps and for all memory pools. size_t getTotalByteCount() const; - bool alwaysReturnsEmptySet() const final { return isInnerOrRight(getKind()) && has_no_rows_in_maps; } + bool alwaysReturnsEmptySet() const final { return isInnerOrRight(getKind()) && data->empty; } ASTTableJoin::Kind getKind() const { return kind; } ASTTableJoin::Strictness getStrictness() const { return strictness; } @@ -294,6 +294,30 @@ public: using MapsAsof = MapsTemplate; using MapsVariant = std::variant; + using BlockNullmapList = std::deque>; + + struct RightTableData + { + /// Protect state for concurrent use in insertFromBlock and joinBlock. + /// @note that these methods could be called simultaneously only while use of StorageJoin. + mutable std::shared_mutex rwlock; + + Type type = Type::EMPTY; + bool empty = true; + + MapsVariant maps; + Block sample_block; /// Block as it would appear in the BlockList + BlocksList blocks; /// Blocks of "right" table. + BlockNullmapList blocks_nullmaps; /// Nullmaps for blocks of "right" table (if needed) + + /// Additional data - strings for string keys and continuation elements of single-linked lists of references to rows. + Arena pool; + }; + + void reuseJoinedData(const Join & join) + { + data = join.data; + } private: friend class NonJoinedBlockInputStream; @@ -306,33 +330,14 @@ private: /// Names of key columns in right-side table (in the order they appear in ON/USING clause). @note It could contain duplicates. const Names & key_names_right; - /// In case of LEFT and FULL joins, if use_nulls, convert right-side columns to Nullable. - bool nullable_right_side; - /// In case of RIGHT and FULL joins, if use_nulls, convert left-side columns to Nullable. - bool nullable_left_side; - - /// Overwrite existing values when encountering the same key again - bool any_take_last_row; - - /// Blocks of "right" table. - BlocksList blocks; - - /// Nullmaps for blocks of "right" table (if needed) - using BlockNullmapList = std::deque>; - BlockNullmapList blocks_nullmaps; - - MapsVariant maps; - bool has_no_rows_in_maps = true; - - /// Additional data - strings for string keys and continuation elements of single-linked lists of references to rows. - Arena pool; - - Type type = Type::EMPTY; + bool nullable_right_side; /// In case of LEFT and FULL joins, if use_nulls, convert right-side columns to Nullable. + bool nullable_left_side; /// In case of RIGHT and FULL joins, if use_nulls, convert left-side columns to Nullable. + bool any_take_last_row; /// Overwrite existing values when encountering the same key again std::optional asof_type; ASOF::Inequality asof_inequality; - static Type chooseMethod(const ColumnRawPtrs & key_columns, Sizes & key_sizes); - + /// Right table data. StorageJoin shares it between many Join objects. + std::shared_ptr data; Sizes key_sizes; /// Block with columns from the right-side table except key columns. @@ -344,26 +349,18 @@ private: /// Left table column names that are sources for required_right_keys columns std::vector required_right_keys_sources; - /// Block as it would appear in the BlockList - Block saved_block_sample; - Poco::Logger * log; Block totals; - /** Protect state for concurrent use in insertFromBlock and joinBlock. - * Note that these methods could be called simultaneously only while use of StorageJoin, - * and StorageJoin only calls these two methods. - * That's why another methods are not guarded. - */ - mutable std::shared_mutex rwlock; - void init(Type type_); /** Set information about structure of right hand of JOIN (joined data). */ void setSampleBlock(const Block & block); + const Block & savedBlockSample() const { return data->sample_block; } + /// Modify (structure) right block to save it in block list Block structureRightBlock(const Block & stored_block) const; void initRightBlockStructure(); @@ -380,6 +377,8 @@ private: template void joinGetImpl(Block & block, const String & column_name, const Maps & maps) const; + + static Type chooseMethod(const ColumnRawPtrs & key_columns, Sizes & key_sizes); }; } diff --git a/dbms/src/Storages/StorageJoin.cpp b/dbms/src/Storages/StorageJoin.cpp index 12444867b6b..6c9f3fecd75 100644 --- a/dbms/src/Storages/StorageJoin.cpp +++ b/dbms/src/Storages/StorageJoin.cpp @@ -67,11 +67,16 @@ void StorageJoin::truncate(const ASTPtr &, const Context &, TableStructureWriteL } -void StorageJoin::assertCompatible(ASTTableJoin::Kind kind_, ASTTableJoin::Strictness strictness_) const +HashJoinPtr StorageJoin::getJoin(std::shared_ptr analyzed_join) const { - /// NOTE Could be more loose. - if (!(kind == kind_ && strictness == strictness_)) + if (!(kind == analyzed_join->kind() && strictness == analyzed_join->strictness())) throw Exception("Table " + table_name + " has incompatible type of JOIN.", ErrorCodes::INCOMPATIBLE_TYPE_OF_JOIN); + + /// TODO: check key columns + + HashJoinPtr join_clone = std::make_shared(analyzed_join, getSampleBlock().sortColumns()); + join_clone->reuseJoinedData(*join); + return join_clone; } @@ -201,7 +206,7 @@ class JoinBlockInputStream : public IBlockInputStream { public: JoinBlockInputStream(const Join & parent_, UInt64 max_block_size_, Block && sample_block_) - : parent(parent_), lock(parent.rwlock), max_block_size(max_block_size_), sample_block(std::move(sample_block_)) + : parent(parent_), lock(parent.data->rwlock), max_block_size(max_block_size_), sample_block(std::move(sample_block_)) { columns.resize(sample_block.columns()); column_indices.resize(sample_block.columns()); @@ -231,11 +236,11 @@ public: protected: Block readImpl() override { - if (parent.blocks.empty()) + if (parent.data->blocks.empty()) return Block(); Block block; - if (!joinDispatch(parent.kind, parent.strictness, parent.maps, + if (!joinDispatch(parent.kind, parent.strictness, parent.data->maps, [&](auto, auto strictness, auto & map) { block = createBlock(map); })) throw Exception("Logical error: unknown JOIN strictness (must be ANY or ALL)", ErrorCodes::LOGICAL_ERROR); return block; @@ -278,7 +283,7 @@ private: size_t rows_added = 0; - switch (parent.type) + switch (parent.data->type) { #define M(TYPE) \ case Join::Type::TYPE: \ @@ -288,7 +293,7 @@ private: #undef M default: - throw Exception("Unsupported JOIN keys in StorageJoin. Type: " + toString(static_cast(parent.type)), + throw Exception("Unsupported JOIN keys in StorageJoin. Type: " + toString(static_cast(parent.data->type)), ErrorCodes::UNSUPPORTED_JOIN_KEYS); } diff --git a/dbms/src/Storages/StorageJoin.h b/dbms/src/Storages/StorageJoin.h index cfafd118768..ab974a07bfa 100644 --- a/dbms/src/Storages/StorageJoin.h +++ b/dbms/src/Storages/StorageJoin.h @@ -31,6 +31,7 @@ public: /// Access the innards. HashJoinPtr & getJoin() { return join; } + HashJoinPtr getJoin(std::shared_ptr analyzed_join) const; /// Verify that the data structure is suitable for implementing this type of JOIN. void assertCompatible(ASTTableJoin::Kind kind_, ASTTableJoin::Strictness strictness_) const; diff --git a/dbms/tests/queries/0_stateless/01050_engine_join_crash.reference b/dbms/tests/queries/0_stateless/01050_engine_join_crash.reference new file mode 100644 index 00000000000..e5e5f07ad2c --- /dev/null +++ b/dbms/tests/queries/0_stateless/01050_engine_join_crash.reference @@ -0,0 +1,11 @@ +1 1 +2 2 +3 3 +1 1 +2 2 +3 3 +3 3 +2 2 +1 1 +- +- diff --git a/dbms/tests/queries/0_stateless/01050_engine_join_crash.sql b/dbms/tests/queries/0_stateless/01050_engine_join_crash.sql new file mode 100644 index 00000000000..836d2e26e3c --- /dev/null +++ b/dbms/tests/queries/0_stateless/01050_engine_join_crash.sql @@ -0,0 +1,43 @@ +DROP TABLE IF EXISTS testJoinTable; + +CREATE TABLE testJoinTable (number UInt64, data String) ENGINE = Join(ANY, INNER, number); + +INSERT INTO testJoinTable VALUES (1, '1'), (2, '2'), (3, '3'); + +SELECT * FROM (SELECT * FROM numbers(10)) INNER JOIN testJoinTable USING number; +SELECT * FROM (SELECT * FROM numbers(10)) INNER JOIN (SELECT * FROM testJoinTable) USING number; +SELECT * FROM testJoinTable; + +DROP TABLE testJoinTable; + +SELECT '-'; + + SET any_join_distinct_right_table_keys = 1; + +DROP TABLE IF EXISTS master; +DROP TABLE IF EXISTS transaction; + +CREATE TABLE master (id Int32, name String) ENGINE = Join (ANY, LEFT, id); +CREATE TABLE transaction (id Int32, value Float64, master_id Int32) ENGINE = MergeTree() ORDER BY id; + +INSERT INTO master VALUES (1, 'ONE'); +INSERT INTO transaction VALUES (1, 52.5, 1); + +SELECT tx.id, tx.value, m.name FROM transaction tx ANY LEFT JOIN master m ON m.id = tx.master_id; + +DROP TABLE master; +DROP TABLE transaction; + +SELECT '-'; + +DROP TABLE IF EXISTS some_join; +DROP TABLE IF EXISTS tbl; + +CREATE TABLE some_join (id String, value String) ENGINE = Join(ANY, LEFT, id); +CREATE TABLE tbl (eventDate Date, id String) ENGINE = MergeTree() PARTITION BY tuple() ORDER BY eventDate; + +SELECT * FROM tbl AS t ANY LEFT JOIN some_join USING (id); +SELECT * FROM tbl AS t ANY LEFT JOIN some_join AS d USING (id); + +DROP TABLE some_join; +DROP TABLE tbl;