Clean up TableJoin storage join

This commit is contained in:
vdimir 2021-06-29 12:22:53 +03:00
parent 13bf141e1d
commit ed8c156190
No known key found for this signature in database
GPG Key ID: F57B3E10A21DBB31
5 changed files with 104 additions and 68 deletions

View File

@ -850,14 +850,6 @@ JoinPtr SelectQueryExpressionAnalyzer::appendJoin(ExpressionActionsChain & chain
return table_join;
}
static JoinPtr tryGetStorageJoin(std::shared_ptr<TableJoin> analyzed_join)
{
if (auto * table = analyzed_join->joined_storage.get())
if (auto * storage_join = dynamic_cast<StorageJoin *>(table))
return storage_join->getJoinLocked(analyzed_join);
return {};
}
static ActionsDAGPtr createJoinedBlockActions(ContextPtr context, const TableJoin & analyzed_join)
{
ASTPtr expression_list = analyzed_join.rightKeysList();
@ -865,44 +857,13 @@ static ActionsDAGPtr createJoinedBlockActions(ContextPtr context, const TableJoi
return ExpressionAnalyzer(expression_list, syntax_result, context).getActionsDAG(true, false);
}
static bool allowDictJoin(StoragePtr joined_storage, ContextPtr context, String & dict_name, String & key_name)
static std::shared_ptr<IJoin> chooseJoinAlgorithm(std::shared_ptr<TableJoin> analyzed_join, const Block & sample_block, ContextPtr context)
{
if (!joined_storage->isDictionary())
return false;
StorageDictionary & storage_dictionary = static_cast<StorageDictionary &>(*joined_storage);
dict_name = storage_dictionary.getDictionaryName();
auto dictionary = context->getExternalDictionariesLoader().getDictionary(dict_name, context);
if (!dictionary)
return false;
const DictionaryStructure & structure = dictionary->getStructure();
if (structure.id)
{
key_name = structure.id->name;
return true;
}
return false;
}
static std::shared_ptr<IJoin> makeJoin(std::shared_ptr<TableJoin> analyzed_join, const Block & sample_block, ContextPtr context)
{
bool allow_merge_join = analyzed_join->allowMergeJoin();
/// HashJoin with Dictionary optimisation
String dict_name;
String key_name;
if (analyzed_join->joined_storage && allowDictJoin(analyzed_join->joined_storage, context, dict_name, key_name))
{
Names original_names;
NamesAndTypesList result_columns;
if (analyzed_join->allowDictJoin(key_name, sample_block, original_names, result_columns))
{
analyzed_join->dictionary_reader = std::make_shared<DictionaryReader>(dict_name, original_names, result_columns, context);
return std::make_shared<HashJoin>(analyzed_join, sample_block);
}
}
if (analyzed_join->tryInitDictJoin(sample_block, context))
return std::make_shared<HashJoin>(analyzed_join, sample_block);
bool allow_merge_join = analyzed_join->allowMergeJoin();
if (analyzed_join->forceHashJoin() || (analyzed_join->preferMergeJoin() && !allow_merge_join))
return std::make_shared<HashJoin>(analyzed_join, sample_block);
else if (analyzed_join->forceMergeJoin() || (analyzed_join->preferMergeJoin() && allow_merge_join))
@ -963,7 +924,7 @@ std::unique_ptr<QueryPlan> buildJoinedPlan(
if (auto right_actions = analyzed_join.rightConvertingActions())
{
auto converting_step = std::make_unique<ExpressionStep>(joined_plan->getCurrentDataStream(), analyzed_join.rightConvertingActions());
auto converting_step = std::make_unique<ExpressionStep>(joined_plan->getCurrentDataStream(), right_actions);
converting_step->setStepDescription("Convert joined columns");
joined_plan->addStep(std::move(converting_step));
}
@ -979,21 +940,18 @@ JoinPtr SelectQueryExpressionAnalyzer::makeTableJoin(
if (joined_plan)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Table join was already created for query");
/// Use StorageJoin if any.
JoinPtr join = tryGetStorageJoin(syntax->analyzed_join);
if (join)
if (auto storage = syntax->analyzed_join->getStorageJoin())
{
syntax->analyzed_join->createConvertingActions(left_sample_columns, {});
return join;
return storage->getJoinLocked(syntax->analyzed_join);
}
joined_plan = buildJoinedPlan(getContext(), join_element, left_sample_columns, *syntax->analyzed_join, query_options);
join = makeJoin(syntax->analyzed_join, joined_plan->getCurrentDataStream().header, getContext());
JoinPtr join = chooseJoinAlgorithm(syntax->analyzed_join, joined_plan->getCurrentDataStream().header, getContext());
/// Do not make subquery for join over dictionary.
if (syntax->analyzed_join->dictionary_reader)
if (syntax->analyzed_join->getDictionaryReader())
joined_plan.reset();
return join;

View File

@ -211,7 +211,7 @@ HashJoin::HashJoin(std::shared_ptr<TableJoin> table_join_, const Block & right_s
if (nullable_right_side)
JoinCommon::convertColumnsToNullable(sample_block_with_columns_to_add);
if (table_join->dictionary_reader)
if (table_join->getDictionaryReader())
{
LOG_DEBUG(log, "Performing join over dict");
data->type = Type::DICT;
@ -331,7 +331,8 @@ public:
KeyGetterForDict(const TableJoin & table_join, const ColumnRawPtrs & key_columns)
{
table_join.dictionary_reader->readKeys(*key_columns[0], read_result, found, positions);
assert(table_join.getDictionaryReader());
table_join.getDictionaryReader()->readKeys(*key_columns[0], read_result, found, positions);
for (ColumnWithTypeAndName & column : read_result)
if (table_join.rightBecomeNullable(column.type))

View File

@ -299,16 +299,17 @@ std::shared_ptr<TableJoin> JoinedTables::makeTableJoin(const ASTSelectQuery & se
if (table_to_join.database_and_table_name)
{
auto joined_table_id = context->resolveStorageID(table_to_join.database_and_table_name);
StoragePtr table = DatabaseCatalog::instance().tryGetTable(joined_table_id, context);
if (table)
StoragePtr storage = DatabaseCatalog::instance().tryGetTable(joined_table_id, context);
if (storage)
{
if (dynamic_cast<StorageJoin *>(table.get()) ||
dynamic_cast<StorageDictionary *>(table.get()))
table_join->joined_storage = table;
if (auto storage_join = std::dynamic_pointer_cast<StorageJoin>(storage); storage_join)
table_join->setStorageJoin(storage_join);
else if (auto storage_dict = std::dynamic_pointer_cast<StorageDictionary>(storage); storage_dict)
table_join->setStorageJoin(storage_dict);
}
}
if (!table_join->joined_storage &&
if (!table_join->isSpecialStorage() &&
settings.enable_optimize_predicate_expression)
replaceJoinedTable(select_query);

View File

@ -1,5 +1,6 @@
#include <Interpreters/TableJoin.h>
#include <Common/StringUtils/StringUtils.h>
#include <Core/Block.h>
@ -8,12 +9,23 @@
#include <DataTypes/DataTypeNullable.h>
#include <Dictionaries/DictionaryStructure.h>
#include <Interpreters/DictionaryReader.h>
#include <Interpreters/ExternalDictionariesLoader.h>
#include <Interpreters/TableJoin.h>
#include <Parsers/ASTExpressionList.h>
#include <Parsers/ASTExpressionList.h>
#include <Parsers/ASTFunction.h>
#include <Parsers/queryToString.h>
#include <common/logger_useful.h>
#include <Storages/IStorage.h>
#include <Storages/StorageDictionary.h>
#include <Storages/StorageJoin.h>
#include <common/logger_useful.h>
namespace DB
{
@ -21,6 +33,7 @@ namespace DB
namespace ErrorCodes
{
extern const int TYPE_MISMATCH;
extern const int LOGICAL_ERROR;
}
namespace
@ -269,7 +282,7 @@ void TableJoin::addJoinedColumnsAndCorrectTypes(NamesAndTypesList & left_columns
* For `JOIN ON expr1 == expr2` we will infer common type later in makeTableJoin,
* when part of plan built and types of expression will be known.
*/
inferJoinKeyCommonType(left_columns, columns_from_joined_table, joined_storage != nullptr);
inferJoinKeyCommonType(left_columns, columns_from_joined_table, !isSpecialStorage());
if (auto it = left_type_map.find(col.name); it != left_type_map.end())
col.type = it->second;
@ -318,7 +331,18 @@ bool TableJoin::needStreamWithNonJoinedRows() const
return isRightOrFull(kind());
}
bool TableJoin::allowDictJoin(const String & dict_key, const Block & sample_block, Names & src_names, NamesAndTypesList & dst_columns) const
static std::optional<String> getDictKeyName(const String & dict_name , ContextPtr context)
{
auto dictionary = context->getExternalDictionariesLoader().getDictionary(dict_name, context);
if (!dictionary)
return {};
if (const auto & structure = dictionary->getStructure(); structure.id)
return structure.id->name;
return {};
}
bool TableJoin::tryInitDictJoin(const Block & sample_block, ContextPtr context)
{
/// Support ALL INNER, [ANY | ALL | SEMI | ANTI] LEFT
if (!isLeft(kind()) && !(isInner(kind()) && strictness() == ASTTableJoin::Strictness::All))
@ -333,9 +357,17 @@ bool TableJoin::allowDictJoin(const String & dict_key, const Block & sample_bloc
if (it_key == original_names.end())
return false;
if (dict_key != it_key->second)
if (!right_storage_dictionary)
return false;
auto dict_name = right_storage_dictionary->getName();
auto dict_key = getDictKeyName(dict_name, context);
if (!dict_key.has_value() || *dict_key != it_key->second)
return false; /// JOIN key != Dictionary key
Names src_names;
NamesAndTypesList dst_columns;
for (const auto & col : sample_block)
{
if (col.name == right_keys[0])
@ -349,6 +381,7 @@ bool TableJoin::allowDictJoin(const String & dict_key, const Block & sample_bloc
dst_columns.push_back({col.name, col.type});
}
}
dictionary_reader = std::make_shared<DictionaryReader>(dict_name, src_names, dst_columns, context);
return true;
}
@ -356,7 +389,7 @@ bool TableJoin::allowDictJoin(const String & dict_key, const Block & sample_bloc
bool TableJoin::createConvertingActions(const ColumnsWithTypeAndName & left_sample_columns, const ColumnsWithTypeAndName & right_sample_columns)
{
bool need_convert = false;
need_convert = inferJoinKeyCommonType(left_sample_columns, right_sample_columns, joined_storage == nullptr);
need_convert = inferJoinKeyCommonType(left_sample_columns, right_sample_columns, !isSpecialStorage());
left_converting_actions = applyKeyConvertToTable(left_sample_columns, left_type_map, key_names_left);
right_converting_actions = applyKeyConvertToTable(right_sample_columns, right_type_map, key_names_right);
@ -458,6 +491,26 @@ ActionsDAGPtr TableJoin::applyKeyConvertToTable(
return dag;
}
void TableJoin::setStorageJoin(std::shared_ptr<StorageJoin> storage)
{
if (right_storage_dictionary)
throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "StorageJoin and Dictionary join are mutually exclusive");
right_storage_join = storage;
}
void TableJoin::setStorageJoin(std::shared_ptr<StorageDictionary> storage)
{
if (right_storage_join)
throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "StorageJoin and Dictionary join are mutually exclusive");
right_storage_dictionary = storage;
}
std::shared_ptr<StorageJoin> TableJoin::getStorageJoin()
{
return right_storage_join;
}
String TableJoin::renamedRightColumnName(const String & name) const
{
if (const auto it = renames.find(name); it != renames.end())
@ -527,4 +580,14 @@ std::pair<String, String> TableJoin::joinConditionColumnNames() const
return res;
}
bool TableJoin::isSpecialStorage() const
{
return right_storage_dictionary || right_storage_join;
}
const DictionaryReader * TableJoin::getDictionaryReader() const
{
return dictionary_reader.get();
}
}

View File

@ -24,6 +24,8 @@ class ASTSelectQuery;
struct DatabaseAndTableWithAlias;
class Block;
class DictionaryReader;
class StorageJoin;
class StorageDictionary;
struct ColumnWithTypeAndName;
using ColumnsWithTypeAndName = std::vector<ColumnWithTypeAndName>;
@ -104,6 +106,11 @@ private:
VolumePtr tmp_volume;
std::shared_ptr<StorageJoin> right_storage_join;
std::shared_ptr<StorageDictionary> right_storage_dictionary;
std::shared_ptr<DictionaryReader> dictionary_reader;
Names requiredJoinedNames() const;
/// Create converting actions and change key column names if required
@ -133,16 +140,12 @@ public:
table_join.strictness = strictness;
}
StoragePtr joined_storage;
std::shared_ptr<DictionaryReader> dictionary_reader;
ASTTableJoin::Kind kind() const { return table_join.kind; }
ASTTableJoin::Strictness strictness() const { return table_join.strictness; }
bool sameStrictnessAndKind(ASTTableJoin::Strictness, ASTTableJoin::Kind) const;
const SizeLimits & sizeLimits() const { return size_limits; }
VolumePtr getTemporaryVolume() { return tmp_volume; }
bool allowMergeJoin() const;
bool allowDictJoin(const String & dict_key, const Block & sample_block, Names &, NamesAndTypesList &) const;
bool preferMergeJoin() const { return join_algorithm == JoinAlgorithm::PREFER_PARTIAL_MERGE; }
bool forceMergeJoin() const { return join_algorithm == JoinAlgorithm::PARTIAL_MERGE; }
bool forceHashJoin() const
@ -233,6 +236,16 @@ public:
String renamedRightColumnName(const String & name) const;
std::unordered_map<String, String> leftToRightKeyRemap() const;
void setStorageJoin(std::shared_ptr<StorageJoin> storage);
void setStorageJoin(std::shared_ptr<StorageDictionary> storage);
std::shared_ptr<StorageJoin> getStorageJoin();
bool tryInitDictJoin(const Block & sample_block, ContextPtr context);
bool isSpecialStorage() const;
const DictionaryReader * getDictionaryReader() const;
};
}