diff --git a/dbms/src/Interpreters/AnalyzedJoin.cpp b/dbms/src/Interpreters/AnalyzedJoin.cpp index e7115816920..5d79ad71ae2 100644 --- a/dbms/src/Interpreters/AnalyzedJoin.cpp +++ b/dbms/src/Interpreters/AnalyzedJoin.cpp @@ -1,6 +1,4 @@ #include -#include -#include #include @@ -229,27 +227,14 @@ bool AnalyzedJoin::sameStrictnessAndKind(ASTTableJoin::Strictness strictness_, A return false; } -JoinPtr makeJoin(std::shared_ptr table_join, const Block & right_sample_block) +bool AnalyzedJoin::allowMergeJoin() const { - auto kind = table_join->kind(); - auto strictness = table_join->strictness(); + bool is_any = (strictness() == ASTTableJoin::Strictness::Any); + bool is_all = (strictness() == ASTTableJoin::Strictness::All); + bool is_semi = (strictness() == ASTTableJoin::Strictness::Semi); - bool is_any = (strictness == ASTTableJoin::Strictness::Any); - bool is_all = (strictness == ASTTableJoin::Strictness::All); - bool is_semi = (strictness == ASTTableJoin::Strictness::Semi); - - bool allow_merge_join = (isLeft(kind) && (is_any || is_all || is_semi)) || (isInner(kind) && is_all); - - if (table_join->partial_merge_join && allow_merge_join) - return std::make_shared(table_join, right_sample_block); - return std::make_shared(table_join, right_sample_block); -} - -bool isMergeJoin(const JoinPtr & join) -{ - if (join) - return typeid_cast(join.get()); - return false; + bool allow_merge_join = (isLeft(kind()) && (is_any || is_all || is_semi)) || (isInner(kind()) && is_all); + return allow_merge_join && partial_merge_join; } } diff --git a/dbms/src/Interpreters/AnalyzedJoin.h b/dbms/src/Interpreters/AnalyzedJoin.h index fe89b6f47ef..a96ea54d5fe 100644 --- a/dbms/src/Interpreters/AnalyzedJoin.h +++ b/dbms/src/Interpreters/AnalyzedJoin.h @@ -87,6 +87,7 @@ public: bool sameStrictnessAndKind(ASTTableJoin::Strictness, ASTTableJoin::Kind) const; const SizeLimits & sizeLimits() const { return size_limits; } VolumePtr getTemporaryVolume() { return tmp_volume; } + bool allowMergeJoin() const; bool forceNullableRight() const { return join_use_nulls && isLeftOrFull(table_join.kind); } bool forceNullableLeft() const { return join_use_nulls && isRightOrFull(table_join.kind); } @@ -128,9 +129,6 @@ public: void setRightKeys(const Names & keys) { key_names_right = keys; } static bool sameJoin(const AnalyzedJoin * x, const AnalyzedJoin * y); - friend JoinPtr makeJoin(std::shared_ptr table_join, const Block & right_sample_block); }; -bool isMergeJoin(const JoinPtr &); - } diff --git a/dbms/src/Interpreters/ExpressionAnalyzer.cpp b/dbms/src/Interpreters/ExpressionAnalyzer.cpp index f131afb86c6..c430e348e13 100644 --- a/dbms/src/Interpreters/ExpressionAnalyzer.cpp +++ b/dbms/src/Interpreters/ExpressionAnalyzer.cpp @@ -29,7 +29,7 @@ #include #include #include -#include +#include #include #include @@ -564,7 +564,7 @@ JoinPtr SelectQueryExpressionAnalyzer::makeTableJoin(const ASTTablesInSelectQuer /// TODO You do not need to set this up when JOIN is only needed on remote servers. subquery_for_join.setJoinActions(joined_block_actions); /// changes subquery_for_join.sample_block inside - subquery_for_join.join = makeJoin(syntax->analyzed_join, subquery_for_join.sample_block); + subquery_for_join.join = std::make_shared(syntax->analyzed_join, subquery_for_join.sample_block); } return subquery_for_join.join; diff --git a/dbms/src/Interpreters/JoinSwitcher.h b/dbms/src/Interpreters/JoinSwitcher.h new file mode 100644 index 00000000000..4c627bd7b8e --- /dev/null +++ b/dbms/src/Interpreters/JoinSwitcher.h @@ -0,0 +1,72 @@ +#pragma once + +#include +#include +#include +#include + +namespace DB +{ + +class JoinSwitcher : public IJoin +{ +public: + JoinSwitcher(std::shared_ptr table_join, const Block & right_sample_block) + { + if (table_join->allowMergeJoin()) + join = std::make_shared(table_join, right_sample_block); + else + join = std::make_shared(table_join, right_sample_block); + } + + bool addJoinedBlock(const Block & block) override + { + /// TODO: switch Join -> MergeJoin + return join->addJoinedBlock(block); + } + + void joinBlock(Block & block, std::shared_ptr & not_processed) override + { + join->joinBlock(block, not_processed); + } + + bool hasTotals() const override + { + return join->hasTotals(); + } + + void setTotals(const Block & block) override + { + join->setTotals(block); + } + + void joinTotals(Block & block) const override + { + join->joinTotals(block); + } + + size_t getTotalRowCount() const override + { + return join->getTotalRowCount(); + } + + bool alwaysReturnsEmptySet() const override + { + return join->alwaysReturnsEmptySet(); + } + + BlockInputStreamPtr createStreamWithNonJoinedRows(const Block & block, UInt64 max_block_size) const override + { + return join->createStreamWithNonJoinedRows(block, max_block_size); + } + + bool hasStreamWithNonJoinedRows() const override + { + return join->hasStreamWithNonJoinedRows(); + } + +private: + JoinPtr join; +}; + +}