move Join object from ExpressionAction into AnalyzedJoin

This commit is contained in:
chertus 2019-09-03 17:36:02 +03:00
parent b831f2f636
commit bb3dedf1dc
9 changed files with 73 additions and 59 deletions

View File

@ -209,7 +209,15 @@ bool AnalyzedJoin::sameJoin(const AnalyzedJoin * x, const AnalyzedJoin * y)
&& x->table_join.strictness == y->table_join.strictness
&& x->key_names_left == y->key_names_left
&& x->key_names_right == y->key_names_right
&& x->columns_added_by_join == y->columns_added_by_join;
&& x->columns_added_by_join == y->columns_added_by_join
&& x->hash_join == y->hash_join;
}
BlockInputStreamPtr AnalyzedJoin::createStreamWithNonJoinedDataIfFullOrRightJoin(const Block & source_header, UInt64 max_block_size) const
{
if (isRightOrFull(table_join.kind))
return hash_join->createStreamWithNonJoinedRows(source_header, *this, max_block_size);
return {};
}
JoinPtr AnalyzedJoin::makeHashJoin(const Block & sample_block, const SizeLimits & size_limits_for_join) const
@ -219,6 +227,21 @@ JoinPtr AnalyzedJoin::makeHashJoin(const Block & sample_block, const SizeLimits
return join;
}
void AnalyzedJoin::joinBlock(Block & block) const
{
hash_join->joinBlock(block, *this);
}
void AnalyzedJoin::joinTotals(Block & block) const
{
hash_join->joinTotals(block);
}
bool AnalyzedJoin::hasTotals() const
{
return hash_join->hasTotals();
}
NamesAndTypesList getNamesAndTypeListFromTableExpression(const ASTTableExpression & table_expression, const Context & context)
{
NamesAndTypesList names_and_type_list;

View File

@ -4,6 +4,7 @@
#include <Core/NamesAndTypes.h>
#include <Core/SettingsCommon.h>
#include <Parsers/ASTTablesInSelectQuery.h>
#include <DataStreams/IBlockStream_fwd.h>
#include <utility>
#include <memory>
@ -19,7 +20,7 @@ class Block;
class Join;
using JoinPtr = std::shared_ptr<Join>;
struct AnalyzedJoin
class AnalyzedJoin
{
/** Query of the form `SELECT expr(x) AS k FROM t1 ANY LEFT JOIN (SELECT expr(x) AS k FROM t2) USING k`
* The join is made by column k.
@ -33,7 +34,6 @@ struct AnalyzedJoin
* It's possible to use name `expr(t2 columns)`.
*/
private:
friend class SyntaxAnalyzer;
Names key_names_left;
@ -53,6 +53,8 @@ private:
/// Original name -> name. Only ranamed columns.
std::unordered_map<String, String> renames;
JoinPtr hash_join;
public:
void addUsingKey(const ASTPtr & ast);
void addOnKeys(ASTPtr & left_table_ast, ASTPtr & right_table_ast);
@ -79,7 +81,12 @@ public:
const NamesAndTypesList & columnsFromJoinedTable() const { return columns_from_joined_table; }
const NamesAndTypesList & columnsAddedByJoin() const { return columns_added_by_join; }
void setHashJoin(JoinPtr join) { hash_join = join; }
JoinPtr makeHashJoin(const Block & sample_block, const SizeLimits & size_limits_for_join) const;
BlockInputStreamPtr createStreamWithNonJoinedDataIfFullOrRightJoin(const Block & source_header, UInt64 max_block_size) const;
void joinBlock(Block & block) const;
void joinTotals(Block & block) const;
bool hasTotals() const;
static bool sameJoin(const AnalyzedJoin * x, const AnalyzedJoin * y);
};

View File

@ -10,7 +10,7 @@ namespace DB
{
class ASTIdentifier;
struct AnalyzedJoin;
class AnalyzedJoin;
class CollectJoinOnKeysMatcher
{

View File

@ -3,11 +3,12 @@
#include <Common/SipHash.h>
#include <Interpreters/ExpressionActions.h>
#include <Interpreters/ExpressionJIT.h>
#include <Interpreters/Join.h>
#include <Interpreters/AnalyzedJoin.h>
#include <Columns/ColumnsNumber.h>
#include <Columns/ColumnArray.h>
#include <Common/typeid_cast.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionFactory.h>
#include <Functions/IFunction.h>
#include <set>
@ -44,8 +45,8 @@ Names ExpressionAction::getNeededColumns() const
res.insert(res.end(), array_joined_columns.begin(), array_joined_columns.end());
if (join_params)
res.insert(res.end(), join_params->keyNamesLeft().begin(), join_params->keyNamesLeft().end());
if (table_join)
res.insert(res.end(), table_join->keyNamesLeft().begin(), table_join->keyNamesLeft().end());
for (const auto & column : projection)
res.push_back(column.first);
@ -159,12 +160,11 @@ ExpressionAction ExpressionAction::arrayJoin(const NameSet & array_joined_column
return a;
}
ExpressionAction ExpressionAction::ordinaryJoin(std::shared_ptr<AnalyzedJoin> join_params, std::shared_ptr<const Join> hash_join)
ExpressionAction ExpressionAction::ordinaryJoin(std::shared_ptr<AnalyzedJoin> table_join)
{
ExpressionAction a;
a.type = JOIN;
a.join_params = join_params;
a.join = hash_join;
a.table_join = table_join;
return a;
}
@ -269,7 +269,7 @@ void ExpressionAction::prepare(Block & sample_block, const Settings & settings,
case JOIN:
{
join_params->addJoinedColumnsAndCorrectNullability(sample_block);
table_join->addJoinedColumnsAndCorrectNullability(sample_block);
break;
}
@ -475,7 +475,7 @@ void ExpressionAction::execute(Block & block, bool dry_run) const
case JOIN:
{
join->joinBlock(block, *join_params);
table_join->joinBlock(block);
break;
}
@ -543,7 +543,7 @@ void ExpressionAction::executeOnTotals(Block & block) const
if (type != JOIN)
execute(block, false);
else
join->joinTotals(block);
table_join->joinTotals(block);
}
@ -593,10 +593,10 @@ std::string ExpressionAction::toString() const
case JOIN:
ss << "JOIN ";
for (NamesAndTypesList::const_iterator it = join_params->columnsAddedByJoin().begin();
it != join_params->columnsAddedByJoin().end(); ++it)
for (NamesAndTypesList::const_iterator it = table_join->columnsAddedByJoin().begin();
it != table_join->columnsAddedByJoin().end(); ++it)
{
if (it != join_params->columnsAddedByJoin().begin())
if (it != table_join->columnsAddedByJoin().begin())
ss << ", ";
ss << it->name;
}
@ -762,17 +762,10 @@ void ExpressionActions::execute(Block & block, bool dry_run) const
bool ExpressionActions::hasTotalsInJoin() const
{
bool has_totals_in_join = false;
for (const auto & action : actions)
{
if (action.join && action.join->hasTotals())
{
has_totals_in_join = true;
break;
}
}
return has_totals_in_join;
if (action.table_join && action.table_join->hasTotals())
return true;
return false;
}
void ExpressionActions::executeOnTotals(Block & block) const
@ -1164,13 +1157,11 @@ void ExpressionActions::optimizeArrayJoin()
}
BlockInputStreamPtr ExpressionActions::createStreamWithNonJoinedDataIfFullOrRightJoin(const Block & source_header, UInt64 max_block_size) const
std::shared_ptr<const AnalyzedJoin> ExpressionActions::getTableJoin() const
{
for (const auto & action : actions)
if (action.join && isRightOrFull(action.join->getKind()))
return action.join->createStreamWithNonJoinedRows(
source_header, *action.join_params, max_block_size);
if (action.table_join)
return action.table_join;
return {};
}
@ -1216,7 +1207,7 @@ UInt128 ExpressionAction::ActionHash::operator()(const ExpressionAction & action
hash.update(col);
break;
case JOIN:
for (const auto & col : action.join_params->columnsAddedByJoin())
for (const auto & col : action.table_join->columnsAddedByJoin())
hash.update(col.name);
break;
case PROJECT:
@ -1274,8 +1265,7 @@ bool ExpressionAction::operator==(const ExpressionAction & other) const
&& argument_names == other.argument_names
&& array_joined_columns == other.array_joined_columns
&& array_join_is_left == other.array_join_is_left
&& join == other.join
&& AnalyzedJoin::sameJoin(join_params.get(), other.join_params.get())
&& AnalyzedJoin::sameJoin(table_join.get(), other.table_join.get())
&& projection == other.projection
&& is_function_compiled == other.is_function_compiled;
}

View File

@ -4,9 +4,7 @@
#include <Core/ColumnWithTypeAndName.h>
#include <Core/Names.h>
#include <Core/Settings.h>
#include <DataStreams/IBlockStream_fwd.h>
#include <Interpreters/Context.h>
#include <Interpreters/AnalyzedJoin.h>
#include <Common/SipHash.h>
#include "config_core.h"
#include <unordered_map>
@ -25,7 +23,7 @@ namespace ErrorCodes
using NameWithAlias = std::pair<std::string, std::string>;
using NamesWithAliases = std::vector<NameWithAlias>;
class Join;
class AnalyzedJoin;
class IPreparedFunction;
using PreparedFunctionPtr = std::shared_ptr<IPreparedFunction>;
@ -105,8 +103,7 @@ public:
bool unaligned_array_join = false;
/// For JOIN
std::shared_ptr<AnalyzedJoin> join_params = nullptr;
std::shared_ptr<const Join> join;
std::shared_ptr<const AnalyzedJoin> table_join;
/// For PROJECT.
NamesWithAliases projection;
@ -122,7 +119,7 @@ public:
static ExpressionAction project(const Names & projected_columns_);
static ExpressionAction addAliases(const NamesWithAliases & aliased_columns_);
static ExpressionAction arrayJoin(const NameSet & array_joined_columns, bool array_join_is_left, const Context & context);
static ExpressionAction ordinaryJoin(std::shared_ptr<AnalyzedJoin> join_params, std::shared_ptr<const Join> hash_join);
static ExpressionAction ordinaryJoin(std::shared_ptr<AnalyzedJoin> join);
/// Which columns necessary to perform this action.
Names getNeededColumns() const;
@ -238,7 +235,7 @@ public:
static std::string getSmallestColumn(const NamesAndTypesList & columns);
BlockInputStreamPtr createStreamWithNonJoinedDataIfFullOrRightJoin(const Block & source_header, UInt64 max_block_size) const;
std::shared_ptr<const AnalyzedJoin> getTableJoin() const;
const Settings & getSettings() const { return settings; }

View File

@ -406,10 +406,9 @@ bool SelectQueryExpressionAnalyzer::appendArrayJoin(ExpressionActionsChain & cha
return true;
}
/// It's possible to set nullptr as join for only_types mode
void ExpressionAnalyzer::addJoinAction(ExpressionActionsPtr & actions, JoinPtr join) const
void ExpressionAnalyzer::addJoinAction(ExpressionActionsPtr & actions) const
{
actions->add(ExpressionAction::ordinaryJoin(syntax->analyzed_join, join));
actions->add(ExpressionAction::ordinaryJoin(syntax->analyzed_join));
}
bool SelectQueryExpressionAnalyzer::appendJoin(ExpressionActionsChain & chain, bool only_types)
@ -419,13 +418,13 @@ bool SelectQueryExpressionAnalyzer::appendJoin(ExpressionActionsChain & chain, b
return false;
SubqueryForSet & subquery_for_set = getSubqueryForJoin(*ast_join);
syntax->analyzed_join->setHashJoin(subquery_for_set.join);
initChain(chain, sourceColumns());
ExpressionActionsChain::Step & step = chain.steps.back();
getRootActions(analyzedJoin().leftKeysList(), only_types, step.actions);
addJoinAction(step.actions, subquery_for_set.join);
addJoinAction(step.actions);
return true;
}

View File

@ -130,7 +130,7 @@ protected:
void addMultipleArrayJoinAction(ExpressionActionsPtr & actions, bool is_left) const;
void addJoinAction(ExpressionActionsPtr & actions, JoinPtr join = {}) const;
void addJoinAction(ExpressionActionsPtr & actions) const;
void getRootActions(const ASTPtr & ast, bool no_subqueries, ExpressionActionsPtr & actions, bool only_consts = false);

View File

@ -1037,20 +1037,18 @@ void InterpreterSelectQuery::executeImpl(TPipeline & pipeline, const BlockInputS
stream = std::make_shared<ExpressionBlockInputStream>(stream, expressions.before_join);
}
const auto & join = query.join()->table_join->as<ASTTableJoin &>();
if (isRightOrFull(join.kind))
if (auto join = expressions.before_join->getTableJoin())
{
auto stream = expressions.before_join->createStreamWithNonJoinedDataIfFullOrRightJoin(
header_before_join, settings.max_block_size);
if constexpr (pipeline_with_processors)
if (auto stream = join->createStreamWithNonJoinedDataIfFullOrRightJoin(header_before_join, settings.max_block_size))
{
auto source = std::make_shared<SourceFromInputStream>(std::move(stream));
pipeline.addDelayedStream(source);
if constexpr (pipeline_with_processors)
{
auto source = std::make_shared<SourceFromInputStream>(std::move(stream));
pipeline.addDelayedStream(source);
}
else
pipeline.stream_with_non_joined_data = std::move(stream);
}
else
pipeline.stream_with_non_joined_data = std::move(stream);
}
}

View File

@ -26,7 +26,7 @@
namespace DB
{
struct AnalyzedJoin;
class AnalyzedJoin;
namespace JoinStuff
{