Merge pull request #9082 from 4ertus2/joins

Switch JOIN algo on the fly
This commit is contained in:
alexey-milovidov 2020-02-21 02:01:23 +03:00 committed by GitHub
commit 219f94ca97
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 264 additions and 49 deletions

View File

@ -390,7 +390,7 @@ namespace ErrorCodes
extern const int ALL_REPLICAS_LOST = 415;
extern const int REPLICA_STATUS_CHANGED = 416;
extern const int EXPECTED_ALL_OR_ANY = 417;
extern const int UNKNOWN_JOIN_STRICTNESS = 418;
extern const int UNKNOWN_JOIN = 418;
extern const int MULTIPLE_ASSIGNMENTS_TO_COLUMN = 419;
extern const int CANNOT_UPDATE_COLUMN = 420;
extern const int CANNOT_ADD_DIFFERENT_AGGREGATE_STATES = 421;

View File

@ -316,9 +316,10 @@ struct Settings : public SettingsCollection<Settings>
M(SettingUInt64, max_bytes_in_join, 0, "Maximum size of the hash table for JOIN (in number of bytes in memory).", 0) \
M(SettingOverflowMode, join_overflow_mode, OverflowMode::THROW, "What to do when the limit is exceeded.", 0) \
M(SettingBool, join_any_take_last_row, false, "When disabled (default) ANY JOIN will take the first found row for a key. When enabled, it will take the last row seen if there are multiple rows for the same key.", IMPORTANT) \
M(SettingBool, partial_merge_join, false, "Use partial merge join instead of hash join for joins (ANY|ALL|SEMI LEFT and ALL INNER are supported for now).", 0) \
M(SettingBool, partial_merge_join, false, "Obsolete. Use join_algorithm='prefer_partial_merge' instead.", 0) \
M(SettingJoinAlgorithm, join_algorithm, JoinAlgorithm::HASH, "Specify join algorithm: 'auto', 'hash', 'partial_merge', 'prefer_partial_merge'. 'auto' tries to change HashJoin to MergeJoin on the fly to avoid out of memory.", 0) \
M(SettingBool, partial_merge_join_optimizations, false, "Enable optimizations in partial merge join", 0) \
M(SettingUInt64, default_max_bytes_in_join, 100000000, "Maximum size of right-side table if limit is required but max_bytes_in_join is not set.", 0) \
M(SettingUInt64, default_max_bytes_in_join, 1000000000, "Maximum size of right-side table if limit is required but max_bytes_in_join is not set.", 0) \
M(SettingUInt64, partial_merge_join_rows_in_right_blocks, 10000, "Split right-hand joining data in blocks of specified size. It's a portion of data indexed by min-max values and possibly unloaded on disk.", 0) \
\
M(SettingUInt64, max_rows_to_transfer, 0, "Maximum size (in rows) of the transmitted external table obtained when the GLOBAL IN/JOIN section is executed.", 0) \

View File

@ -22,7 +22,7 @@ namespace ErrorCodes
extern const int UNKNOWN_COMPRESSION_METHOD;
extern const int UNKNOWN_DISTRIBUTED_PRODUCT_MODE;
extern const int UNKNOWN_GLOBAL_SUBQUERIES_METHOD;
extern const int UNKNOWN_JOIN_STRICTNESS;
extern const int UNKNOWN_JOIN;
extern const int UNKNOWN_LOG_LEVEL;
extern const int SIZE_OF_FIXED_STRING_DOESNT_MATCH;
extern const int BAD_ARGUMENTS;
@ -495,8 +495,14 @@ IMPLEMENT_SETTING_ENUM(LoadBalancing, LOAD_BALANCING_LIST_OF_NAMES, ErrorCodes::
M(Unspecified, "") \
M(ALL, "ALL") \
M(ANY, "ANY")
IMPLEMENT_SETTING_ENUM(JoinStrictness, JOIN_STRICTNESS_LIST_OF_NAMES, ErrorCodes::UNKNOWN_JOIN_STRICTNESS)
IMPLEMENT_SETTING_ENUM(JoinStrictness, JOIN_STRICTNESS_LIST_OF_NAMES, ErrorCodes::UNKNOWN_JOIN)
#define JOIN_ALGORITHM_NAMES(M) \
M(AUTO, "auto") \
M(HASH, "hash") \
M(PARTIAL_MERGE, "partial_merge") \
M(PREFER_PARTIAL_MERGE, "prefer_partial_merge")
IMPLEMENT_SETTING_ENUM(JoinAlgorithm, JOIN_ALGORITHM_NAMES, ErrorCodes::UNKNOWN_JOIN)
#define TOTALS_MODE_LIST_OF_NAMES(M) \
M(BEFORE_HAVING, "before_having") \

View File

@ -242,6 +242,14 @@ enum class JoinStrictness
};
using SettingJoinStrictness = SettingEnum<JoinStrictness>;
enum class JoinAlgorithm
{
AUTO = 0,
HASH,
PARTIAL_MERGE,
PREFER_PARTIAL_MERGE,
};
using SettingJoinAlgorithm = SettingEnum<JoinAlgorithm>;
/// Which rows should be included in TOTALS.
enum class TotalsMode

View File

@ -1,6 +1,4 @@
#include <Interpreters/AnalyzedJoin.h>
#include <Interpreters/Join.h>
#include <Interpreters/MergeJoin.h>
#include <Parsers/ASTExpressionList.h>
@ -24,11 +22,14 @@ AnalyzedJoin::AnalyzedJoin(const Settings & settings, VolumePtr tmp_volume_)
, default_max_bytes(settings.default_max_bytes_in_join)
, join_use_nulls(settings.join_use_nulls)
, max_joined_block_rows(settings.max_joined_block_size_rows)
, partial_merge_join(settings.partial_merge_join)
, join_algorithm(settings.join_algorithm)
, partial_merge_join_optimizations(settings.partial_merge_join_optimizations)
, partial_merge_join_rows_in_right_blocks(settings.partial_merge_join_rows_in_right_blocks)
, tmp_volume(tmp_volume_)
{}
{
if (settings.partial_merge_join)
join_algorithm = JoinAlgorithm::PREFER_PARTIAL_MERGE;
}
void AnalyzedJoin::addUsingKey(const ASTPtr & ast)
{
@ -229,27 +230,14 @@ bool AnalyzedJoin::sameStrictnessAndKind(ASTTableJoin::Strictness strictness_, A
return false;
}
JoinPtr makeJoin(std::shared_ptr<AnalyzedJoin> table_join, const Block & right_sample_block)
bool AnalyzedJoin::allowMergeJoin() const
{
auto kind = table_join->kind();
auto strictness = table_join->strictness();
bool is_any = (strictness() == ASTTableJoin::Strictness::Any);
bool is_all = (strictness() == ASTTableJoin::Strictness::All);
bool is_semi = (strictness() == ASTTableJoin::Strictness::Semi);
bool is_any = (strictness == ASTTableJoin::Strictness::Any);
bool is_all = (strictness == ASTTableJoin::Strictness::All);
bool is_semi = (strictness == ASTTableJoin::Strictness::Semi);
bool allow_merge_join = (isLeft(kind) && (is_any || is_all || is_semi)) || (isInner(kind) && is_all);
if (table_join->partial_merge_join && allow_merge_join)
return std::make_shared<MergeJoin>(table_join, right_sample_block);
return std::make_shared<Join>(table_join, right_sample_block);
}
bool isMergeJoin(const JoinPtr & join)
{
if (join)
return typeid_cast<const MergeJoin *>(join.get());
return false;
bool allow_merge_join = (isLeft(kind()) && (is_any || is_all || is_semi)) || (isInner(kind()) && is_all);
return allow_merge_join;
}
}

View File

@ -2,6 +2,7 @@
#include <Core/Names.h>
#include <Core/NamesAndTypes.h>
#include <Core/SettingsCollection.h>
#include <Parsers/ASTTablesInSelectQuery.h>
#include <Interpreters/IJoin.h>
#include <Interpreters/asof.h>
@ -44,7 +45,7 @@ class AnalyzedJoin
const size_t default_max_bytes;
const bool join_use_nulls;
const size_t max_joined_block_rows = 0;
const bool partial_merge_join = false;
JoinAlgorithm join_algorithm;
const bool partial_merge_join_optimizations = false;
const size_t partial_merge_join_rows_in_right_blocks = 0;
@ -76,6 +77,7 @@ public:
: size_limits(limits)
, default_max_bytes(0)
, join_use_nulls(use_nulls)
, join_algorithm(JoinAlgorithm::HASH)
, key_names_right(key_names_right_)
{
table_join.kind = kind;
@ -87,6 +89,10 @@ public:
bool sameStrictnessAndKind(ASTTableJoin::Strictness, ASTTableJoin::Kind) const;
const SizeLimits & sizeLimits() const { return size_limits; }
VolumePtr getTemporaryVolume() { return tmp_volume; }
bool allowMergeJoin() const;
bool preferMergeJoin() const { return join_algorithm == JoinAlgorithm::PREFER_PARTIAL_MERGE; }
bool forceMergeJoin() const { return join_algorithm == JoinAlgorithm::PARTIAL_MERGE; }
bool forceHashJoin() const { return join_algorithm == JoinAlgorithm::HASH; }
bool forceNullableRight() const { return join_use_nulls && isLeftOrFull(table_join.kind); }
bool forceNullableLeft() const { return join_use_nulls && isRightOrFull(table_join.kind); }
@ -128,9 +134,6 @@ public:
void setRightKeys(const Names & keys) { key_names_right = keys; }
static bool sameJoin(const AnalyzedJoin * x, const AnalyzedJoin * y);
friend JoinPtr makeJoin(std::shared_ptr<AnalyzedJoin> table_join, const Block & right_sample_block);
};
bool isMergeJoin(const JoinPtr &);
}

View File

@ -29,7 +29,9 @@
#include <Interpreters/ExternalDictionariesLoader.h>
#include <Interpreters/Set.h>
#include <Interpreters/AnalyzedJoin.h>
#include <Interpreters/JoinSwitcher.h>
#include <Interpreters/Join.h>
#include <Interpreters/MergeJoin.h>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/parseAggregateFunctionParameters.h>
@ -538,6 +540,17 @@ static ExpressionActionsPtr createJoinedBlockActions(const Context & context, co
return ExpressionAnalyzer(expression_list, syntax_result, context).getActions(true, false);
}
static std::shared_ptr<IJoin> makeJoin(std::shared_ptr<AnalyzedJoin> analyzed_join, const Block & sample_block)
{
bool allow_merge_join = analyzed_join->allowMergeJoin();
if (analyzed_join->forceHashJoin() || (analyzed_join->preferMergeJoin() && !allow_merge_join))
return std::make_shared<Join>(analyzed_join, sample_block);
else if (analyzed_join->forceMergeJoin() || (analyzed_join->preferMergeJoin() && allow_merge_join))
return std::make_shared<MergeJoin>(analyzed_join, sample_block);
return std::make_shared<JoinSwitcher>(analyzed_join, sample_block);
}
JoinPtr SelectQueryExpressionAnalyzer::makeTableJoin(const ASTTablesInSelectQueryElement & join_element)
{
/// Two JOINs are not supported with the same subquery, but different USINGs.

View File

@ -20,7 +20,7 @@ public:
/// Add block of data from right hand of JOIN.
/// @returns false, if some limit was exceeded and you should not insert more data.
virtual bool addJoinedBlock(const Block & block) = 0;
virtual bool addJoinedBlock(const Block & block, bool check_limits = true) = 0;
/// Join the block with data from left hand of JOIN to the right hand data (that was previously built by calls to addJoinedBlock).
/// Could be called from different threads in parallel.
@ -31,6 +31,7 @@ public:
virtual void joinTotals(Block & block) const = 0;
virtual size_t getTotalRowCount() const = 0;
virtual size_t getTotalByteCount() const = 0;
virtual bool alwaysReturnsEmptySet() const { return false; }
virtual BlockInputStreamPtr createStreamWithNonJoinedRows(const Block &, UInt64) const { return {}; }

View File

@ -50,6 +50,7 @@
#include <Interpreters/JoinToSubqueryTransformVisitor.h>
#include <Interpreters/CrossToInnerJoinVisitor.h>
#include <Interpreters/AnalyzedJoin.h>
#include <Interpreters/Join.h>
#include <Storages/MergeTree/MergeTreeData.h>
#include <Storages/MergeTree/MergeTreeWhereOptimizer.h>
@ -862,6 +863,8 @@ void InterpreterSelectQuery::executeImpl(TPipeline & pipeline, const BlockInputS
if (expressions.hasJoin())
{
Block header_before_join;
JoinPtr join = expressions.before_join->getTableJoinAlgo();
bool inflating_join = join && !typeid_cast<Join *>(join.get());
if constexpr (pipeline_with_processors)
{
@ -879,10 +882,11 @@ void InterpreterSelectQuery::executeImpl(TPipeline & pipeline, const BlockInputS
{
bool on_totals = type == QueryPipeline::StreamType::Totals;
std::shared_ptr<IProcessor> ret;
if (settings.partial_merge_join)
if (inflating_join)
ret = std::make_shared<InflatingExpressionTransform>(header, expressions.before_join, on_totals, default_totals);
else
ret = std::make_shared<ExpressionTransform>(header, expressions.before_join, on_totals, default_totals);
return ret;
});
}
@ -894,7 +898,7 @@ void InterpreterSelectQuery::executeImpl(TPipeline & pipeline, const BlockInputS
stream = std::make_shared<InflatingExpressionBlockInputStream>(stream, expressions.before_join);
}
if (JoinPtr join = expressions.before_join->getTableJoinAlgo())
if (join)
{
Block join_result_sample = ExpressionBlockInputStream(
std::make_shared<OneBlockInputStream>(header_before_join), expressions.before_join).getHeader();

View File

@ -316,7 +316,7 @@ void Join::setSampleBlock(const Block & block)
ColumnRawPtrs key_columns = JoinCommon::extractKeysForJoin(key_names_right, block, right_table_keys, sample_block_with_columns_to_add);
initRightBlockStructure();
initRightBlockStructure(data->sample_block);
initRequiredRightKeys();
JoinCommon::createMissedColumns(sample_block_with_columns_to_add);
@ -481,13 +481,12 @@ void Join::initRequiredRightKeys()
}
}
void Join::initRightBlockStructure()
void Join::initRightBlockStructure(Block & saved_block_sample)
{
auto & saved_block_sample = data->sample_block;
if (isRightOrFull(kind))
/// We could remove key columns for LEFT | INNER HashJoin but we should keep them for JoinSwitcher (if any).
bool save_key_columns = !table_join->forceHashJoin() || isRightOrFull(kind);
if (save_key_columns)
{
/// Save keys for NonJoinedBlockInputStream
saved_block_sample = right_table_keys.cloneEmpty();
}
else if (strictness == ASTTableJoin::Strictness::Asof)
@ -518,7 +517,7 @@ Block Join::structureRightBlock(const Block & block) const
return structured_block;
}
bool Join::addJoinedBlock(const Block & source_block)
bool Join::addJoinedBlock(const Block & source_block, bool check_limits)
{
if (empty())
throw Exception("Logical error: Join was not initialized", ErrorCodes::LOGICAL_ERROR);
@ -565,6 +564,9 @@ bool Join::addJoinedBlock(const Block & source_block)
if (save_nullmap)
data->blocks_nullmaps.emplace_back(stored_block, null_map_holder);
if (!check_limits)
return true;
/// TODO: Do not calculate them every time
total_rows = getTotalRowCount();
total_bytes = getTotalByteCount();

View File

@ -153,7 +153,7 @@ public:
/** Add block of data from right hand of JOIN to the map.
* Returns false, if some limit was exceeded and you should not insert more data.
*/
bool addJoinedBlock(const Block & block) override;
bool addJoinedBlock(const Block & block, bool check_limits = true) override;
/** Join data from the map (that was previously built by calls to addJoinedBlock) to the block with data from "left" table.
* Could be called from different threads in parallel.
@ -184,7 +184,7 @@ public:
/// Number of keys in all built JOIN maps.
size_t getTotalRowCount() const final;
/// Sum size in bytes of all buffers, used for JOIN maps and for all memory pools.
size_t getTotalByteCount() const;
size_t getTotalByteCount() const final;
bool alwaysReturnsEmptySet() const final { return isInnerOrRight(getKind()) && data->empty; }
@ -320,6 +320,11 @@ public:
data = join.data;
}
std::shared_ptr<RightTableData> getJoinedData() const
{
return data;
}
private:
friend class NonJoinedBlockInputStream;
friend class JoinBlockInputStream;
@ -364,7 +369,7 @@ private:
/// Modify (structure) right block to save it in block list
Block structureRightBlock(const Block & stored_block) const;
void initRightBlockStructure();
void initRightBlockStructure(Block & saved_block_sample);
void initRequiredRightKeys();
template <ASTTableJoin::Kind KIND, ASTTableJoin::Strictness STRICTNESS, typename Maps>

View File

@ -0,0 +1,87 @@
#include <Common/typeid_cast.h>
#include <Interpreters/JoinSwitcher.h>
#include <Interpreters/Join.h>
#include <Interpreters/MergeJoin.h>
#include <Interpreters/join_common.h>
namespace DB
{
static ColumnWithTypeAndName correctNullability(ColumnWithTypeAndName && column, bool nullable)
{
if (nullable)
JoinCommon::convertColumnToNullable(column);
else
JoinCommon::removeColumnNullability(column);
return std::move(column);
}
JoinSwitcher::JoinSwitcher(std::shared_ptr<AnalyzedJoin> table_join_, const Block & right_sample_block_)
: limits(table_join_->sizeLimits())
, switched(false)
, table_join(table_join_)
, right_sample_block(right_sample_block_.cloneEmpty())
{
join = std::make_shared<Join>(table_join, right_sample_block);
if (!limits.hasLimits())
limits.max_bytes = table_join->defaultMaxBytes();
}
bool JoinSwitcher::addJoinedBlock(const Block & block, bool)
{
std::lock_guard lock(switch_mutex);
if (switched)
return join->addJoinedBlock(block);
/// HashJoin with external limits check
join->addJoinedBlock(block, false);
size_t rows = join->getTotalRowCount();
size_t bytes = join->getTotalByteCount();
if (!limits.softCheck(rows, bytes))
switchJoin();
return true;
}
void JoinSwitcher::switchJoin()
{
std::shared_ptr<Join::RightTableData> joined_data = static_cast<const Join &>(*join).getJoinedData();
BlocksList right_blocks = std::move(joined_data->blocks);
/// Destroy old join & create new one. Early destroy for memory saving.
join = std::make_shared<MergeJoin>(table_join, right_sample_block);
/// names to positions optimization
std::vector<size_t> positions;
std::vector<bool> is_nullable;
if (right_blocks.size())
{
positions.reserve(right_sample_block.columns());
const Block & tmp_block = *right_blocks.begin();
for (const auto & sample_column : right_sample_block)
{
positions.emplace_back(tmp_block.getPositionByName(sample_column.name));
is_nullable.emplace_back(sample_column.type->isNullable());
}
}
for (Block & saved_block : right_blocks)
{
Block restored_block;
for (size_t i = 0; i < positions.size(); ++i)
{
auto & column = saved_block.getByPosition(positions[i]);
restored_block.insert(correctNullability(std::move(column), is_nullable[i]));
}
join->addJoinedBlock(restored_block);
}
switched = true;
}
}

View File

@ -0,0 +1,83 @@
#pragma once
#include <mutex>
#include <Core/Block.h>
#include <Interpreters/IJoin.h>
#include <Interpreters/AnalyzedJoin.h>
namespace DB
{
/// Used when setting 'join_algorithm' set to JoinAlgorithm::AUTO.
/// Starts JOIN with join-in-memory algorithm and switches to join-on-disk on the fly if there's no memory to place right table.
/// Current join-in-memory and join-on-disk are JoinAlgorithm::HASH and JoinAlgorithm::PARTIAL_MERGE joins respectively.
class JoinSwitcher : public IJoin
{
public:
JoinSwitcher(std::shared_ptr<AnalyzedJoin> table_join_, const Block & right_sample_block_);
/// Add block of data from right hand of JOIN into current join object.
/// If join-in-memory memory limit exceeded switches to join-on-disk and continue with it.
/// @returns false, if join-on-disk disk limit exceeded
bool addJoinedBlock(const Block & block, bool check_limits = true) override;
void joinBlock(Block & block, std::shared_ptr<ExtraBlock> & not_processed) override
{
join->joinBlock(block, not_processed);
}
bool hasTotals() const override
{
return join->hasTotals();
}
void setTotals(const Block & block) override
{
join->setTotals(block);
}
void joinTotals(Block & block) const override
{
join->joinTotals(block);
}
size_t getTotalRowCount() const override
{
return join->getTotalRowCount();
}
size_t getTotalByteCount() const override
{
return join->getTotalByteCount();
}
bool alwaysReturnsEmptySet() const override
{
return join->alwaysReturnsEmptySet();
}
BlockInputStreamPtr createStreamWithNonJoinedRows(const Block & block, UInt64 max_block_size) const override
{
return join->createStreamWithNonJoinedRows(block, max_block_size);
}
bool hasStreamWithNonJoinedRows() const override
{
return join->hasStreamWithNonJoinedRows();
}
private:
JoinPtr join;
SizeLimits limits;
bool switched;
mutable std::mutex switch_mutex;
std::shared_ptr<AnalyzedJoin> table_join;
const Block right_sample_block;
/// Change join-in-memory to join-on-disk moving right hand JOIN data from one to another.
/// Throws an error if join-on-disk do not support JOIN kind or strictness.
void switchJoin();
};
}

View File

@ -585,7 +585,7 @@ bool MergeJoin::saveRightBlock(Block && block)
return true;
}
bool MergeJoin::addJoinedBlock(const Block & src_block)
bool MergeJoin::addJoinedBlock(const Block & src_block, bool)
{
Block block = materializeBlock(src_block);
JoinCommon::removeLowCardinalityInplace(block);

View File

@ -50,12 +50,13 @@ class MergeJoin : public IJoin
public:
MergeJoin(std::shared_ptr<AnalyzedJoin> table_join_, const Block & right_sample_block);
bool addJoinedBlock(const Block & block) override;
bool addJoinedBlock(const Block & block, bool check_limits = true) override;
void joinBlock(Block &, ExtraBlockPtr & not_processed) override;
void joinTotals(Block &) const override;
void setTotals(const Block &) override;
bool hasTotals() const override { return totals; }
size_t getTotalRowCount() const override { return right_blocks_row_count; }
size_t getTotalByteCount() const override { return right_blocks_bytes; }
private:
struct NotProcessed : public ExtraBlock

View File

@ -10,3 +10,7 @@
1 0
2 11
3 0
0 10
1 0
2 11
3 0

View File

@ -1,4 +1,4 @@
SET partial_merge_join = 0;
SET join_algorithm = 'hash';
SELECT number as n, j FROM numbers(4)
ANY LEFT JOIN (
@ -16,7 +16,7 @@ ANY LEFT JOIN (
) js2
USING n; -- { serverError 191 }
SET partial_merge_join = 1;
SET join_algorithm = 'partial_merge';
SELECT number as n, j FROM numbers(4)
ANY LEFT JOIN (
@ -33,3 +33,12 @@ ANY LEFT JOIN (
FROM numbers(4000)
) js2
USING n;
SET join_algorithm = 'auto';
SELECT number as n, j FROM numbers(4)
ANY LEFT JOIN (
SELECT number * 2 AS n, number + 10 AS j
FROM numbers(4000)
) js2
USING n;