From bb3dedf1dc441bf07fa8b5869eced214c018eea1 Mon Sep 17 00:00:00 2001 From: chertus Date: Tue, 3 Sep 2019 17:36:02 +0300 Subject: [PATCH] move Join object from ExpressionAction into AnalyzedJoin --- dbms/src/Interpreters/AnalyzedJoin.cpp | 25 +++++++++- dbms/src/Interpreters/AnalyzedJoin.h | 11 +++- .../Interpreters/CollectJoinOnKeysVisitor.h | 2 +- dbms/src/Interpreters/ExpressionActions.cpp | 50 ++++++++----------- dbms/src/Interpreters/ExpressionActions.h | 11 ++-- dbms/src/Interpreters/ExpressionAnalyzer.cpp | 9 ++-- dbms/src/Interpreters/ExpressionAnalyzer.h | 2 +- .../Interpreters/InterpreterSelectQuery.cpp | 20 ++++---- dbms/src/Interpreters/Join.h | 2 +- 9 files changed, 73 insertions(+), 59 deletions(-) diff --git a/dbms/src/Interpreters/AnalyzedJoin.cpp b/dbms/src/Interpreters/AnalyzedJoin.cpp index 36b573c4093..f60afe81276 100644 --- a/dbms/src/Interpreters/AnalyzedJoin.cpp +++ b/dbms/src/Interpreters/AnalyzedJoin.cpp @@ -209,7 +209,15 @@ bool AnalyzedJoin::sameJoin(const AnalyzedJoin * x, const AnalyzedJoin * y) && x->table_join.strictness == y->table_join.strictness && x->key_names_left == y->key_names_left && x->key_names_right == y->key_names_right - && x->columns_added_by_join == y->columns_added_by_join; + && x->columns_added_by_join == y->columns_added_by_join + && x->hash_join == y->hash_join; +} + +BlockInputStreamPtr AnalyzedJoin::createStreamWithNonJoinedDataIfFullOrRightJoin(const Block & source_header, UInt64 max_block_size) const +{ + if (isRightOrFull(table_join.kind)) + return hash_join->createStreamWithNonJoinedRows(source_header, *this, max_block_size); + return {}; } JoinPtr AnalyzedJoin::makeHashJoin(const Block & sample_block, const SizeLimits & size_limits_for_join) const @@ -219,6 +227,21 @@ JoinPtr AnalyzedJoin::makeHashJoin(const Block & sample_block, const SizeLimits return join; } +void AnalyzedJoin::joinBlock(Block & block) const +{ + hash_join->joinBlock(block, *this); +} + +void AnalyzedJoin::joinTotals(Block & block) const +{ + hash_join->joinTotals(block); +} + +bool AnalyzedJoin::hasTotals() const +{ + return hash_join->hasTotals(); +} + NamesAndTypesList getNamesAndTypeListFromTableExpression(const ASTTableExpression & table_expression, const Context & context) { NamesAndTypesList names_and_type_list; diff --git a/dbms/src/Interpreters/AnalyzedJoin.h b/dbms/src/Interpreters/AnalyzedJoin.h index 34fbede0d89..2622f35a941 100644 --- a/dbms/src/Interpreters/AnalyzedJoin.h +++ b/dbms/src/Interpreters/AnalyzedJoin.h @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -19,7 +20,7 @@ class Block; class Join; using JoinPtr = std::shared_ptr; -struct AnalyzedJoin +class AnalyzedJoin { /** Query of the form `SELECT expr(x) AS k FROM t1 ANY LEFT JOIN (SELECT expr(x) AS k FROM t2) USING k` * The join is made by column k. @@ -33,7 +34,6 @@ struct AnalyzedJoin * It's possible to use name `expr(t2 columns)`. */ -private: friend class SyntaxAnalyzer; Names key_names_left; @@ -53,6 +53,8 @@ private: /// Original name -> name. Only ranamed columns. std::unordered_map renames; + JoinPtr hash_join; + public: void addUsingKey(const ASTPtr & ast); void addOnKeys(ASTPtr & left_table_ast, ASTPtr & right_table_ast); @@ -79,7 +81,12 @@ public: const NamesAndTypesList & columnsFromJoinedTable() const { return columns_from_joined_table; } const NamesAndTypesList & columnsAddedByJoin() const { return columns_added_by_join; } + void setHashJoin(JoinPtr join) { hash_join = join; } JoinPtr makeHashJoin(const Block & sample_block, const SizeLimits & size_limits_for_join) const; + BlockInputStreamPtr createStreamWithNonJoinedDataIfFullOrRightJoin(const Block & source_header, UInt64 max_block_size) const; + void joinBlock(Block & block) const; + void joinTotals(Block & block) const; + bool hasTotals() const; static bool sameJoin(const AnalyzedJoin * x, const AnalyzedJoin * y); }; diff --git a/dbms/src/Interpreters/CollectJoinOnKeysVisitor.h b/dbms/src/Interpreters/CollectJoinOnKeysVisitor.h index bae6781a18a..024ad7c7cd8 100644 --- a/dbms/src/Interpreters/CollectJoinOnKeysVisitor.h +++ b/dbms/src/Interpreters/CollectJoinOnKeysVisitor.h @@ -10,7 +10,7 @@ namespace DB { class ASTIdentifier; -struct AnalyzedJoin; +class AnalyzedJoin; class CollectJoinOnKeysMatcher { diff --git a/dbms/src/Interpreters/ExpressionActions.cpp b/dbms/src/Interpreters/ExpressionActions.cpp index 160f9d68672..d6c38417899 100644 --- a/dbms/src/Interpreters/ExpressionActions.cpp +++ b/dbms/src/Interpreters/ExpressionActions.cpp @@ -3,11 +3,12 @@ #include #include #include -#include +#include #include #include #include #include +#include #include #include #include @@ -44,8 +45,8 @@ Names ExpressionAction::getNeededColumns() const res.insert(res.end(), array_joined_columns.begin(), array_joined_columns.end()); - if (join_params) - res.insert(res.end(), join_params->keyNamesLeft().begin(), join_params->keyNamesLeft().end()); + if (table_join) + res.insert(res.end(), table_join->keyNamesLeft().begin(), table_join->keyNamesLeft().end()); for (const auto & column : projection) res.push_back(column.first); @@ -159,12 +160,11 @@ ExpressionAction ExpressionAction::arrayJoin(const NameSet & array_joined_column return a; } -ExpressionAction ExpressionAction::ordinaryJoin(std::shared_ptr join_params, std::shared_ptr hash_join) +ExpressionAction ExpressionAction::ordinaryJoin(std::shared_ptr table_join) { ExpressionAction a; a.type = JOIN; - a.join_params = join_params; - a.join = hash_join; + a.table_join = table_join; return a; } @@ -269,7 +269,7 @@ void ExpressionAction::prepare(Block & sample_block, const Settings & settings, case JOIN: { - join_params->addJoinedColumnsAndCorrectNullability(sample_block); + table_join->addJoinedColumnsAndCorrectNullability(sample_block); break; } @@ -475,7 +475,7 @@ void ExpressionAction::execute(Block & block, bool dry_run) const case JOIN: { - join->joinBlock(block, *join_params); + table_join->joinBlock(block); break; } @@ -543,7 +543,7 @@ void ExpressionAction::executeOnTotals(Block & block) const if (type != JOIN) execute(block, false); else - join->joinTotals(block); + table_join->joinTotals(block); } @@ -593,10 +593,10 @@ std::string ExpressionAction::toString() const case JOIN: ss << "JOIN "; - for (NamesAndTypesList::const_iterator it = join_params->columnsAddedByJoin().begin(); - it != join_params->columnsAddedByJoin().end(); ++it) + for (NamesAndTypesList::const_iterator it = table_join->columnsAddedByJoin().begin(); + it != table_join->columnsAddedByJoin().end(); ++it) { - if (it != join_params->columnsAddedByJoin().begin()) + if (it != table_join->columnsAddedByJoin().begin()) ss << ", "; ss << it->name; } @@ -762,17 +762,10 @@ void ExpressionActions::execute(Block & block, bool dry_run) const bool ExpressionActions::hasTotalsInJoin() const { - bool has_totals_in_join = false; for (const auto & action : actions) - { - if (action.join && action.join->hasTotals()) - { - has_totals_in_join = true; - break; - } - } - - return has_totals_in_join; + if (action.table_join && action.table_join->hasTotals()) + return true; + return false; } void ExpressionActions::executeOnTotals(Block & block) const @@ -1164,13 +1157,11 @@ void ExpressionActions::optimizeArrayJoin() } -BlockInputStreamPtr ExpressionActions::createStreamWithNonJoinedDataIfFullOrRightJoin(const Block & source_header, UInt64 max_block_size) const +std::shared_ptr ExpressionActions::getTableJoin() const { for (const auto & action : actions) - if (action.join && isRightOrFull(action.join->getKind())) - return action.join->createStreamWithNonJoinedRows( - source_header, *action.join_params, max_block_size); - + if (action.table_join) + return action.table_join; return {}; } @@ -1216,7 +1207,7 @@ UInt128 ExpressionAction::ActionHash::operator()(const ExpressionAction & action hash.update(col); break; case JOIN: - for (const auto & col : action.join_params->columnsAddedByJoin()) + for (const auto & col : action.table_join->columnsAddedByJoin()) hash.update(col.name); break; case PROJECT: @@ -1274,8 +1265,7 @@ bool ExpressionAction::operator==(const ExpressionAction & other) const && argument_names == other.argument_names && array_joined_columns == other.array_joined_columns && array_join_is_left == other.array_join_is_left - && join == other.join - && AnalyzedJoin::sameJoin(join_params.get(), other.join_params.get()) + && AnalyzedJoin::sameJoin(table_join.get(), other.table_join.get()) && projection == other.projection && is_function_compiled == other.is_function_compiled; } diff --git a/dbms/src/Interpreters/ExpressionActions.h b/dbms/src/Interpreters/ExpressionActions.h index 90638d86368..6997c3ef759 100644 --- a/dbms/src/Interpreters/ExpressionActions.h +++ b/dbms/src/Interpreters/ExpressionActions.h @@ -4,9 +4,7 @@ #include #include #include -#include #include -#include #include #include "config_core.h" #include @@ -25,7 +23,7 @@ namespace ErrorCodes using NameWithAlias = std::pair; using NamesWithAliases = std::vector; -class Join; +class AnalyzedJoin; class IPreparedFunction; using PreparedFunctionPtr = std::shared_ptr; @@ -105,8 +103,7 @@ public: bool unaligned_array_join = false; /// For JOIN - std::shared_ptr join_params = nullptr; - std::shared_ptr join; + std::shared_ptr table_join; /// For PROJECT. NamesWithAliases projection; @@ -122,7 +119,7 @@ public: static ExpressionAction project(const Names & projected_columns_); static ExpressionAction addAliases(const NamesWithAliases & aliased_columns_); static ExpressionAction arrayJoin(const NameSet & array_joined_columns, bool array_join_is_left, const Context & context); - static ExpressionAction ordinaryJoin(std::shared_ptr join_params, std::shared_ptr hash_join); + static ExpressionAction ordinaryJoin(std::shared_ptr join); /// Which columns necessary to perform this action. Names getNeededColumns() const; @@ -238,7 +235,7 @@ public: static std::string getSmallestColumn(const NamesAndTypesList & columns); - BlockInputStreamPtr createStreamWithNonJoinedDataIfFullOrRightJoin(const Block & source_header, UInt64 max_block_size) const; + std::shared_ptr getTableJoin() const; const Settings & getSettings() const { return settings; } diff --git a/dbms/src/Interpreters/ExpressionAnalyzer.cpp b/dbms/src/Interpreters/ExpressionAnalyzer.cpp index e452d62ffca..d82169cf8e4 100644 --- a/dbms/src/Interpreters/ExpressionAnalyzer.cpp +++ b/dbms/src/Interpreters/ExpressionAnalyzer.cpp @@ -406,10 +406,9 @@ bool SelectQueryExpressionAnalyzer::appendArrayJoin(ExpressionActionsChain & cha return true; } -/// It's possible to set nullptr as join for only_types mode -void ExpressionAnalyzer::addJoinAction(ExpressionActionsPtr & actions, JoinPtr join) const +void ExpressionAnalyzer::addJoinAction(ExpressionActionsPtr & actions) const { - actions->add(ExpressionAction::ordinaryJoin(syntax->analyzed_join, join)); + actions->add(ExpressionAction::ordinaryJoin(syntax->analyzed_join)); } bool SelectQueryExpressionAnalyzer::appendJoin(ExpressionActionsChain & chain, bool only_types) @@ -419,13 +418,13 @@ bool SelectQueryExpressionAnalyzer::appendJoin(ExpressionActionsChain & chain, b return false; SubqueryForSet & subquery_for_set = getSubqueryForJoin(*ast_join); + syntax->analyzed_join->setHashJoin(subquery_for_set.join); initChain(chain, sourceColumns()); ExpressionActionsChain::Step & step = chain.steps.back(); getRootActions(analyzedJoin().leftKeysList(), only_types, step.actions); - addJoinAction(step.actions, subquery_for_set.join); - + addJoinAction(step.actions); return true; } diff --git a/dbms/src/Interpreters/ExpressionAnalyzer.h b/dbms/src/Interpreters/ExpressionAnalyzer.h index a28f54210b2..aebbaf038cc 100644 --- a/dbms/src/Interpreters/ExpressionAnalyzer.h +++ b/dbms/src/Interpreters/ExpressionAnalyzer.h @@ -130,7 +130,7 @@ protected: void addMultipleArrayJoinAction(ExpressionActionsPtr & actions, bool is_left) const; - void addJoinAction(ExpressionActionsPtr & actions, JoinPtr join = {}) const; + void addJoinAction(ExpressionActionsPtr & actions) const; void getRootActions(const ASTPtr & ast, bool no_subqueries, ExpressionActionsPtr & actions, bool only_consts = false); diff --git a/dbms/src/Interpreters/InterpreterSelectQuery.cpp b/dbms/src/Interpreters/InterpreterSelectQuery.cpp index 69613c73705..f18b368959e 100644 --- a/dbms/src/Interpreters/InterpreterSelectQuery.cpp +++ b/dbms/src/Interpreters/InterpreterSelectQuery.cpp @@ -1037,20 +1037,18 @@ void InterpreterSelectQuery::executeImpl(TPipeline & pipeline, const BlockInputS stream = std::make_shared(stream, expressions.before_join); } - const auto & join = query.join()->table_join->as(); - if (isRightOrFull(join.kind)) + if (auto join = expressions.before_join->getTableJoin()) { - auto stream = expressions.before_join->createStreamWithNonJoinedDataIfFullOrRightJoin( - header_before_join, settings.max_block_size); - - if constexpr (pipeline_with_processors) + if (auto stream = join->createStreamWithNonJoinedDataIfFullOrRightJoin(header_before_join, settings.max_block_size)) { - auto source = std::make_shared(std::move(stream)); - pipeline.addDelayedStream(source); + if constexpr (pipeline_with_processors) + { + auto source = std::make_shared(std::move(stream)); + pipeline.addDelayedStream(source); + } + else + pipeline.stream_with_non_joined_data = std::move(stream); } - else - pipeline.stream_with_non_joined_data = std::move(stream); - } } diff --git a/dbms/src/Interpreters/Join.h b/dbms/src/Interpreters/Join.h index 1a85481cf39..6ae69155920 100644 --- a/dbms/src/Interpreters/Join.h +++ b/dbms/src/Interpreters/Join.h @@ -26,7 +26,7 @@ namespace DB { -struct AnalyzedJoin; +class AnalyzedJoin; namespace JoinStuff {