add JoinSwitcher

This commit is contained in:
chertus 2020-02-11 21:27:52 +03:00
parent 840d4c5f4a
commit d2d4118730
4 changed files with 81 additions and 26 deletions

View File

@ -1,6 +1,4 @@
#include <Interpreters/AnalyzedJoin.h>
#include <Interpreters/Join.h>
#include <Interpreters/MergeJoin.h>
#include <Parsers/ASTExpressionList.h>
@ -229,27 +227,14 @@ bool AnalyzedJoin::sameStrictnessAndKind(ASTTableJoin::Strictness strictness_, A
return false;
}
JoinPtr makeJoin(std::shared_ptr<AnalyzedJoin> table_join, const Block & right_sample_block)
bool AnalyzedJoin::allowMergeJoin() const
{
auto kind = table_join->kind();
auto strictness = table_join->strictness();
bool is_any = (strictness() == ASTTableJoin::Strictness::Any);
bool is_all = (strictness() == ASTTableJoin::Strictness::All);
bool is_semi = (strictness() == ASTTableJoin::Strictness::Semi);
bool is_any = (strictness == ASTTableJoin::Strictness::Any);
bool is_all = (strictness == ASTTableJoin::Strictness::All);
bool is_semi = (strictness == ASTTableJoin::Strictness::Semi);
bool allow_merge_join = (isLeft(kind) && (is_any || is_all || is_semi)) || (isInner(kind) && is_all);
if (table_join->partial_merge_join && allow_merge_join)
return std::make_shared<MergeJoin>(table_join, right_sample_block);
return std::make_shared<Join>(table_join, right_sample_block);
}
bool isMergeJoin(const JoinPtr & join)
{
if (join)
return typeid_cast<const MergeJoin *>(join.get());
return false;
bool allow_merge_join = (isLeft(kind()) && (is_any || is_all || is_semi)) || (isInner(kind()) && is_all);
return allow_merge_join && partial_merge_join;
}
}

View File

@ -87,6 +87,7 @@ public:
bool sameStrictnessAndKind(ASTTableJoin::Strictness, ASTTableJoin::Kind) const;
const SizeLimits & sizeLimits() const { return size_limits; }
VolumePtr getTemporaryVolume() { return tmp_volume; }
bool allowMergeJoin() const;
bool forceNullableRight() const { return join_use_nulls && isLeftOrFull(table_join.kind); }
bool forceNullableLeft() const { return join_use_nulls && isRightOrFull(table_join.kind); }
@ -128,9 +129,6 @@ public:
void setRightKeys(const Names & keys) { key_names_right = keys; }
static bool sameJoin(const AnalyzedJoin * x, const AnalyzedJoin * y);
friend JoinPtr makeJoin(std::shared_ptr<AnalyzedJoin> table_join, const Block & right_sample_block);
};
bool isMergeJoin(const JoinPtr &);
}

View File

@ -29,7 +29,7 @@
#include <Interpreters/ExternalDictionariesLoader.h>
#include <Interpreters/Set.h>
#include <Interpreters/AnalyzedJoin.h>
#include <Interpreters/Join.h>
#include <Interpreters/JoinSwitcher.h>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/parseAggregateFunctionParameters.h>
@ -564,7 +564,7 @@ JoinPtr SelectQueryExpressionAnalyzer::makeTableJoin(const ASTTablesInSelectQuer
/// TODO You do not need to set this up when JOIN is only needed on remote servers.
subquery_for_join.setJoinActions(joined_block_actions); /// changes subquery_for_join.sample_block inside
subquery_for_join.join = makeJoin(syntax->analyzed_join, subquery_for_join.sample_block);
subquery_for_join.join = std::make_shared<JoinSwitcher>(syntax->analyzed_join, subquery_for_join.sample_block);
}
return subquery_for_join.join;

View File

@ -0,0 +1,72 @@
#pragma once
#include <Interpreters/IJoin.h>
#include <Interpreters/Join.h>
#include <Interpreters/MergeJoin.h>
#include <Interpreters/AnalyzedJoin.h>
namespace DB
{
class JoinSwitcher : public IJoin
{
public:
JoinSwitcher(std::shared_ptr<AnalyzedJoin> table_join, const Block & right_sample_block)
{
if (table_join->allowMergeJoin())
join = std::make_shared<MergeJoin>(table_join, right_sample_block);
else
join = std::make_shared<Join>(table_join, right_sample_block);
}
bool addJoinedBlock(const Block & block) override
{
/// TODO: switch Join -> MergeJoin
return join->addJoinedBlock(block);
}
void joinBlock(Block & block, std::shared_ptr<ExtraBlock> & not_processed) override
{
join->joinBlock(block, not_processed);
}
bool hasTotals() const override
{
return join->hasTotals();
}
void setTotals(const Block & block) override
{
join->setTotals(block);
}
void joinTotals(Block & block) const override
{
join->joinTotals(block);
}
size_t getTotalRowCount() const override
{
return join->getTotalRowCount();
}
bool alwaysReturnsEmptySet() const override
{
return join->alwaysReturnsEmptySet();
}
BlockInputStreamPtr createStreamWithNonJoinedRows(const Block & block, UInt64 max_block_size) const override
{
return join->createStreamWithNonJoinedRows(block, max_block_size);
}
bool hasStreamWithNonJoinedRows() const override
{
return join->hasStreamWithNonJoinedRows();
}
private:
JoinPtr join;
};
}