Refactor join: make IJoin implementations independent from TableJoin

This commit is contained in:
vdimir 2021-03-05 16:38:49 +03:00
parent b7c7c97d10
commit e6406c3f4c
No known key found for this signature in database
GPG Key ID: F57B3E10A21DBB31
19 changed files with 475 additions and 431 deletions

View File

@ -95,6 +95,17 @@ bool allowEarlyConstantFolding(const ActionsDAG & actions, const Settings & sett
return true;
}
bool allowMergeJoin(ASTTableJoin::Kind kind, ASTTableJoin::Strictness strictness)
{
bool is_any = (strictness == ASTTableJoin::Strictness::Any);
bool is_all = (strictness == ASTTableJoin::Strictness::All);
bool is_semi = (strictness == ASTTableJoin::Strictness::Semi);
bool all_join = is_all && (isInner(kind) || isLeft(kind) || isRight(kind) || isFull(kind));
bool special_left = isLeft(kind) && (is_any || is_semi);
return all_join || special_left;
}
}
bool sanitizeBlock(Block & block, bool throw_if_cannot_create_column)
@ -720,14 +731,14 @@ bool SelectQueryExpressionAnalyzer::appendJoinLeftKeys(ExpressionActionsChain &
return true;
}
JoinPtr SelectQueryExpressionAnalyzer::appendJoin(ExpressionActionsChain & chain)
JoinPtr SelectQueryExpressionAnalyzer::appendJoin(ExpressionActionsChain & chain, ActionsDAGPtr & left_actions, ActionsDAGPtr & right_actions)
{
const ColumnsWithTypeAndName & left_sample_columns = chain.getLastStep().getResultColumns();
JoinPtr table_join = makeTableJoin(*syntax->ast_join, left_sample_columns);
JoinPtr table_join = makeTableJoin(*syntax->ast_join, left_sample_columns, left_actions, right_actions);
if (syntax->analyzed_join->needConvert())
if (left_actions)
{
chain.steps.push_back(std::make_unique<ExpressionActionsChain::ExpressionActionsStep>(syntax->analyzed_join->leftConvertingActions()));
chain.steps.push_back(std::make_unique<ExpressionActionsChain::ExpressionActionsStep>(left_actions));
chain.addStep();
}
@ -741,7 +752,7 @@ static JoinPtr tryGetStorageJoin(std::shared_ptr<TableJoin> analyzed_join)
{
if (auto * table = analyzed_join->joined_storage.get())
if (auto * storage_join = dynamic_cast<StorageJoin *>(table))
return storage_join->getJoinLocked(analyzed_join);
return storage_join->getJoinLocked(analyzed_join->getJoinInfo());
return {};
}
@ -752,7 +763,7 @@ static ExpressionActionsPtr createJoinedBlockActions(const Context & context, co
return ExpressionAnalyzer(expression_list, syntax_result, context).getActions(true, false);
}
static bool allowDictJoin(StoragePtr joined_storage, const Context & context, String & dict_name, String & key_name)
static bool allowDictJoin(const StoragePtr joined_storage, const Context & context, String & dict_name, String & key_name)
{
const auto * dict = dynamic_cast<const StorageDictionary *>(joined_storage.get());
if (!dict)
@ -772,33 +783,39 @@ static bool allowDictJoin(StoragePtr joined_storage, const Context & context, St
return false;
}
static std::shared_ptr<IJoin> makeJoin(std::shared_ptr<TableJoin> analyzed_join, const Block & sample_block, const Context & context)
/// HashJoin with Dictionary optimisation
static std::shared_ptr<IJoin> tryMakeDictJoin(const TableJoin & analyzed_join, const Block & sample_block, const Context & context)
{
bool allow_merge_join = analyzed_join->allowMergeJoin();
/// HashJoin with Dictionary optimisation
String dict_name;
String key_name;
if (analyzed_join->joined_storage && allowDictJoin(analyzed_join->joined_storage, context, dict_name, key_name))
if (analyzed_join.joined_storage && allowDictJoin(analyzed_join.joined_storage, context, dict_name, key_name))
{
Names original_names;
NamesAndTypesList result_columns;
if (analyzed_join->allowDictJoin(key_name, sample_block, original_names, result_columns))
if (analyzed_join.allowDictJoin(key_name, sample_block, original_names, result_columns))
{
analyzed_join->dictionary_reader = std::make_shared<DictionaryReader>(dict_name, original_names, result_columns, context);
return std::make_shared<HashJoin>(analyzed_join, sample_block);
auto dictionary_reader = std::make_shared<DictionaryReader>(dict_name, original_names, result_columns, context);
return std::make_shared<HashJoin>(analyzed_join.getJoinInfo(), sample_block, dictionary_reader);
}
}
return {};
}
if (analyzed_join->forceHashJoin() || (analyzed_join->preferMergeJoin() && !allow_merge_join))
return std::make_shared<HashJoin>(analyzed_join, sample_block);
else if (analyzed_join->forceMergeJoin() || (analyzed_join->preferMergeJoin() && allow_merge_join))
return std::make_shared<MergeJoin>(analyzed_join, sample_block);
return std::make_shared<JoinSwitcher>(analyzed_join, sample_block);
static std::shared_ptr<IJoin> makeJoin(const TableJoin & analyzed_join, const Block & sample_block)
{
auto join_info = analyzed_join.getJoinInfo();
bool allow_merge_join = allowMergeJoin(join_info.kind, join_info.strictness);
if (join_info.forceHashJoin() || (join_info.preferMergeJoin() && !allow_merge_join))
return std::make_shared<HashJoin>(std::move(join_info), sample_block);
else if (join_info.forceMergeJoin() || (join_info.preferMergeJoin() && allow_merge_join))
return std::make_shared<MergeJoin>(std::move(join_info), sample_block, analyzed_join.getTemporaryVolume());
return std::make_shared<JoinSwitcher>(std::move(join_info), sample_block, analyzed_join.getTemporaryVolume());
}
JoinPtr SelectQueryExpressionAnalyzer::makeTableJoin(
const ASTTablesInSelectQueryElement & join_element, const ColumnsWithTypeAndName & left_sample_columns)
const ASTTablesInSelectQueryElement & join_element, const ColumnsWithTypeAndName & left_sample_columns,
ActionsDAGPtr & left_converting_actions, ActionsDAGPtr & right_converting_actions)
{
/// Two JOINs are not supported with the same subquery, but different USINGs.
auto join_hash = join_element.getTreeHash();
@ -818,8 +835,8 @@ JoinPtr SelectQueryExpressionAnalyzer::makeTableJoin(
Names original_right_columns;
if (!subquery_for_join.source)
{
NamesWithAliases required_columns_with_aliases = analyzedJoin().getRequiredColumns(
joined_block_actions->getSampleBlock(), joined_block_actions->getRequiredColumns());
NamesWithAliases required_columns_with_aliases
= analyzedJoin().getRequiredColumns(joined_block_actions->getSampleBlock(), joined_block_actions->getRequiredColumns());
for (auto & pr : required_columns_with_aliases)
original_right_columns.push_back(pr.first);
@ -837,11 +854,15 @@ JoinPtr SelectQueryExpressionAnalyzer::makeTableJoin(
subquery_for_join.addJoinActions(joined_block_actions); /// changes subquery_for_join.sample_block inside
const ColumnsWithTypeAndName & right_sample_columns = subquery_for_join.sample_block.getColumnsWithTypeAndName();
bool need_convert = syntax->analyzed_join->applyJoinKeyConvert(left_sample_columns, right_sample_columns);
if (need_convert)
subquery_for_join.addJoinActions(std::make_shared<ExpressionActions>(syntax->analyzed_join->rightConvertingActions()));
subquery_for_join.join = makeJoin(syntax->analyzed_join, subquery_for_join.sample_block, context);
bool need_convert = syntax->analyzed_join->applyJoinKeyConvert(
left_sample_columns, right_sample_columns, left_converting_actions, right_converting_actions);
if (need_convert)
subquery_for_join.addJoinActions(std::make_shared<ExpressionActions>(right_converting_actions));
subquery_for_join.join = tryMakeDictJoin(*syntax->analyzed_join, subquery_for_join.sample_block, context);
if (!subquery_for_join.join)
subquery_for_join.join = makeJoin(*syntax->analyzed_join, subquery_for_join.sample_block);
/// Do not make subquery for join over dictionary.
if (syntax->analyzed_join->dictionary_reader)
@ -1436,8 +1457,7 @@ ExpressionAnalysisResult::ExpressionAnalysisResult(
{
query_analyzer.appendJoinLeftKeys(chain, only_types || !first_stage);
before_join = chain.getLastActions();
join = query_analyzer.appendJoin(chain);
converting_join_columns = query_analyzer.analyzedJoin().leftConvertingActions();
join = query_analyzer.appendJoin(chain, converting_join_left_columns, converting_join_right_columns);
chain.addStep();
}

View File

@ -200,7 +200,8 @@ struct ExpressionAnalysisResult
ActionsDAGPtr before_array_join;
ArrayJoinActionPtr array_join;
ActionsDAGPtr before_join;
ActionsDAGPtr converting_join_columns;
ActionsDAGPtr converting_join_left_columns;
ActionsDAGPtr converting_join_right_columns;
JoinPtr join;
ActionsDAGPtr before_where;
ActionsDAGPtr before_aggregation;
@ -317,7 +318,9 @@ private:
JoinPtr makeTableJoin(
const ASTTablesInSelectQueryElement & join_element,
const ColumnsWithTypeAndName & left_sample_columns);
const ColumnsWithTypeAndName & left_sample_columns,
ActionsDAGPtr & left_converting_actions,
ActionsDAGPtr & right_converting_actions);
const ASTSelectQuery * getAggregatingQuery() const;
@ -338,7 +341,7 @@ private:
/// Before aggregation:
ArrayJoinActionPtr appendArrayJoin(ExpressionActionsChain & chain, ActionsDAGPtr & before_array_join, bool only_types);
bool appendJoinLeftKeys(ExpressionActionsChain & chain, bool only_types);
JoinPtr appendJoin(ExpressionActionsChain & chain);
JoinPtr appendJoin(ExpressionActionsChain & chain, ActionsDAGPtr & left_actions, ActionsDAGPtr & right_actions);
/// Add preliminary rows filtration. Actions are created in other expression analyzer to prevent any possible alias injection.
void appendPreliminaryFilter(ExpressionActionsChain & chain, ActionsDAGPtr actions_dag, String column_name);
/// remove_filter is set in ExpressionActionsChain::finalize();

View File

@ -53,6 +53,12 @@ struct NotProcessedCrossJoin : public ExtraBlock
size_t right_block;
};
struct DictJoinData
{
std::shared_ptr<DictionaryReader> dictionary_reader;
bool nullable_right;
};
}
namespace JoinStuff
@ -175,34 +181,37 @@ static ColumnWithTypeAndName correctNullability(ColumnWithTypeAndName && column,
}
HashJoin::HashJoin(std::shared_ptr<TableJoin> table_join_, const Block & right_sample_block_, bool any_take_last_row_)
: table_join(table_join_)
, kind(table_join->kind())
, strictness(table_join->strictness())
, key_names_right(table_join->keyNamesRight())
, nullable_right_side(table_join->forceNullableRight())
, nullable_left_side(table_join->forceNullableLeft())
HashJoin::HashJoin(JoinInfo join_info_, const Block & right_sample_block_,
std::shared_ptr<DictionaryReader> dictionary_reader_,
bool any_take_last_row_)
: join_info(join_info_)
, kind(join_info_.kind)
, strictness(join_info_.strictness)
, any_take_last_row(any_take_last_row_)
, asof_inequality(table_join->getAsofInequality())
, dictionary_reader(dictionary_reader_)
, data(std::make_shared<RightTableData>())
, right_sample_block(right_sample_block_)
, log(&Poco::Logger::get("HashJoin"))
{
LOG_DEBUG(log, "Right sample block: {}", right_sample_block.dumpStructure());
table_join->splitAdditionalColumns(right_sample_block, right_table_keys, sample_block_with_columns_to_add);
required_right_keys = table_join->getRequiredRightKeys(right_table_keys, required_right_keys_sources);
JoinCommon::splitAdditionalColumns(join_info.key_names_right, right_sample_block, right_table_keys, sample_block_with_columns_to_add);
required_right_keys = JoinCommon::getRequiredRightKeys(
join_info.key_names_left, join_info.key_names_right, join_info.required_right_keys, right_table_keys, required_right_keys_sources);
JoinCommon::removeLowCardinalityInplace(right_table_keys);
initRightBlockStructure(data->sample_block);
LOG_DEBUG(log, "Right sample block: {}, join on keys: {}. Left keys: {}",
right_sample_block.dumpStructure(),
fmt::join(key_names_right, ", "),
fmt::join(key_names_left, ", "));
ColumnRawPtrs key_columns = JoinCommon::extractKeysForJoin(right_table_keys, key_names_right);
JoinCommon::createMissedColumns(sample_block_with_columns_to_add);
if (nullable_right_side)
if (join_info.forceNullableRight())
JoinCommon::convertColumnsToNullable(sample_block_with_columns_to_add);
if (table_join->dictionary_reader)
if (dictionary_reader)
{
data->type = Type::DICT;
std::get<MapsOne>(data->maps).create(Type::DICT);
@ -323,17 +332,16 @@ public:
: key_columns(key_columns_)
{}
FindResult findKey(const TableJoin & table_join, size_t row, const Arena &)
FindResult findKey(const DictJoinData & join_data, size_t row, const Arena &)
{
const DictionaryReader & reader = *table_join.dictionary_reader;
if (!read_result)
{
reader.readKeys(*key_columns[0], read_result, found, positions);
join_data.dictionary_reader->readKeys(*key_columns[0], read_result, found, positions);
result.block = &read_result;
if (table_join.forceNullableRight())
if (join_data.nullable_right)
for (auto & column : read_result)
if (table_join.rightBecomeNullable(column.type))
if (column.type->canBeInsideNullable())
JoinCommon::convertColumnToNullable(column);
}
@ -423,7 +431,7 @@ bool HashJoin::empty() const
bool HashJoin::alwaysReturnsEmptySet() const
{
return isInnerOrRight(getKind()) && data->empty && !overDictionary();
return isInnerOrRight(join_info.kind) && data->empty && !overDictionary();
}
size_t HashJoin::getTotalRowCount() const
@ -571,7 +579,7 @@ namespace
void HashJoin::initRightBlockStructure(Block & saved_block_sample)
{
/// We could remove key columns for LEFT | INNER HashJoin but we should keep them for JoinSwitcher (if any).
bool save_key_columns = !table_join->forceHashJoin() || isRightOrFull(kind);
bool save_key_columns = !join_info.forceHashJoin() || isRightOrFull(kind);
if (save_key_columns)
{
saved_block_sample = right_table_keys.cloneEmpty();
@ -586,7 +594,7 @@ void HashJoin::initRightBlockStructure(Block & saved_block_sample)
for (auto & column : sample_block_with_columns_to_add)
saved_block_sample.insert(column);
if (nullable_right_side)
if (join_info.forceNullableRight())
JoinCommon::convertColumnsToNullable(saved_block_sample, (isFull(kind) ? right_table_keys.columns() : 0));
}
@ -670,7 +678,7 @@ bool HashJoin::addJoinedBlock(const Block & source_block, bool check_limits)
total_bytes = getTotalByteCount();
}
return table_join->sizeLimits().check(total_rows, total_bytes, "JOIN", ErrorCodes::SET_SIZE_LIMIT_EXCEEDED);
return join_info.size_limits.check(total_rows, total_bytes, "JOIN", ErrorCodes::SET_SIZE_LIMIT_EXCEEDED);
}
@ -984,7 +992,7 @@ IColumn::Filter switchJoinRightColumns(
}
template <ASTTableJoin::Kind KIND, ASTTableJoin::Strictness STRICTNESS>
IColumn::Filter dictionaryJoinRightColumns(const TableJoin & table_join, AddedColumns & added_columns, const ConstNullMapPtr & null_map)
IColumn::Filter dictionaryJoinRightColumns(const DictJoinData & join_data, AddedColumns & added_columns, const ConstNullMapPtr & null_map)
{
if constexpr (KIND == ASTTableJoin::Kind::Left &&
(STRICTNESS == ASTTableJoin::Strictness::Any ||
@ -992,7 +1000,7 @@ IColumn::Filter dictionaryJoinRightColumns(const TableJoin & table_join, AddedCo
STRICTNESS == ASTTableJoin::Strictness::Anti))
{
JoinStuff::JoinUsedFlags flags;
return joinRightColumnsSwitchNullability<KIND, STRICTNESS, KeyGetterForDict>(table_join, added_columns, null_map, flags);
return joinRightColumnsSwitchNullability<KIND, STRICTNESS, KeyGetterForDict>(join_data, added_columns, null_map, flags);
}
throw Exception("Logical error: wrong JOIN combination", ErrorCodes::LOGICAL_ERROR);
@ -1004,7 +1012,7 @@ IColumn::Filter dictionaryJoinRightColumns(const TableJoin & table_join, AddedCo
template <ASTTableJoin::Kind KIND, ASTTableJoin::Strictness STRICTNESS, typename Maps>
void HashJoin::joinBlockImpl(
Block & block,
const Names & key_names_left,
const Names & key_names,
const Block & block_with_columns_to_add,
const Maps & maps_) const
{
@ -1023,7 +1031,7 @@ void HashJoin::joinBlockImpl(
constexpr bool need_filter = !need_replication && (inner || right || (is_semi_join && left) || (is_anti_join && left));
/// Rare case, when keys are constant or low cardinality. To avoid code bloat, simply materialize them.
Columns materialized_keys = JoinCommon::materializeColumns(block, key_names_left);
Columns materialized_keys = JoinCommon::materializeColumns(block, key_names);
ColumnRawPtrs left_key_columns = JoinCommon::getRawPointers(materialized_keys);
/// Keys with NULL value in any column won't join to anything.
@ -1040,7 +1048,7 @@ void HashJoin::joinBlockImpl(
{
materializeBlockInplace(block);
if (nullable_left_side)
if (join_info.forceNullableLeft())
JoinCommon::convertColumnsToNullable(block);
}
@ -1055,7 +1063,7 @@ void HashJoin::joinBlockImpl(
added_columns.need_filter = need_filter || has_required_right_keys;
IColumn::Filter row_filter = overDictionary() ?
dictionaryJoinRightColumns<KIND, STRICTNESS>(*table_join, added_columns, null_map) :
dictionaryJoinRightColumns<KIND, STRICTNESS>({dictionary_reader, join_info.forceNullableRight()}, added_columns, null_map) :
switchJoinRightColumns<KIND, STRICTNESS>(maps_, added_columns, data->type, null_map, used_flags);
for (size_t i = 0; i < added_columns.size(); ++i)
@ -1080,7 +1088,7 @@ void HashJoin::joinBlockImpl(
continue;
const auto & col = block.getByName(left_name);
bool is_nullable = nullable_right_side || right_key.type->isNullable();
bool is_nullable = join_info.forceNullableRight() || right_key.type->isNullable();
block.insert(correctNullability({col.column, col.type, right_key.name}, is_nullable));
}
}
@ -1103,7 +1111,7 @@ void HashJoin::joinBlockImpl(
continue;
const auto & col = block.getByName(left_name);
bool is_nullable = nullable_right_side || right_key.type->isNullable();
bool is_nullable = join_info.forceNullableRight() || right_key.type->isNullable();
ColumnPtr thin_column = filterWithBlanks(col.column, filter);
block.insert(correctNullability({thin_column, col.type, right_key.name}, is_nullable, null_map_filter));
@ -1129,7 +1137,7 @@ void HashJoin::joinBlockImpl(
void HashJoin::joinBlockImplCross(Block & block, ExtraBlockPtr & not_processed) const
{
size_t max_joined_block_rows = table_join->maxJoinedBlockRows();
size_t max_joined_block_rows = join_info.max_joined_block_rows;
size_t start_left_row = 0;
size_t start_right_block = 0;
if (not_processed)
@ -1263,7 +1271,6 @@ ColumnWithTypeAndName HashJoin::joinGet(const Block & block, const Block & block
void HashJoin::joinBlock(Block & block, ExtraBlockPtr & not_processed)
{
const Names & key_names_left = table_join->keyNamesLeft();
JoinCommon::checkTypesOfKeys(block, key_names_left, right_table_keys, key_names_right);
if (overDictionary())
@ -1311,7 +1318,7 @@ void HashJoin::joinBlock(Block & block, ExtraBlockPtr & not_processed)
void HashJoin::joinTotals(Block & block) const
{
JoinCommon::joinTotals(totals, sample_block_with_columns_to_add, *table_join, block);
JoinCommon::joinTotals(totals, sample_block_with_columns_to_add, join_info, block);
}
@ -1359,7 +1366,7 @@ class NonJoinedBlockInputStream : private NotJoined, public IBlockInputStream
{
public:
NonJoinedBlockInputStream(const HashJoin & parent_, const Block & result_sample_block_, UInt64 max_block_size_)
: NotJoined(*parent_.table_join,
: NotJoined(parent_.join_info,
parent_.savedBlockSample(),
parent_.right_sample_block,
result_sample_block_)
@ -1492,11 +1499,11 @@ private:
BlockInputStreamPtr HashJoin::createStreamWithNonJoinedRows(const Block & result_sample_block, UInt64 max_block_size) const
{
if (table_join->strictness() == ASTTableJoin::Strictness::Asof ||
table_join->strictness() == ASTTableJoin::Strictness::Semi)
if (join_info.strictness == ASTTableJoin::Strictness::Asof ||
join_info.strictness == ASTTableJoin::Strictness::Semi)
return {};
if (isRightOrFull(table_join->kind()))
if (isRightOrFull(join_info.kind))
return std::make_shared<NonJoinedBlockInputStream>(*this, result_sample_block, max_block_size);
return {};
}

View File

@ -10,6 +10,7 @@
#include <Interpreters/IJoin.h>
#include <Interpreters/AggregationCommon.h>
#include <Interpreters/RowRefs.h>
#include <Interpreters/TableJoin.h>
#include <Common/Arena.h>
#include <Common/ColumnsHashing.h>
@ -27,7 +28,6 @@
namespace DB
{
class TableJoin;
class DictionaryReader;
namespace JoinStuff
@ -132,7 +132,9 @@ public:
class HashJoin : public IJoin
{
public:
HashJoin(std::shared_ptr<TableJoin> table_join_, const Block & right_sample_block, bool any_take_last_row_ = false);
HashJoin(JoinInfo join_info_, const Block & right_sample_block_,
std::shared_ptr<DictionaryReader> dictionary_reader_ = nullptr,
bool any_take_last_row_ = false);
/** Add block of data from right hand of JOIN to the map.
* Returns false, if some limit was exceeded and you should not insert more data.
@ -171,10 +173,8 @@ public:
bool alwaysReturnsEmptySet() const final;
ASTTableJoin::Kind getKind() const { return kind; }
ASTTableJoin::Strictness getStrictness() const { return strictness; }
const std::optional<TypeIndex> & getAsofType() const { return asof_type; }
ASOF::Inequality getAsofInequality() const { return asof_inequality; }
ASOF::Inequality getAsofInequality() const { return join_info.asof_inequality; }
bool anyTakeLastRow() const { return any_take_last_row; }
const ColumnWithTypeAndName & rightAsofKeyColumn() const
@ -338,18 +338,18 @@ private:
friend class NonJoinedBlockInputStream;
friend class JoinSource;
std::shared_ptr<TableJoin> table_join;
JoinInfo join_info;
ASTTableJoin::Kind kind;
ASTTableJoin::Strictness strictness;
/// Names of key columns in right-side table (in the order they appear in ON/USING clause). @note It could contain duplicates.
const Names & key_names_right;
const Names key_names_left;
const Names key_names_right;
bool nullable_right_side; /// In case of LEFT and FULL joins, if use_nulls, convert right-side columns to Nullable.
bool nullable_left_side; /// In case of RIGHT and FULL joins, if use_nulls, convert left-side columns to Nullable.
bool any_take_last_row; /// Overwrite existing values when encountering the same key again
std::optional<TypeIndex> asof_type;
ASOF::Inequality asof_inequality;
std::shared_ptr<DictionaryReader> dictionary_reader;
/// Right table data. StorageJoin shares it between many Join objects.
std::shared_ptr<RightTableData> data;

View File

@ -349,8 +349,6 @@ InterpreterSelectQuery::InterpreterSelectQuery(
joined_tables.rewriteDistributedInAndJoins(query_ptr);
max_streams = settings.max_threads;
ASTSelectQuery & query = getSelectQuery();
std::shared_ptr<TableJoin> table_join = joined_tables.makeTableJoin(query);
ASTPtr row_policy_filter;
if (storage)
@ -362,6 +360,8 @@ InterpreterSelectQuery::InterpreterSelectQuery(
SubqueriesForSets subquery_for_sets;
ASTSelectQuery & query = getSelectQuery();
auto analyze = [&] (bool try_move_to_prewhere)
{
/// Allow push down and other optimizations for VIEW: replace with subquery and rewrite it.
@ -372,7 +372,7 @@ InterpreterSelectQuery::InterpreterSelectQuery(
syntax_analyzer_result = TreeRewriter(*context).analyzeSelect(
query_ptr,
TreeRewriterResult(source_header.getNamesAndTypesList(), storage, metadata_snapshot),
options, joined_tables.tablesWithColumns(), required_result_column_names, table_join);
options, required_result_column_names, &joined_tables);
/// Save scalar sub queries's results in the query context
if (!options.only_analyze && context->hasQueryContext())

View File

@ -17,16 +17,14 @@ static ColumnWithTypeAndName correctNullability(ColumnWithTypeAndName && column,
return std::move(column);
}
JoinSwitcher::JoinSwitcher(std::shared_ptr<TableJoin> table_join_, const Block & right_sample_block_)
: limits(table_join_->sizeLimits())
JoinSwitcher::JoinSwitcher(JoinInfo join_info_, const Block & right_sample_block_,
const MergeJoin::TemporaryVolumeSettings & temp_vol_settings_)
: join_info(join_info_)
, switched(false)
, table_join(table_join_)
, right_sample_block(right_sample_block_.cloneEmpty())
, temp_vol_settings(temp_vol_settings_)
{
join = std::make_shared<HashJoin>(table_join, right_sample_block);
if (!limits.hasLimits())
limits.max_bytes = table_join->defaultMaxBytes();
join = std::make_shared<HashJoin>(join_info, right_sample_block);
}
bool JoinSwitcher::addJoinedBlock(const Block & block, bool)
@ -42,7 +40,7 @@ bool JoinSwitcher::addJoinedBlock(const Block & block, bool)
size_t rows = join->getTotalRowCount();
size_t bytes = join->getTotalByteCount();
if (!limits.softCheck(rows, bytes))
if (!join_info.size_limits.softCheck(rows, bytes))
switchJoin();
return true;
@ -54,7 +52,7 @@ void JoinSwitcher::switchJoin()
BlocksList right_blocks = std::move(joined_data->blocks);
/// Destroy old join & create new one. Early destroy for memory saving.
join = std::make_shared<MergeJoin>(table_join, right_sample_block);
join = std::make_shared<MergeJoin>(join_info, right_sample_block, temp_vol_settings);
/// names to positions optimization
std::vector<size_t> positions;

View File

@ -5,6 +5,7 @@
#include <Core/Block.h>
#include <Interpreters/IJoin.h>
#include <Interpreters/TableJoin.h>
#include <Interpreters/MergeJoin.h>
#include <DataStreams/IBlockInputStream.h>
@ -17,7 +18,8 @@ namespace DB
class JoinSwitcher : public IJoin
{
public:
JoinSwitcher(std::shared_ptr<TableJoin> table_join_, const Block & right_sample_block_);
JoinSwitcher(JoinInfo join_info_, const Block & right_sample_block_,
const MergeJoin::TemporaryVolumeSettings & temp_vol_settings_);
/// Add block of data from right hand of JOIN into current join object.
/// If join-in-memory memory limit exceeded switches to join-on-disk and continue with it.
@ -66,11 +68,12 @@ public:
private:
JoinPtr join;
SizeLimits limits;
JoinInfo join_info;
bool switched;
mutable std::mutex switch_mutex;
std::shared_ptr<TableJoin> table_join;
const Block right_sample_block;
const MergeJoin::TemporaryVolumeSettings & temp_vol_settings;
/// Change join-in-memory to join-on-disk moving right hand JOIN data from one to another.
/// Throws an error if join-on-disk do not support JOIN kind or strictness.

View File

@ -249,17 +249,17 @@ void JoinedTables::rewriteDistributedInAndJoins(ASTPtr & query)
}
}
std::shared_ptr<TableJoin> JoinedTables::makeTableJoin(const ASTSelectQuery & select_query)
std::shared_ptr<TableJoin> JoinedTables::makeTableJoin(const ASTSelectQuery & select_query) const
{
if (tables_with_columns.size() < 2)
return {};
return std::make_shared<TableJoin>();
auto settings = context.getSettingsRef();
auto table_join = std::make_shared<TableJoin>(settings, context.getTemporaryVolume());
const ASTTablesInSelectQueryElement * ast_join = select_query.join();
const auto & table_to_join = ast_join->table_expression->as<ASTTableExpression &>();
auto table_join = std::make_shared<TableJoin>(ast_join->table_join->as<ASTTableJoin &>(), settings, context.getTemporaryVolume());
/// TODO This syntax does not support specifying a database name.
if (table_to_join.database_and_table_name)
{

View File

@ -34,10 +34,9 @@ public:
/// Make fake tables_with_columns[0] in case we have predefined input in InterpreterSelectQuery
void makeFakeTable(StoragePtr storage, const StorageMetadataPtr & metadata_snapshot, const Block & source_header);
std::shared_ptr<TableJoin> makeTableJoin(const ASTSelectQuery & select_query);
std::shared_ptr<TableJoin> makeTableJoin(const ASTSelectQuery & select_query) const;
const TablesWithColumns & tablesWithColumns() const { return tables_with_columns; }
TablesWithColumns moveTablesWithColumns() { return std::move(tables_with_columns); }
bool isLeftTableSubquery() const;
bool isLeftTableFunction() const;

View File

@ -418,24 +418,22 @@ void joinInequalsLeft(const Block & left_block, MutableColumns & left_columns,
}
MergeJoin::MergeJoin(std::shared_ptr<TableJoin> table_join_, const Block & right_sample_block_)
: table_join(table_join_)
, size_limits(table_join->sizeLimits())
MergeJoin::MergeJoin(JoinInfo join_info_, const Block & right_sample_block_, const TemporaryVolumeSettings & temp_vol_settings_)
: join_info(join_info_)
, right_sample_block(right_sample_block_)
, nullable_right_side(table_join->forceNullableRight())
, nullable_left_side(table_join->forceNullableLeft())
, is_any_join(table_join->strictness() == ASTTableJoin::Strictness::Any)
, is_all_join(table_join->strictness() == ASTTableJoin::Strictness::All)
, is_semi_join(table_join->strictness() == ASTTableJoin::Strictness::Semi)
, is_inner(isInner(table_join->kind()))
, is_left(isLeft(table_join->kind()))
, is_right(isRight(table_join->kind()))
, is_full(isFull(table_join->kind()))
, max_joined_block_rows(table_join->maxJoinedBlockRows())
, max_rows_in_right_block(table_join->maxRowsInRightBlock())
, max_files_to_merge(table_join->maxFilesToMerge())
, nullable_right_side(join_info.forceNullableRight())
, nullable_left_side(join_info.forceNullableLeft())
, is_any_join(join_info.strictness == ASTTableJoin::Strictness::Any)
, is_all_join(join_info.strictness == ASTTableJoin::Strictness::All)
, is_semi_join(join_info.strictness == ASTTableJoin::Strictness::Semi)
, is_inner(isInner(join_info.kind))
, is_left(isLeft(join_info.kind))
, is_right(isRight(join_info.kind))
, is_full(isFull(join_info.kind))
, max_rows_in_right_block(join_info.partial_merge_join_rows_in_right_blocks)
, temp_vol_settings(temp_vol_settings_)
{
switch (table_join->strictness())
switch (join_info.strictness)
{
case ASTTableJoin::Strictness::All:
break;
@ -452,23 +450,20 @@ MergeJoin::MergeJoin(std::shared_ptr<TableJoin> table_join_, const Block & right
if (!max_rows_in_right_block)
throw Exception("partial_merge_join_rows_in_right_blocks cannot be zero", ErrorCodes::PARAMETER_OUT_OF_BOUND);
if (max_files_to_merge < 2)
if (join_info.max_files_to_merge < 2)
throw Exception("max_files_to_merge cannot be less than 2", ErrorCodes::PARAMETER_OUT_OF_BOUND);
if (!size_limits.hasLimits())
if (!join_info.size_limits.hasLimits())
{
size_limits.max_bytes = table_join->defaultMaxBytes();
if (!size_limits.max_bytes)
throw Exception("No limit for MergeJoin (max_rows_in_join, max_bytes_in_join or default_max_bytes_in_join have to be set)",
ErrorCodes::PARAMETER_OUT_OF_BOUND);
throw Exception("No limit for MergeJoin (max_rows_in_join, max_bytes_in_join or default_max_bytes_in_join have to be set)",
ErrorCodes::PARAMETER_OUT_OF_BOUND);
}
table_join->splitAdditionalColumns(right_sample_block, right_table_keys, right_columns_to_add);
JoinCommon::splitAdditionalColumns(join_info.key_names_right, right_sample_block, right_table_keys, right_columns_to_add);
JoinCommon::removeLowCardinalityInplace(right_table_keys);
const NameSet required_right_keys = table_join->requiredRightKeys();
for (const auto & column : right_table_keys)
if (required_right_keys.count(column.name))
if (join_info.required_right_keys.count(column.name))
right_columns_to_add.insert(ColumnWithTypeAndName{nullptr, column.type, column.name});
JoinCommon::createMissedColumns(right_columns_to_add);
@ -476,12 +471,12 @@ MergeJoin::MergeJoin(std::shared_ptr<TableJoin> table_join_, const Block & right
if (nullable_right_side)
JoinCommon::convertColumnsToNullable(right_columns_to_add);
makeSortAndMerge(table_join->keyNamesLeft(), left_sort_description, left_merge_description);
makeSortAndMerge(table_join->keyNamesRight(), right_sort_description, right_merge_description);
makeSortAndMerge(join_info.key_names_left, left_sort_description, left_merge_description);
makeSortAndMerge(join_info.key_names_right, right_sort_description, right_merge_description);
/// Temporary disable 'partial_merge_join_left_table_buffer_bytes' without 'partial_merge_join_optimizations'
if (table_join->enablePartialMergeJoinOptimizations())
if (size_t max_bytes = table_join->maxBytesInLeftBuffer())
if (join_info.partial_merge_join_optimizations)
if (size_t max_bytes = join_info.partial_merge_join_left_table_buffer_bytes)
left_blocks_buffer = std::make_shared<SortedBlocksBuffer>(left_sort_description, max_bytes);
}
@ -496,7 +491,7 @@ void MergeJoin::setTotals(const Block & totals_block)
void MergeJoin::joinTotals(Block & block) const
{
JoinCommon::joinTotals(totals, right_columns_to_add, *table_join, block);
JoinCommon::joinTotals(totals, right_columns_to_add, join_info, block);
}
void MergeJoin::mergeRightBlocks()
@ -522,7 +517,8 @@ void MergeJoin::mergeInMemoryRightBlocks()
pipeline.init(std::move(source));
/// TODO: there should be no split keys by blocks for RIGHT|FULL JOIN
pipeline.addTransform(std::make_shared<MergeSortingTransform>(pipeline.getHeader(), right_sort_description, max_rows_in_right_block, 0, 0, 0, 0, nullptr, 0));
pipeline.addTransform(
std::make_shared<MergeSortingTransform>(pipeline.getHeader(), right_sort_description, max_rows_in_right_block, 0, 0, 0, 0, nullptr, 0));
auto sorted_input = PipelineExecutingBlockInputStream(std::move(pipeline));
@ -590,7 +586,7 @@ bool MergeJoin::saveRightBlock(Block && block)
Block MergeJoin::modifyRightBlock(const Block & src_block) const
{
Block block = materializeBlock(src_block);
JoinCommon::removeLowCardinalityInplace(block, table_join->keyNamesRight());
JoinCommon::removeLowCardinalityInplace(block, join_info.key_names_right);
return block;
}
@ -606,9 +602,9 @@ void MergeJoin::joinBlock(Block & block, ExtraBlockPtr & not_processed)
{
if (block)
{
JoinCommon::checkTypesOfKeys(block, table_join->keyNamesLeft(), right_table_keys, table_join->keyNamesRight());
JoinCommon::checkTypesOfKeys(block, join_info.key_names_left, right_table_keys, join_info.key_names_right);
materializeBlockInplace(block);
JoinCommon::removeLowCardinalityInplace(block, table_join->keyNamesLeft(), false);
JoinCommon::removeLowCardinalityInplace(block, join_info.key_names_left, false);
sortBlock(block, left_sort_description);
@ -679,7 +675,7 @@ void MergeJoin::joinSortedBlock(Block & block, ExtraBlockPtr & not_processed)
if (skip_not_intersected)
{
int intersection = left_cursor.intersect(min_max_right_blocks[i], table_join->keyNamesRight());
int intersection = left_cursor.intersect(min_max_right_blocks[i], join_info.key_names_right);
if (intersection < 0)
break; /// (left) ... (right)
if (intersection > 0)
@ -713,7 +709,7 @@ void MergeJoin::joinSortedBlock(Block & block, ExtraBlockPtr & not_processed)
if (skip_not_intersected)
{
int intersection = left_cursor.intersect(min_max_right_blocks[i], table_join->keyNamesRight());
int intersection = left_cursor.intersect(min_max_right_blocks[i], join_info.key_names_right);
if (intersection < 0)
break; /// (left) ... (right)
if (intersection > 0)
@ -784,7 +780,7 @@ bool MergeJoin::leftJoin(MergeJoinCursor & left_cursor, const Block & left_block
{
right_block_info.setUsed(range.right_start, range.right_length);
size_t max_rows = maxRangeRows(left_columns[0]->size(), max_joined_block_rows);
size_t max_rows = maxRangeRows(left_columns[0]->size(), join_info.max_joined_block_rows);
if (!joinEquals<true>(left_block, right_block, right_columns_to_add, left_columns, right_columns, range, max_rows))
{
@ -832,7 +828,7 @@ bool MergeJoin::allInnerJoin(MergeJoinCursor & left_cursor, const Block & left_b
right_block_info.setUsed(range.right_start, range.right_length);
size_t max_rows = maxRangeRows(left_columns[0]->size(), max_joined_block_rows);
size_t max_rows = maxRangeRows(left_columns[0]->size(), join_info.max_joined_block_rows);
if (!joinEquals<true>(left_block, right_block, right_columns_to_add, left_columns, right_columns, range, max_rows))
{
@ -942,9 +938,9 @@ std::shared_ptr<Block> MergeJoin::loadRightBlock(size_t pos) const
void MergeJoin::initRightTableWriter()
{
disk_writer = std::make_unique<SortedBlocksWriter>(size_limits, table_join->getTemporaryVolume(),
right_sample_block, right_sort_description, max_rows_in_right_block, max_files_to_merge,
table_join->temporaryFilesCodec());
disk_writer = std::make_unique<SortedBlocksWriter>(size_limits, temp_vol_settings.first,
right_sample_block, right_sort_description, max_rows_in_right_block, join_info.max_files_to_merge,
temp_vol_settings.second);
disk_writer->addBlocks(right_blocks);
right_blocks.clear();
}
@ -954,7 +950,7 @@ class NonMergeJoinedBlockInputStream : private NotJoined, public IBlockInputStre
{
public:
NonMergeJoinedBlockInputStream(const MergeJoin & parent_, const Block & result_sample_block_, UInt64 max_block_size_)
: NotJoined(*parent_.table_join,
: NotJoined(parent_.join_info,
parent_.modifyRightBlock(parent_.right_sample_block),
parent_.right_sample_block,
result_sample_block_)
@ -1041,7 +1037,7 @@ private:
BlockInputStreamPtr MergeJoin::createStreamWithNonJoinedRows(const Block & result_sample_block, UInt64 max_block_size) const
{
if (table_join->strictness() == ASTTableJoin::Strictness::All && (is_right || is_full))
if (join_info.strictness == ASTTableJoin::Strictness::All && (is_right || is_full))
return std::make_shared<NonMergeJoinedBlockInputStream>(*this, result_sample_block, max_block_size);
return {};
}

View File

@ -8,6 +8,7 @@
#include <Interpreters/IJoin.h>
#include <Interpreters/SortedBlocksWriter.h>
#include <DataStreams/SizeLimits.h>
#include <Interpreters/TableJoin.h>
namespace DB
{
@ -21,7 +22,9 @@ class RowBitmaps;
class MergeJoin : public IJoin
{
public:
MergeJoin(std::shared_ptr<TableJoin> table_join_, const Block & right_sample_block);
using TemporaryVolumeSettings = std::pair<VolumePtr, String>;
MergeJoin(JoinInfo join_info_, const Block & right_sample_block, const TemporaryVolumeSettings & temp_vol_settings_);
bool addJoinedBlock(const Block & block, bool check_limits) override;
void joinBlock(Block &, ExtraBlockPtr & not_processed) override;
@ -66,7 +69,7 @@ private:
using Cache = LRUCache<size_t, Block, std::hash<size_t>, BlockByteWeight>;
mutable std::shared_mutex rwlock;
std::shared_ptr<TableJoin> table_join;
JoinInfo join_info;
SizeLimits size_limits;
SortDescription left_sort_description;
SortDescription right_sort_description;
@ -95,9 +98,9 @@ private:
const bool is_right;
const bool is_full;
static constexpr const bool skip_not_intersected = true; /// skip index for right blocks
const size_t max_joined_block_rows;
const size_t max_rows_in_right_block;
const size_t max_files_to_merge;
TemporaryVolumeSettings temp_vol_settings;
void changeLeftColumns(Block & block, MutableColumns && columns) const;
void addRightColumns(Block & block, MutableColumns && columns);

View File

@ -9,8 +9,6 @@
#include <Common/StringUtils/StringUtils.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataStreams/materializeBlock.h>
namespace DB
{
@ -20,33 +18,13 @@ namespace ErrorCodes
extern const int TYPE_MISMATCH;
}
TableJoin::TableJoin(const Settings & settings, VolumePtr tmp_volume_)
: size_limits(SizeLimits{settings.max_rows_in_join, settings.max_bytes_in_join, settings.join_overflow_mode})
, default_max_bytes(settings.default_max_bytes_in_join)
, join_use_nulls(settings.join_use_nulls)
, max_joined_block_rows(settings.max_joined_block_size_rows)
, join_algorithm(settings.join_algorithm)
, partial_merge_join_optimizations(settings.partial_merge_join_optimizations)
, partial_merge_join_rows_in_right_blocks(settings.partial_merge_join_rows_in_right_blocks)
, partial_merge_join_left_table_buffer_bytes(settings.partial_merge_join_left_table_buffer_bytes)
, max_files_to_merge(settings.join_on_disk_max_files_to_merge)
TableJoin::TableJoin(const ASTTableJoin & table_join_ast, const Settings & settings, VolumePtr tmp_volume_)
: join_info(table_join_ast, settings)
, temporary_files_codec(settings.temporary_files_codec)
, tmp_volume(tmp_volume_)
{
}
void TableJoin::resetCollected()
{
key_names_left.clear();
key_names_right.clear();
key_asts_left.clear();
key_asts_right.clear();
columns_from_joined_table.clear();
columns_added_by_join.clear();
original_names.clear();
renames.clear();
}
void TableJoin::addUsingKey(const ASTPtr & ast)
{
key_names_left.push_back(ast->getColumnName());
@ -72,7 +50,7 @@ void TableJoin::addOnKeys(ASTPtr & left_table_ast, ASTPtr & right_table_ast)
/// @return how many times right key appears in ON section.
size_t TableJoin::rightKeyInclusion(const String & name) const
{
if (hasUsing())
if (!hasOn())
return 0;
size_t count = 0;
@ -167,66 +145,18 @@ NamesWithAliases TableJoin::getRequiredColumns(const Block & sample, const Names
return getNamesWithAliases(required_columns);
}
void TableJoin::splitAdditionalColumns(const Block & sample_block, Block & block_keys, Block & block_others) const
{
block_others = materializeBlock(sample_block);
for (const String & column_name : key_names_right)
{
/// Extract right keys with correct keys order. There could be the same key names.
if (!block_keys.has(column_name))
{
auto & col = block_others.getByName(column_name);
block_keys.insert(col);
block_others.erase(column_name);
}
}
}
Block TableJoin::getRequiredRightKeys(const Block & right_table_keys, std::vector<String> & keys_sources) const
{
const Names & left_keys = keyNamesLeft();
const Names & right_keys = keyNamesRight();
NameSet required_keys(requiredRightKeys().begin(), requiredRightKeys().end());
Block required_right_keys;
for (size_t i = 0; i < right_keys.size(); ++i)
{
const String & right_key_name = right_keys[i];
if (required_keys.count(right_key_name) && !required_right_keys.has(right_key_name))
{
const auto & right_key = right_table_keys.getByName(right_key_name);
required_right_keys.insert(right_key);
keys_sources.push_back(left_keys[i]);
}
}
return required_right_keys;
}
bool TableJoin::leftBecomeNullable(const DataTypePtr & column_type) const
{
return forceNullableLeft() && column_type->canBeInsideNullable();
}
bool TableJoin::rightBecomeNullable(const DataTypePtr & column_type) const
{
return forceNullableRight() && column_type->canBeInsideNullable();
}
void TableJoin::addJoinedColumn(const NameAndTypePair & joined_column)
{
DataTypePtr type = joined_column.type;
if (hasUsing())
if (!hasOn())
{
if (auto it = right_type_map.find(joined_column.name); it != right_type_map.end())
type = it->second;
}
if (rightBecomeNullable(type))
if (join_info.forceNullableRight() && type->canBeInsideNullable())
type = makeNullable(joined_column.type);
columns_added_by_join.emplace_back(joined_column.name, type);
@ -249,12 +179,12 @@ void TableJoin::addJoinedColumnsAndCorrectTypes(ColumnsWithTypeAndName & columns
{
for (auto & col : columns)
{
if (hasUsing())
if (!hasOn())
{
if (auto it = left_type_map.find(col.name); it != left_type_map.end())
col.type = it->second;
}
if (correct_nullability && leftBecomeNullable(col.type))
if (correct_nullability && join_info.forceNullableLeft() && col.type->canBeInsideNullable())
{
/// No need to nullify constants
bool is_column_const = col.column && isColumnConst(*col.column);
@ -268,53 +198,26 @@ void TableJoin::addJoinedColumnsAndCorrectTypes(ColumnsWithTypeAndName & columns
columns.emplace_back(nullptr, col.type, col.name);
}
bool TableJoin::sameStrictnessAndKind(ASTTableJoin::Strictness strictness_, ASTTableJoin::Kind kind_) const
{
if (strictness_ == strictness() && kind_ == kind())
return true;
/// Compatibility: old ANY INNER == new SEMI LEFT
if (strictness_ == ASTTableJoin::Strictness::Semi && isLeft(kind_) &&
strictness() == ASTTableJoin::Strictness::RightAny && isInner(kind()))
return true;
if (strictness() == ASTTableJoin::Strictness::Semi && isLeft(kind()) &&
strictness_ == ASTTableJoin::Strictness::RightAny && isInner(kind_))
return true;
return false;
}
bool TableJoin::allowMergeJoin() const
{
bool is_any = (strictness() == ASTTableJoin::Strictness::Any);
bool is_all = (strictness() == ASTTableJoin::Strictness::All);
bool is_semi = (strictness() == ASTTableJoin::Strictness::Semi);
bool all_join = is_all && (isInner(kind()) || isLeft(kind()) || isRight(kind()) || isFull(kind()));
bool special_left = isLeft(kind()) && (is_any || is_semi);
return all_join || special_left;
}
bool TableJoin::needStreamWithNonJoinedRows() const
{
if (strictness() == ASTTableJoin::Strictness::Asof ||
strictness() == ASTTableJoin::Strictness::Semi)
if (join_info.strictness == ASTTableJoin::Strictness::Asof ||
join_info.strictness == ASTTableJoin::Strictness::Semi)
return false;
return isRightOrFull(kind());
return isRightOrFull(join_info.kind);
}
bool TableJoin::allowDictJoin(const String & dict_key, const Block & sample_block, Names & src_names, NamesAndTypesList & dst_columns) const
{
/// Support ALL INNER, [ANY | ALL | SEMI | ANTI] LEFT
if (!isLeft(kind()) && !(isInner(kind()) && strictness() == ASTTableJoin::Strictness::All))
if (!isLeft(join_info.kind) && !(isInner(join_info.kind) && join_info.strictness == ASTTableJoin::Strictness::All))
return false;
const Names & right_keys = keyNamesRight();
if (right_keys.size() != 1)
if (key_names_right.size() != 1)
return false;
/// TODO: support 'JOIN ... ON expr(dict_key) = table_key'
auto it_key = original_names.find(right_keys[0]);
auto it_key = original_names.find(key_names_right[0]);
if (it_key == original_names.end())
return false;
@ -323,7 +226,7 @@ bool TableJoin::allowDictJoin(const String & dict_key, const Block & sample_bloc
for (const auto & col : sample_block)
{
if (col.name == right_keys[0])
if (col.name == key_names_right[0])
continue; /// do not extract key column
auto it = original_names.find(col.name);
@ -338,10 +241,13 @@ bool TableJoin::allowDictJoin(const String & dict_key, const Block & sample_bloc
return true;
}
bool TableJoin::applyJoinKeyConvert(const ColumnsWithTypeAndName & left_sample_columns, const ColumnsWithTypeAndName & right_sample_columns)
bool TableJoin::applyJoinKeyConvert(const ColumnsWithTypeAndName & left_sample_columns,
const ColumnsWithTypeAndName & right_sample_columns,
ActionsDAGPtr & left_converting_actions,
ActionsDAGPtr & right_converting_actions)
{
bool need_convert = needConvert();
if (!need_convert && !hasUsing())
bool need_convert = !left_type_map.empty();
if (hasOn() && !need_convert)
{
/// For `USING` we already inferred common type an syntax analyzer stage
NamesAndTypesList left_list;
@ -356,8 +262,8 @@ bool TableJoin::applyJoinKeyConvert(const ColumnsWithTypeAndName & left_sample_c
if (need_convert)
{
left_converting_actions = applyKeyConvertToTable(left_sample_columns, left_type_map, key_names_left);
right_converting_actions = applyKeyConvertToTable(right_sample_columns, right_type_map, key_names_right);
left_converting_actions = JoinCommon::applyKeyConvertToTable(left_sample_columns, left_type_map, hasOn(), key_names_left);
right_converting_actions = JoinCommon::applyKeyConvertToTable(right_sample_columns, right_type_map, hasOn(), key_names_right);
}
return need_convert;
@ -415,31 +321,38 @@ bool TableJoin::inferJoinKeyCommonType(const NamesAndTypesList & left, const Nam
return !left_type_map.empty();
}
ActionsDAGPtr TableJoin::applyKeyConvertToTable(
const ColumnsWithTypeAndName & cols_src, const NameToTypeMap & type_mapping, Names & names_to_rename) const
JoinInfo TableJoin::getJoinInfo() const
{
ColumnsWithTypeAndName cols_dst = cols_src;
for (auto & col : cols_dst)
{
if (auto it = type_mapping.find(col.name); it != type_mapping.end())
{
col.type = it->second;
col.column = nullptr;
}
}
JoinInfo res = join_info;
res.key_names_right = key_names_right;
res.key_names_left = key_names_left;
res.required_right_keys = requiredRightKeys();
return res;
}
NameToNameMap key_column_rename;
/// Returns converting actions for tables that need to be performed before join
auto dag = ActionsDAG::makeConvertingActions(
cols_src, cols_dst, ActionsDAG::MatchColumnsMode::Name, true, !hasUsing(), &key_column_rename);
for (auto & name : names_to_rename)
{
const auto it = key_column_rename.find(name);
if (it != key_column_rename.end())
name = it->second;
}
return dag;
JoinInfo::JoinInfo(const ASTTableJoin & table_join_ast, const Settings & settings)
: join_use_nulls(settings.join_use_nulls)
, join_algorithm(settings.join_algorithm)
, max_joined_block_rows(settings.max_joined_block_size_rows)
, partial_merge_join_optimizations(settings.partial_merge_join_optimizations)
, partial_merge_join_rows_in_right_blocks(settings.partial_merge_join_rows_in_right_blocks)
, partial_merge_join_left_table_buffer_bytes(settings.partial_merge_join_left_table_buffer_bytes)
, max_files_to_merge(settings.join_on_disk_max_files_to_merge)
, size_limits(SizeLimits{settings.max_rows_in_join, settings.max_bytes_in_join, settings.join_overflow_mode})
{
kind = table_join_ast.kind;
strictness = table_join_ast.strictness;
if (table_join_ast.using_expression_list)
match_expression = JoinInfo::MatchExpressionType::JoinUsing;
if (table_join_ast.on_expression)
match_expression = JoinInfo::MatchExpressionType::JoinOn;
if (!size_limits.hasLimits())
size_limits.max_bytes = settings.default_max_bytes_in_join;
}
}

View File

@ -32,6 +32,67 @@ struct Settings;
class IVolume;
using VolumePtr = std::shared_ptr<IVolume>;
struct JoinInfo
{
JoinInfo() = default;
JoinInfo(const ASTTableJoin & table_join_ast, const Settings & settings);
JoinInfo(SizeLimits limits, bool use_nulls, ASTTableJoin::Kind kind_, ASTTableJoin::Strictness strictness_)
: kind(kind_), strictness(strictness_), join_use_nulls(use_nulls), join_algorithm(JoinAlgorithm::HASH), size_limits(limits)
{
}
/// for StorageJoin
JoinInfo(
SizeLimits limits, bool use_nulls, ASTTableJoin::Kind kind_, ASTTableJoin::Strictness strictness_, const Names & key_names_right_)
: kind(kind_)
, strictness(strictness_)
, key_names_right(key_names_right_)
, join_use_nulls(use_nulls)
, join_algorithm(JoinAlgorithm::HASH)
, size_limits(limits)
{
}
enum class MatchExpressionType
{
JoinUsing,
JoinOn
};
ASTTableJoin::Kind kind;
ASTTableJoin::Strictness strictness;
MatchExpressionType match_expression;
Names key_names_left;
Names key_names_right;
NameSet required_right_keys;
const bool join_use_nulls = false;
JoinAlgorithm join_algorithm = JoinAlgorithm::AUTO;
ASOF::Inequality asof_inequality = ASOF::Inequality::GreaterOrEquals;
/// Settings
const size_t max_joined_block_rows = 0;
const bool partial_merge_join_optimizations = false;
const size_t partial_merge_join_rows_in_right_blocks = 0;
const size_t partial_merge_join_left_table_buffer_bytes = 0;
const size_t max_files_to_merge = 0;
SizeLimits size_limits;
bool forceNullableRight() const { return join_use_nulls && isLeftOrFull(kind); }
bool forceNullableLeft() const { return join_use_nulls && isRightOrFull(kind); }
bool preferMergeJoin() const { return join_algorithm == JoinAlgorithm::PREFER_PARTIAL_MERGE; }
bool forceMergeJoin() const { return join_algorithm == JoinAlgorithm::PARTIAL_MERGE; }
bool forceHashJoin() const { return join_algorithm == JoinAlgorithm::HASH; }
bool hasUsing() const { return match_expression == JoinInfo::MatchExpressionType::JoinUsing; }
bool hasOn() const { return match_expression == JoinInfo::MatchExpressionType::JoinOn; }
};
class TableJoin
{
@ -53,24 +114,13 @@ private:
friend class TreeRewriter;
const SizeLimits size_limits;
const size_t default_max_bytes = 0;
const bool join_use_nulls = false;
const size_t max_joined_block_rows = 0;
JoinAlgorithm join_algorithm = JoinAlgorithm::AUTO;
const bool partial_merge_join_optimizations = false;
const size_t partial_merge_join_rows_in_right_blocks = 0;
const size_t partial_merge_join_left_table_buffer_bytes = 0;
const size_t max_files_to_merge = 0;
const String temporary_files_codec = "LZ4";
JoinInfo join_info;
Names key_names_left;
Names key_names_right; /// Duplicating names are qualified.
ASTs key_asts_left;
ASTs key_asts_right;
ASTTableJoin table_join;
ASOF::Inequality asof_inequality = ASOF::Inequality::GreaterOrEquals;
/// All columns which can be read from joined table. Duplicating names are qualified.
NamesAndTypesList columns_from_joined_table;
@ -82,70 +132,32 @@ private:
NameToTypeMap left_type_map;
NameToTypeMap right_type_map;
ActionsDAGPtr left_converting_actions;
ActionsDAGPtr right_converting_actions;
/// Name -> original name. Names are the same as in columns_from_joined_table list.
std::unordered_map<String, String> original_names;
/// Original name -> name. Only renamed columns.
std::unordered_map<String, String> renames;
const String temporary_files_codec = "LZ4";
VolumePtr tmp_volume;
Names requiredJoinedNames() const;
/// Create converting actions and change key column names if required
ActionsDAGPtr applyKeyConvertToTable(
const ColumnsWithTypeAndName & cols_src, const NameToTypeMap & type_mapping, Names & names_to_rename) const;
public:
TableJoin() = default;
TableJoin(const Settings &, VolumePtr tmp_volume);
/// for StorageJoin
TableJoin(SizeLimits limits, bool use_nulls, ASTTableJoin::Kind kind, ASTTableJoin::Strictness strictness,
const Names & key_names_right_)
: size_limits(limits)
, default_max_bytes(0)
, join_use_nulls(use_nulls)
, join_algorithm(JoinAlgorithm::HASH)
, key_names_right(key_names_right_)
{
table_join.kind = kind;
table_join.strictness = strictness;
}
TableJoin(const ASTTableJoin & table_join_ast, const Settings & settings, VolumePtr tmp_volume_);
StoragePtr joined_storage;
std::shared_ptr<DictionaryReader> dictionary_reader;
ASTTableJoin::Kind kind() const { return table_join.kind; }
ASTTableJoin::Strictness strictness() const { return table_join.strictness; }
bool sameStrictnessAndKind(ASTTableJoin::Strictness, ASTTableJoin::Kind) const;
const SizeLimits & sizeLimits() const { return size_limits; }
VolumePtr getTemporaryVolume() { return tmp_volume; }
bool allowMergeJoin() const;
std::pair<VolumePtr, String> getTemporaryVolume() const { return std::make_pair(tmp_volume, temporary_files_codec) ; }
bool allowDictJoin(const String & dict_key, const Block & sample_block, Names &, NamesAndTypesList &) const;
bool preferMergeJoin() const { return join_algorithm == JoinAlgorithm::PREFER_PARTIAL_MERGE; }
bool forceMergeJoin() const { return join_algorithm == JoinAlgorithm::PARTIAL_MERGE; }
bool forceHashJoin() const { return join_algorithm == JoinAlgorithm::HASH; }
bool forceNullableRight() const { return join_use_nulls && isLeftOrFull(table_join.kind); }
bool forceNullableLeft() const { return join_use_nulls && isRightOrFull(table_join.kind); }
size_t defaultMaxBytes() const { return default_max_bytes; }
size_t maxJoinedBlockRows() const { return max_joined_block_rows; }
size_t maxRowsInRightBlock() const { return partial_merge_join_rows_in_right_blocks; }
size_t maxBytesInLeftBuffer() const { return partial_merge_join_left_table_buffer_bytes; }
size_t maxFilesToMerge() const { return max_files_to_merge; }
const String & temporaryFilesCodec() const { return temporary_files_codec; }
bool enablePartialMergeJoinOptimizations() const { return partial_merge_join_optimizations; }
bool needStreamWithNonJoinedRows() const;
void resetCollected();
void addUsingKey(const ASTPtr & ast);
void addOnKeys(ASTPtr & left_table_ast, ASTPtr & right_table_ast);
bool hasUsing() const { return table_join.using_expression_list != nullptr; }
bool hasOn() const { return table_join.on_expression != nullptr; }
bool hasOn() const { return join_info.hasOn(); }
NamesWithAliases getNamesWithAliases(const NameSet & required_columns) const;
NamesWithAliases getRequiredColumns(const Block & sample, const Names & action_required_columns) const;
@ -154,8 +166,6 @@ public:
size_t rightKeyInclusion(const String & name) const;
NameSet requiredRightKeys() const;
bool leftBecomeNullable(const DataTypePtr & column_type) const;
bool rightBecomeNullable(const DataTypePtr & column_type) const;
void addJoinedColumn(const NameAndTypePair & joined_column);
void addJoinedColumnsAndCorrectTypes(NamesAndTypesList & names_and_types, bool correct_nullability = true) const;
@ -164,25 +174,19 @@ public:
/// Calculates common supertypes for corresponding join key columns.
bool inferJoinKeyCommonType(const NamesAndTypesList & left, const NamesAndTypesList & right);
/// Calculate converting actions, rename key columns in required
/// For `USING` join we will convert key columns inplace and affect into types in the result table
/// For `JOIN ON` we will create new columns with converted keys to join by.
bool applyJoinKeyConvert(const ColumnsWithTypeAndName & left_sample_columns, const ColumnsWithTypeAndName & right_sample_columns);
bool applyJoinKeyConvert(const ColumnsWithTypeAndName & left_sample_columns,
const ColumnsWithTypeAndName & right_sample_columns,
ActionsDAGPtr & left_converting_actions,
ActionsDAGPtr & right_converting_actions);
bool needConvert() const { return !left_type_map.empty(); }
JoinInfo getJoinInfo() const;
/// Key columns should be converted before join.
ActionsDAGPtr leftConvertingActions() const { return left_converting_actions; }
ActionsDAGPtr rightConvertingActions() const { return right_converting_actions; }
void setAsofInequality(ASOF::Inequality inequality) { asof_inequality = inequality; }
ASOF::Inequality getAsofInequality() { return asof_inequality; }
void setAsofInequality(ASOF::Inequality inequality) { join_info.asof_inequality = inequality; }
ASTPtr leftKeysList() const;
ASTPtr rightKeysList() const; /// For ON syntax only
const Names & keyNamesLeft() const { return key_names_left; }
const Names & keyNamesRight() const { return key_names_right; }
const NamesAndTypesList & columnsFromJoinedTable() const { return columns_from_joined_table; }
Names columnsAddedByJoin() const
{
@ -191,13 +195,6 @@ public:
res.push_back(col.name);
return res;
}
/// StorageJoin overrides key names (cause of different names qualification)
void setRightKeys(const Names & keys) { key_names_right = keys; }
/// Split key and other columns by keys name list
void splitAdditionalColumns(const Block & sample_block, Block & block_keys, Block & block_others) const;
Block getRequiredRightKeys(const Block & right_table_keys, std::vector<String> & keys_sources) const;
};
}

View File

@ -16,6 +16,7 @@
#include <Interpreters/RequiredSourceColumnsVisitor.h>
#include <Interpreters/GetAggregatesVisitor.h>
#include <Interpreters/TableJoin.h>
#include <Interpreters/JoinedTables.h>
#include <Interpreters/ExpressionActions.h> /// getSmallestColumn()
#include <Interpreters/getTableExpressions.h>
#include <Interpreters/TreeOptimizer.h>
@ -361,21 +362,17 @@ void getArrayJoinedColumns(ASTPtr & query, TreeRewriterResult & result, const AS
}
}
void setJoinStrictness(ASTSelectQuery & select_query, JoinStrictness join_default_strictness, bool old_any, ASTTableJoin & out_table_join)
std::pair<ASTTableJoin::Kind, ASTTableJoin::Strictness>
getJoinStrictness(ASTTableJoin::Kind kind, ASTTableJoin::Strictness strictness, JoinStrictness join_default_strictness, bool old_any)
{
const ASTTablesInSelectQueryElement * node = select_query.join();
if (!node)
return;
auto & table_join = const_cast<ASTTablesInSelectQueryElement *>(node)->table_join->as<ASTTableJoin &>();
if (table_join.strictness == ASTTableJoin::Strictness::Unspecified &&
table_join.kind != ASTTableJoin::Kind::Cross)
if (strictness == ASTTableJoin::Strictness::Unspecified &&
kind != ASTTableJoin::Kind::Cross)
{
if (join_default_strictness == JoinStrictness::ANY)
table_join.strictness = ASTTableJoin::Strictness::Any;
strictness = ASTTableJoin::Strictness::Any;
else if (join_default_strictness == JoinStrictness::ALL)
table_join.strictness = ASTTableJoin::Strictness::All;
strictness = ASTTableJoin::Strictness::All;
else
throw Exception("Expected ANY or ALL in JOIN section, because setting (join_default_strictness) is empty",
DB::ErrorCodes::EXPECTED_ALL_OR_ANY);
@ -383,24 +380,22 @@ void setJoinStrictness(ASTSelectQuery & select_query, JoinStrictness join_defaul
if (old_any)
{
if (table_join.strictness == ASTTableJoin::Strictness::Any &&
table_join.kind == ASTTableJoin::Kind::Inner)
if (strictness == ASTTableJoin::Strictness::Any && kind == ASTTableJoin::Kind::Inner)
{
table_join.strictness = ASTTableJoin::Strictness::Semi;
table_join.kind = ASTTableJoin::Kind::Left;
strictness = ASTTableJoin::Strictness::Semi;
kind = ASTTableJoin::Kind::Left;
}
if (table_join.strictness == ASTTableJoin::Strictness::Any)
table_join.strictness = ASTTableJoin::Strictness::RightAny;
if (strictness == ASTTableJoin::Strictness::Any)
strictness = ASTTableJoin::Strictness::RightAny;
}
else
{
if (table_join.strictness == ASTTableJoin::Strictness::Any)
if (table_join.kind == ASTTableJoin::Kind::Full)
if (strictness == ASTTableJoin::Strictness::Any)
if (kind == ASTTableJoin::Kind::Full)
throw Exception("ANY FULL JOINs are not implemented.", ErrorCodes::NOT_IMPLEMENTED);
}
out_table_join = table_join;
return std::make_pair(kind, strictness);
}
/// Find the columns that are obtained by JOIN.
@ -775,13 +770,14 @@ TreeRewriterResultPtr TreeRewriter::analyzeSelect(
ASTPtr & query,
TreeRewriterResult && result,
const SelectQueryOptions & select_options,
const std::vector<TableWithColumnNamesAndTypes> & tables_with_columns,
const Names & required_result_columns,
std::shared_ptr<TableJoin> table_join) const
const JoinedTables * joined_tables) const
{
auto * select_query = query->as<ASTSelectQuery>();
if (!select_query)
throw Exception("Select analyze for not select asts.", ErrorCodes::LOGICAL_ERROR);
throw Exception("Select analyze for not select ASTs.", ErrorCodes::LOGICAL_ERROR);
const auto & tables_with_columns = joined_tables ? joined_tables->tablesWithColumns() : std::vector<TableWithColumnNamesAndTypes>{};
size_t subquery_depth = select_options.subquery_depth;
bool remove_duplicates = select_options.remove_duplicates;
@ -790,24 +786,9 @@ TreeRewriterResultPtr TreeRewriter::analyzeSelect(
const NameSet & source_columns_set = result.source_columns_set;
if (table_join)
{
result.analyzed_join = table_join;
result.analyzed_join->resetCollected();
}
else /// TODO: remove. For now ExpressionAnalyzer expects some not empty object here
result.analyzed_join = std::make_shared<TableJoin>();
if (remove_duplicates)
renameDuplicatedColumns(select_query);
if (tables_with_columns.size() > 1)
{
result.analyzed_join->columns_from_joined_table = tables_with_columns[1].columns;
result.analyzed_join->deduplicateAndQualifyColumnNames(
source_columns_set, tables_with_columns[1].table.getQualifiedNamePrefix());
}
translateQualifiedNames(query, *select_query, source_columns_set, tables_with_columns);
/// Optimizes logical expressions.
@ -829,8 +810,24 @@ TreeRewriterResultPtr TreeRewriter::analyzeSelect(
/// array_join_alias_to_name, array_join_result_to_source.
getArrayJoinedColumns(query, result, select_query, result.source_columns, source_columns_set);
setJoinStrictness(*select_query, settings.join_default_strictness, settings.any_join_distinct_right_table_keys,
result.analyzed_join->table_join);
if (const ASTTablesInSelectQueryElement * join_node = select_query->join())
{
auto & table_join_ast = join_node->table_join->as<ASTTableJoin &>();
std::tie(table_join_ast.kind, table_join_ast.strictness) = getJoinStrictness(
table_join_ast.kind, table_join_ast.strictness, settings.join_default_strictness, settings.any_join_distinct_right_table_keys);
}
if (joined_tables)
result.analyzed_join = joined_tables->makeTableJoin(*select_query);
else
result.analyzed_join = std::make_shared<TableJoin>();
if (tables_with_columns.size() > 1)
{
result.analyzed_join->columns_from_joined_table = tables_with_columns[1].columns;
result.analyzed_join->deduplicateAndQualifyColumnNames(
source_columns_set, tables_with_columns[1].table.getQualifiedNamePrefix());
}
ASTPtr new_where_condition = nullptr;
collectJoinedColumns(*result.analyzed_join, *select_query, tables_with_columns, result.aliases, new_where_condition);
@ -852,7 +849,7 @@ TreeRewriterResultPtr TreeRewriter::analyzeSelect(
result.optimize_trivial_count = settings.optimize_trivial_count_query &&
!select_query->groupBy() && !select_query->having() &&
!select_query->sampleSize() && !select_query->sampleOffset() && !select_query->final() &&
(tables_with_columns.size() < 2 || isLeft(result.analyzed_join->kind()));
(tables_with_columns.size() < 2 || isLeft(result.analyzed_join->join_info.kind));
return std::make_shared<const TreeRewriterResult>(result);
}

View File

@ -13,6 +13,7 @@ namespace DB
class ASTFunction;
struct ASTTablesInSelectQueryElement;
class TableJoin;
class JoinedTables;
class Context;
struct Settings;
struct SelectQueryOptions;
@ -112,9 +113,8 @@ public:
ASTPtr & query,
TreeRewriterResult && result,
const SelectQueryOptions & select_options = {},
const std::vector<TableWithColumnNamesAndTypes> & tables_with_columns = {},
const Names & required_result_columns = {},
std::shared_ptr<TableJoin> table_join = {}) const;
const JoinedTables * joined_tables = nullptr) const;
private:
const Context & context;

View File

@ -254,20 +254,23 @@ void createMissedColumns(Block & block)
}
/// Append totals from right to left block, correct types if needed
void joinTotals(const Block & totals, const Block & columns_to_add, const TableJoin & table_join, Block & block)
void joinTotals(const Block & totals, const Block & columns_to_add, const JoinInfo & join_info, Block & block)
{
if (table_join.forceNullableLeft())
if (join_info.forceNullableLeft())
convertColumnsToNullable(block);
if (Block totals_without_keys = totals)
{
for (const auto & name : table_join.keyNamesRight())
for (const auto & name : join_info.key_names_right)
totals_without_keys.erase(totals_without_keys.getPositionByName(name));
for (auto & col : totals_without_keys)
if (join_info.forceNullableRight())
{
if (table_join.rightBecomeNullable(col.type))
JoinCommon::convertColumnToNullable(col);
for (auto & col : totals_without_keys)
{
if (col.type->canBeInsideNullable())
JoinCommon::convertColumnToNullable(col);
}
}
for (size_t i = 0; i < totals_without_keys.columns(); ++i)
@ -302,10 +305,77 @@ bool typesEqualUpToNullability(DataTypePtr left_type, DataTypePtr right_type)
return left_type_strict->equals(*right_type_strict);
}
ActionsDAGPtr applyKeyConvertToTable(
const ColumnsWithTypeAndName & cols_src, const NameToTypeMap & type_mapping, bool replace_columns, Names & names_to_rename)
{
ColumnsWithTypeAndName cols_dst = cols_src;
for (auto & col : cols_dst)
{
if (auto it = type_mapping.find(col.name); it != type_mapping.end())
{
col.type = it->second;
col.column = nullptr;
}
}
NameToNameMap key_column_rename;
/// Returns converting actions for tables that need to be performed before join
auto dag = ActionsDAG::makeConvertingActions(
cols_src, cols_dst, ActionsDAG::MatchColumnsMode::Name, true, !replace_columns, &key_column_rename);
for (auto & name : names_to_rename)
{
const auto it = key_column_rename.find(name);
if (it != key_column_rename.end())
name = it->second;
}
return dag;
}
void splitAdditionalColumns(const Names & key_names_right, const Block & sample_block, Block & block_keys, Block & block_others)
{
block_others = materializeBlock(sample_block);
for (const String & column_name : key_names_right)
{
/// Extract right keys with correct keys order. There could be the same key names.
if (!block_keys.has(column_name))
{
auto & col = block_others.getByName(column_name);
block_keys.insert(col);
block_others.erase(column_name);
}
}
}
Block getRequiredRightKeys(
const Names & left_keys,
const Names & right_keys,
const NameSet & required_keys,
const Block & right_table_keys,
std::vector<String> & keys_sources)
{
Block required_right_keys;
for (size_t i = 0; i < right_keys.size(); ++i)
{
const String & right_key_name = right_keys[i];
if (required_keys.count(right_key_name) && !required_right_keys.has(right_key_name))
{
const auto & right_key = right_table_keys.getByName(right_key_name);
required_right_keys.insert(right_key);
keys_sources.push_back(left_keys[i]);
}
}
return required_right_keys;
}
}
NotJoined::NotJoined(const TableJoin & table_join, const Block & saved_block_sample_, const Block & right_sample_block,
NotJoined::NotJoined(const JoinInfo & join_info, const Block & saved_block_sample_, const Block & right_sample_block,
const Block & result_sample_block_)
: saved_block_sample(saved_block_sample_)
, result_sample_block(materializeBlock(result_sample_block_))
@ -313,17 +383,19 @@ NotJoined::NotJoined(const TableJoin & table_join, const Block & saved_block_sam
std::vector<String> tmp;
Block right_table_keys;
Block sample_block_with_columns_to_add;
table_join.splitAdditionalColumns(right_sample_block, right_table_keys, sample_block_with_columns_to_add);
Block required_right_keys = table_join.getRequiredRightKeys(right_table_keys, tmp);
JoinCommon::splitAdditionalColumns(join_info.key_names_right, right_sample_block, right_table_keys, sample_block_with_columns_to_add);
Block required_right_keys = JoinCommon::getRequiredRightKeys(
join_info.key_names_left, join_info.key_names_right, join_info.required_right_keys, right_table_keys, tmp);
std::unordered_map<size_t, size_t> left_to_right_key_remap;
if (table_join.hasUsing())
if (join_info.hasUsing())
{
for (size_t i = 0; i < table_join.keyNamesLeft().size(); ++i)
for (size_t i = 0; i < join_info.key_names_left.size(); ++i)
{
const String & left_key_name = table_join.keyNamesLeft()[i];
const String & right_key_name = table_join.keyNamesRight()[i];
const String & left_key_name = join_info.key_names_left[i];
const String & right_key_name = join_info.key_names_right[i];
size_t left_key_pos = result_sample_block.getPositionByName(left_key_name);
size_t right_key_pos = saved_block_sample.getPositionByName(right_key_name);

View File

@ -4,14 +4,17 @@
#include <Interpreters/IJoin.h>
#include <Interpreters/ActionsDAG.h>
#include <Interpreters/ExpressionActions.h>
#include <DataStreams/materializeBlock.h>
namespace DB
{
struct ColumnWithTypeAndName;
class TableJoin;
struct JoinInfo;
class IColumn;
using ColumnRawPtrs = std::vector<const IColumn *>;
using NameToTypeMap = std::unordered_map<String, DataTypePtr>;
namespace JoinCommon
{
@ -34,19 +37,31 @@ ColumnRawPtrs extractKeysForJoin(const Block & block_keys, const Names & key_nam
void checkTypesOfKeys(const Block & block_left, const Names & key_names_left, const Block & block_right, const Names & key_names_right);
void createMissedColumns(Block & block);
void joinTotals(const Block & totals, const Block & columns_to_add, const TableJoin & table_join, Block & block);
void joinTotals(const Block & totals, const Block & columns_to_add, const JoinInfo & table_join, Block & block);
void addDefaultValues(IColumn & column, const DataTypePtr & type, size_t count);
bool typesEqualUpToNullability(DataTypePtr left_type, DataTypePtr right_type);
/// Calculate converting actions, rename key columns in required
/// For `USING` join we will convert key columns inplace and affect into types in the result table
/// For `JOIN ON` we will create new columns with converted keys to join by.
ActionsDAGPtr applyKeyConvertToTable(
const ColumnsWithTypeAndName & cols_src, const NameToTypeMap & type_mapping, bool replace_columns, Names & names_to_rename);
void splitAdditionalColumns(const Names & key_names_right, const Block & sample_block, Block & block_keys, Block & block_others);
Block getRequiredRightKeys(const Names & left_keys, const Names & right_keys,
const NameSet & required_keys, const Block & right_table_keys, std::vector<String> & keys_sources);
}
/// Creates result from right table data in RIGHT and FULL JOIN when keys are not present in left table.
class NotJoined
{
public:
NotJoined(const TableJoin & table_join, const Block & saved_block_sample_, const Block & right_sample_block,
NotJoined(const JoinInfo & join_info, const Block & saved_block_sample_, const Block & right_sample_block,
const Block & result_sample_block_);
void correctLowcardAndNullability(MutableColumns & columns_right);

View File

@ -35,6 +35,28 @@ namespace ErrorCodes
extern const int BAD_ARGUMENTS;
}
namespace
{
bool sameStrictnessAndKind(ASTTableJoin::Strictness strictness_left, ASTTableJoin::Kind kind_left,
ASTTableJoin::Strictness strictness_right, ASTTableJoin::Kind kind_right)
{
if (strictness_right == strictness_left && kind_right == kind_left)
return true;
/// Compatibility: old ANY INNER == new SEMI LEFT
if (strictness_right == ASTTableJoin::Strictness::Semi && isLeft(kind_right) &&
strictness_left == ASTTableJoin::Strictness::RightAny && isInner(kind_left))
return true;
if (strictness_left == ASTTableJoin::Strictness::Semi && isLeft(kind_left) &&
strictness_right == ASTTableJoin::Strictness::RightAny && isInner(kind_right))
return true;
return false;
}
}
StorageJoin::StorageJoin(
DiskPtr disk_,
const String & relative_path_,
@ -55,14 +77,14 @@ StorageJoin::StorageJoin(
, kind(kind_)
, strictness(strictness_)
, overwrite(overwrite_)
, join_info(limits, use_nulls, kind, strictness, key_names)
{
auto metadata_snapshot = getInMemoryMetadataPtr();
for (const auto & key : key_names)
if (!metadata_snapshot->getColumns().hasPhysical(key))
throw Exception{"Key column (" + key + ") does not exist in table declaration.", ErrorCodes::NO_SUCH_COLUMN_IN_TABLE};
table_join = std::make_shared<TableJoin>(limits, use_nulls, kind, strictness, key_names);
join = std::make_shared<HashJoin>(table_join, metadata_snapshot->getSampleBlock().sortColumns(), overwrite);
join = std::make_shared<HashJoin>(join_info, metadata_snapshot->getSampleBlock().sortColumns(), nullptr, overwrite);
restore();
}
@ -75,27 +97,26 @@ void StorageJoin::truncate(
disk->createDirectories(path + "tmp/");
increment = 0;
join = std::make_shared<HashJoin>(table_join, metadata_snapshot->getSampleBlock().sortColumns(), overwrite);
join = std::make_shared<HashJoin>(join_info, metadata_snapshot->getSampleBlock().sortColumns(), nullptr, overwrite);
}
HashJoinPtr StorageJoin::getJoinLocked(std::shared_ptr<TableJoin> analyzed_join) const
HashJoinPtr StorageJoin::getJoinLocked(JoinInfo info) const
{
auto metadata_snapshot = getInMemoryMetadataPtr();
if (!analyzed_join->sameStrictnessAndKind(strictness, kind))
if (!sameStrictnessAndKind(info.strictness, info.kind, strictness, kind))
throw Exception("Table " + getStorageID().getNameForLogs() + " has incompatible type of JOIN.", ErrorCodes::INCOMPATIBLE_TYPE_OF_JOIN);
if ((analyzed_join->forceNullableRight() && !use_nulls) ||
(!analyzed_join->forceNullableRight() && isLeftOrFull(analyzed_join->kind()) && use_nulls))
if ((info.forceNullableRight() && !use_nulls) ||
(!info.forceNullableRight() && isLeftOrFull(info.kind) && use_nulls))
throw Exception("Table " + getStorageID().getNameForLogs() + " needs the same join_use_nulls setting as present in LEFT or FULL JOIN.",
ErrorCodes::INCOMPATIBLE_TYPE_OF_JOIN);
/// TODO: check key columns
/// Some HACK to remove wrong names qualifiers: table.column -> column.
analyzed_join->setRightKeys(key_names);
info.key_names_right = key_names;
HashJoinPtr join_clone = std::make_shared<HashJoin>(analyzed_join, metadata_snapshot->getSampleBlock().sortColumns());
HashJoinPtr join_clone = std::make_shared<HashJoin>(std::move(info), metadata_snapshot->getSampleBlock().sortColumns());
join_clone->setLock(rwlock);
join_clone->reuseJoinedData(*join);

View File

@ -3,6 +3,7 @@
#include <ext/shared_ptr_helper.h>
#include <Storages/StorageSet.h>
#include <Interpreters/TableJoin.h>
#include <Storages/JoinSettings.h>
#include <Parsers/ASTTablesInSelectQuery.h>
@ -31,7 +32,7 @@ public:
/// Return instance of HashJoin holding lock that protects from insertions to StorageJoin.
/// HashJoin relies on structure of hash table that's why we need to return it with locked mutex.
HashJoinPtr getJoinLocked(std::shared_ptr<TableJoin> analyzed_join) const;
HashJoinPtr getJoinLocked(JoinInfo join_info) const;
/// Get result type for function "joinGet(OrNull)"
DataTypePtr joinGetCheckAndGetReturnType(const DataTypes & data_types, const String & column_name, bool or_null) const;
@ -62,9 +63,8 @@ private:
ASTTableJoin::Strictness strictness; /// ANY | ALL
bool overwrite;
std::shared_ptr<TableJoin> table_join;
HashJoinPtr join;
JoinInfo join_info;
/// Protect state for concurrent use in insertFromBlock and joinBlock.
/// Lock is stored in HashJoin instance during query and blocks concurrent insertions.
mutable std::shared_mutex rwlock;