Store all related to one join disjunct in JoinOnClause, pt1

This commit is contained in:
vdimir 2021-09-02 14:40:04 +03:00 committed by Ilya Golshtein
parent 8339cfc8e1
commit 8e2637aab2
8 changed files with 130 additions and 109 deletions

View File

@ -41,11 +41,6 @@ void CollectJoinOnKeysMatcher::Data::setDisjuncts(const ASTPtr & or_func_ast)
analyzed_join.setDisjuncts(std::move(v));
}
void CollectJoinOnKeysMatcher::Data::addDisjunct(const ASTPtr & ast)
{
analyzed_join.addDisjunct(std::move(ast));
}
void CollectJoinOnKeysMatcher::Data::addJoinKeys(const ASTPtr & left_ast, const ASTPtr & right_ast, JoinIdentifierPosPair table_pos)
{
ASTPtr left = left_ast->clone();
@ -107,7 +102,7 @@ void CollectJoinOnKeysMatcher::visit(const ASTFunction & func, const ASTPtr & as
return;
}
data.addDisjunct(ast);
data.analyzed_join.addDisjunct(ast);
if (func.name == "and")
return; /// go into children

View File

@ -52,7 +52,6 @@ public:
void addAsofJoinKeys(const ASTPtr & left_ast, const ASTPtr & right_ast, JoinIdentifierPosPair table_pos,
const ASOF::Inequality & asof_inequality);
void setDisjuncts(const ASTPtr & or_func_ast);
void addDisjunct(const ASTPtr & ast);
void asofToJoinKeys();
};

View File

@ -1,4 +1,3 @@
#include <common/logger_useful.h>
#include <Interpreters/LogicalExpressionsOptimizer.h>
#include <Core/Settings.h>

View File

@ -507,7 +507,9 @@ MergeJoin::MergeJoin(std::shared_ptr<TableJoin> table_join_, const Block & right
ErrorCodes::PARAMETER_OUT_OF_BOUND);
}
if (table_join->keyNamesLeft().size() > 1)
const auto & key_names_left_all = table_join->keyNamesLeft();
const auto & key_names_right_all = table_join->keyNamesRight();
if (key_names_left_all.size() != 1 || key_names_right_all.size() != 1)
throw Exception("MergeJoin does not support OR", ErrorCodes::NOT_IMPLEMENTED);
std::tie(mask_column_name_left, mask_column_name_right) = table_join->joinConditionColumnNames(0);
@ -522,8 +524,8 @@ MergeJoin::MergeJoin(std::shared_ptr<TableJoin> table_join_, const Block & right
key_names_right.push_back(deriveTempName(mask_column_name_right));
}
key_names_left.insert(key_names_left.end(), table_join->keyNamesLeft().front().begin(), table_join->keyNamesLeft().front().end());
key_names_right.insert(key_names_right.end(), table_join->keyNamesRight().front().begin(), table_join->keyNamesRight().front().end());
key_names_left.insert(key_names_left.end(), key_names_left_all.front().begin(), key_names_left_all.front().end());
key_names_right.insert(key_names_right.end(), key_names_right_all.front().begin(), key_names_right_all.front().end());
addConditionJoinColumn(right_sample_block, JoinTableSide::Right);
JoinCommon::splitAdditionalColumns(NamesVector{key_names_right}, right_sample_block, right_table_keys, right_columns_to_add);

View File

@ -62,22 +62,17 @@ TableJoin::TableJoin(const Settings & settings, VolumePtr tmp_volume_)
, partial_merge_join_left_table_buffer_bytes(settings.partial_merge_join_left_table_buffer_bytes)
, max_files_to_merge(settings.join_on_disk_max_files_to_merge)
, temporary_files_codec(settings.temporary_files_codec)
, key_names_left(1)
, key_names_right(1)
, on_filter_condition_asts_left(1)
, on_filter_condition_asts_right(1)
, left_clauses(1)
, right_clauses(1)
, tmp_volume(tmp_volume_)
{
}
void TableJoin::resetCollected()
{
key_names_left.clear();
key_names_right.clear();
key_asts_left.clear();
key_asts_right.clear();
on_filter_condition_asts_left.clear();
on_filter_condition_asts_right.clear();
left_clauses = std::vector<JoinOnClause>(1);
right_clauses = std::vector<JoinOnClause>(1);
columns_from_joined_table.clear();
columns_added_by_join.clear();
original_names.clear();
@ -92,33 +87,21 @@ void TableJoin::resetCollected()
void TableJoin::addUsingKey(const ASTPtr & ast)
{
key_names_left.front().push_back(ast->getColumnName());
key_names_right.front().push_back(ast->getAliasOrColumnName());
key_asts_left.push_back(ast);
key_asts_right.push_back(ast);
auto & right_key = key_names_right.front().back();
if (renames.count(right_key))
right_key = renames[right_key];
left_clauses.back().addKey(ast->getColumnName(), ast);
right_clauses.back().addKey(renamedRightColumnName(ast->getAliasOrColumnName()), ast);
}
/// create new disjunct when see a child of a previously discovered OR
/// create new disjunct when see a direct child of a previously discovered OR
void TableJoin::addDisjunct(const ASTPtr & ast)
{
const IAST * addr = ast.get();
if (std::find_if(disjuncts.begin(), disjuncts.end(), [addr](const ASTPtr & ast_){return ast_.get() == addr;}) != disjuncts.end())
{
assert(key_names_left.size() == disjunct_num + 1);
if (!key_names_left[disjunct_num].empty() || !on_filter_condition_asts_left[disjunct_num].empty() || !on_filter_condition_asts_right[disjunct_num].empty())
if (!left_clauses.back().key_names.empty() || !left_clauses.back().on_filter_conditions.empty() || !right_clauses.back().on_filter_conditions.empty())
{
disjunct_num++;
key_names_left.resize(disjunct_num + 1);
key_names_right.resize(disjunct_num + 1);
on_filter_condition_asts_left.resize(disjunct_num + 1);
on_filter_condition_asts_right.resize(disjunct_num + 1);
left_clauses.emplace_back();
right_clauses.emplace_back();
}
}
}
@ -131,11 +114,8 @@ void TableJoin::setDisjuncts(Disjuncts&& disjuncts_)
void TableJoin::addOnKeys(ASTPtr & left_table_ast, ASTPtr & right_table_ast)
{
key_names_left[disjunct_num].push_back(left_table_ast->getColumnName());
key_names_right[disjunct_num].push_back(right_table_ast->getAliasOrColumnName());
key_asts_left.push_back(left_table_ast);
key_asts_right.push_back(right_table_ast);
left_clauses.back().addKey(left_table_ast->getColumnName(), left_table_ast);
right_clauses.back().addKey(right_table_ast->getAliasOrColumnName(), right_table_ast);
}
/// @return how many times right key appears in ON section.
@ -145,9 +125,8 @@ size_t TableJoin::rightKeyInclusion(const String & name) const
return 0;
size_t count = 0;
for (const auto & key_names : key_names_right)
count += std::count(key_names.begin(), key_names.end(), name);
for (const auto & clause : right_clauses)
count += std::count(clause.key_names.begin(), clause.key_names.end(), name);
return count;
}
@ -194,31 +173,39 @@ NamesWithAliases TableJoin::getNamesWithAliases(const NameSet & required_columns
ASTPtr TableJoin::leftKeysList() const
{
ASTPtr keys_list = std::make_shared<ASTExpressionList>();
keys_list->children = key_asts_left;
const size_t disjuncts_num = key_names_left.size();
for (size_t d = 0; d < disjuncts_num; ++d)
if (ASTPtr extra_cond = joinConditionColumn(JoinTableSide::Left, d))
for (size_t i = 0; i < left_clauses.size(); ++i)
{
const auto & clause = left_clauses[i];
keys_list->children.insert(keys_list->children.end(), clause.key_asts.begin(), clause.key_asts.end());
if (ASTPtr extra_cond = joinConditionColumn(JoinTableSide::Left, i))
keys_list->children.push_back(extra_cond);
}
return keys_list;
}
ASTPtr TableJoin::rightKeysList() const
{
ASTPtr keys_list = std::make_shared<ASTExpressionList>();
for (size_t i = 0; i < right_clauses.size(); ++i)
{
if (hasOn())
keys_list->children = key_asts_right;
const size_t disjuncts_num = key_names_left.size();
for (size_t d = 0; d < disjuncts_num; ++d)
if (ASTPtr extra_cond = joinConditionColumn(JoinTableSide::Right, d))
{
const auto & clause = right_clauses[i];
keys_list->children.insert(keys_list->children.end(), clause.key_asts.begin(), clause.key_asts.end());
}
if (ASTPtr extra_cond = joinConditionColumn(JoinTableSide::Right, i))
keys_list->children.push_back(extra_cond);
}
return keys_list;
}
Names TableJoin::requiredJoinedNames() const
{
NameSet required_columns_set;
for (const auto& key_names_right_part : key_names_right)
required_columns_set.insert(key_names_right_part.begin(), key_names_right_part.end());
for (const auto & clause : right_clauses)
required_columns_set.insert(clause.key_names.begin(), clause.key_names.end());
for (const auto & joined_column : columns_added_by_join)
required_columns_set.insert(joined_column.name);
@ -228,9 +215,9 @@ Names TableJoin::requiredJoinedNames() const
NameSet TableJoin::requiredRightKeys() const
{
NameSet required;
for (const auto & key_names_right_part : key_names_right)
for (const auto & clause : right_clauses)
{
for (const auto & name : key_names_right_part)
for (const auto & name : clause.key_names)
{
auto rename = renamedRightColumnName(name);
for (const auto & column : columns_added_by_join)
@ -369,7 +356,7 @@ bool TableJoin::allowMergeJoin() const
bool all_join = is_all && (isInner(kind()) || isLeft(kind()) || isRight(kind()) || isFull(kind()));
bool special_left = isLeft(kind()) && (is_any || is_semi);
bool no_ors = (key_names_right.size() == 1);
bool no_ors = (left_clauses.size() == 1);
return (all_join || special_left) && no_ors;
}
@ -407,7 +394,7 @@ bool TableJoin::tryInitDictJoin(const Block & sample_block, ContextPtr context)
if (!allowed_inner && !allowed_left)
return false;
const Names & right_keys = keyNamesRight().front();
const Names & right_keys = right_clauses.front().key_names;
if (right_keys.size() != 1)
return false;
@ -470,12 +457,14 @@ bool TableJoin::inferJoinKeyCommonType(const LeftNamesAndTypes & left, const Rig
for (const auto & col : right)
right_types[renamedRightColumnName(col.name)] = col.type;
for (size_t d = 0; d < key_names_left.size(); ++d)
for (size_t d = 0; d < left_clauses.size(); ++d)
{
for (size_t i = 0; i < key_names_left[d].size(); ++i)
auto & key_names_left = left_clauses[d].key_names;
auto & key_names_right = right_clauses[d].key_names;
for (size_t i = 0; i < key_names_left.size(); ++i)
{
auto ltype = left_types.find(key_names_left[d][i]);
auto rtype = right_types.find(key_names_right[d][i]);
auto ltype = left_types.find(key_names_left[i]);
auto rtype = right_types.find(key_names_right[i]);
if (ltype == left_types.end() || rtype == right_types.end())
{
/// Name mismatch, give up
@ -495,13 +484,14 @@ bool TableJoin::inferJoinKeyCommonType(const LeftNamesAndTypes & left, const Rig
}
catch (DB::Exception & ex)
{
throw DB::Exception(ErrorCodes::TYPE_MISMATCH,
"Can't infer common type for joined columns: {}: {} at left, {}: {} at right. {}",
key_names_left[d][i], ltype->second->getName(),
key_names_right[d][i], rtype->second->getName(),
ex.message());
throw Exception(
"Type mismatch of columns to JOIN by: " +
key_names_left[d][i] + ": " + ltype->second->getName() + " at left, " +
key_names_right[d][i] + ": " + rtype->second->getName() + " at right. " +
"Can't get supertype: " + ex.message(),
ErrorCodes::TYPE_MISMATCH);
}
left_type_map[key_names_left[d][i]] = right_type_map[key_names_right[d][i]] = supertype;
left_type_map[key_names_left[i]] = right_type_map[key_names_right[i]] = supertype;
}
}
@ -518,7 +508,7 @@ bool TableJoin::inferJoinKeyCommonType(const LeftNamesAndTypes & left, const Rig
}
ActionsDAGPtr TableJoin::applyKeyConvertToTable(
const ColumnsWithTypeAndName & cols_src, const NameToTypeMap & type_mapping, NamesVector & names_vector_to_rename) const
const ColumnsWithTypeAndName & cols_src, const NameToTypeMap & type_mapping, std::vector<JoinOnClause> & join_clause) const
{
bool has_some_to_do = false;
@ -540,9 +530,9 @@ ActionsDAGPtr TableJoin::applyKeyConvertToTable(
auto dag = ActionsDAG::makeConvertingActions(
cols_src, cols_dst, ActionsDAG::MatchColumnsMode::Name, true, !hasUsing(), &key_column_rename);
for (auto & disjunct_names : names_vector_to_rename)
for (auto & clause : join_clause)
{
for (auto & name : disjunct_names)
for (auto & name : clause.key_names)
{
const auto it = key_column_rename.find(name);
if (it != key_column_rename.end())
@ -577,9 +567,9 @@ String TableJoin::renamedRightColumnName(const String & name) const
void TableJoin::addJoinCondition(const ASTPtr & ast, bool is_left)
{
if (is_left)
on_filter_condition_asts_left[disjunct_num].push_back(ast);
left_clauses.back().on_filter_conditions.push_back(ast);
else
on_filter_condition_asts_right[disjunct_num].push_back(ast);
right_clauses.back().on_filter_conditions.push_back(ast);
}
void TableJoin::leftToRightKeyRemap(
@ -607,33 +597,34 @@ std::unordered_map<String, String> TableJoin::leftToRightKeyRemap() const
if (hasUsing())
{
const auto & required_right_keys = requiredRightKeys();
for (size_t i = 0; i < key_names_left.size(); ++i)
TableJoin::leftToRightKeyRemap(key_names_left[i], key_names_right[i], required_right_keys, left_to_right_key_remap);
for (size_t i = 0; i < left_clauses.size(); ++i)
TableJoin::leftToRightKeyRemap(left_clauses[i].key_names, right_clauses[i].key_names, required_right_keys, left_to_right_key_remap);
}
return left_to_right_key_remap;
}
/// Returns all conditions related to one table joined with 'and' function
static ASTPtr buildJoinConditionColumn(const ASTsVector & on_filter_condition_asts, size_t disjunct)
static ASTPtr buildJoinConditionColumn(const ASTs & on_filter_condition_asts)
{
if (on_filter_condition_asts[disjunct].empty())
if (on_filter_condition_asts.empty())
return nullptr;
if (on_filter_condition_asts[disjunct].size() == 1)
return on_filter_condition_asts[disjunct][0];
if (on_filter_condition_asts.size() == 1)
return on_filter_condition_asts[0];
auto function = std::make_shared<ASTFunction>();
function->name = "and";
function->arguments = std::make_shared<ASTExpressionList>();
function->children.push_back(function->arguments);
function->arguments->children = on_filter_condition_asts[disjunct];
function->arguments->children = on_filter_condition_asts;
return function;
}
ASTPtr TableJoin::joinConditionColumn(JoinTableSide side, size_t disjunct) const
{
if (side == JoinTableSide::Left)
return buildJoinConditionColumn(on_filter_condition_asts_left, disjunct);
return buildJoinConditionColumn(on_filter_condition_asts_right, disjunct);
return buildJoinConditionColumn(left_clauses[disjunct].on_filter_conditions);
return buildJoinConditionColumn(right_clauses[disjunct].on_filter_conditions);
}
std::pair<String, String> TableJoin::joinConditionColumnNames(size_t disjunct) const

View File

@ -74,16 +74,31 @@ private:
const size_t max_files_to_merge = 0;
const String temporary_files_codec = "LZ4";
NamesVector key_names_left;
NamesVector key_names_right; /// Duplicating names are qualified.
ASTsVector on_filter_condition_asts_left;
ASTsVector on_filter_condition_asts_right;
private:
size_t disjunct_num = 0;
/// Corresponds to one disjunct
struct JoinOnClause
{
Names key_names;
ASTs key_asts;
ASTs on_filter_conditions;
JoinOnClause() = default;
explicit JoinOnClause(const Names & names)
: key_names(names)
{}
void addKey(const String & name, const ASTPtr & ast)
{
key_names.emplace_back(name);
key_asts.emplace_back(ast);
}
};
Disjuncts disjuncts;
ASTs key_asts_left;
ASTs key_asts_right;
std::vector<JoinOnClause> left_clauses;
std::vector<JoinOnClause> right_clauses; /// Duplicating key_names are qualified.
ASTTableJoin table_join;
@ -116,7 +131,7 @@ private:
/// Create converting actions and change key column names if required
ActionsDAGPtr applyKeyConvertToTable(
const ColumnsWithTypeAndName & cols_src, const NameToTypeMap & type_mapping, NamesVector & names_vector_to_rename) const;
const ColumnsWithTypeAndName & cols_src, const NameToTypeMap & type_mapping, std::vector<JoinOnClause> & join_clause) const;
/// Calculates common supertypes for corresponding join key columns.
template <typename LeftNamesAndTypes, typename RightNamesAndTypes>
@ -131,10 +146,8 @@ private:
public:
TableJoin()
: key_names_left(1)
, key_names_right(1)
, on_filter_condition_asts_left(1)
, on_filter_condition_asts_right(1)
: left_clauses(1)
, right_clauses(1)
{
}
@ -142,16 +155,14 @@ public:
/// for StorageJoin
TableJoin(SizeLimits limits, bool use_nulls, ASTTableJoin::Kind kind, ASTTableJoin::Strictness strictness,
const NamesVector & key_names_right_)
const Names & key_names_right)
: size_limits(limits)
, default_max_bytes(0)
, join_use_nulls(use_nulls)
, join_algorithm(JoinAlgorithm::HASH)
, key_names_left(1)
, key_names_right(key_names_right_)
, on_filter_condition_asts_left(1)
, on_filter_condition_asts_right(1)
, left_clauses(1)
{
right_clauses.emplace_back(key_names_right);
table_join.kind = kind;
table_join.strictness = strictness;
}
@ -232,8 +243,26 @@ public:
ASTPtr leftKeysList() const;
ASTPtr rightKeysList() const; /// For ON syntax only
const NamesVector & keyNamesLeft() const { return key_names_left; }
const NamesVector & keyNamesRight() const { return key_names_right; }
NamesVector keyNamesLeft() const
{
NamesVector key_names;
for (const auto & clause : left_clauses)
{
key_names.push_back(clause.key_names);
}
return key_names;
}
NamesVector keyNamesRight() const
{
NamesVector key_names;
for (const auto & clause : right_clauses)
{
key_names.push_back(clause.key_names);
}
return key_names;
}
const NamesAndTypesList & columnsFromJoinedTable() const { return columns_from_joined_table; }
Names columnsAddedByJoin() const
@ -245,7 +274,12 @@ public:
}
/// StorageJoin overrides key names (cause of different names qualification)
void setRightKeys(const Names & keys) { key_names_right.clear(); key_names_right.push_back(keys); }
void setRightKeys(const Names & keys)
{
// assert(right_clauses.size() <= 1);
right_clauses.clear();
right_clauses.emplace_back(keys);
}
Block getRequiredRightKeys(const Block & right_table_keys, std::vector<String> & keys_sources) const;

View File

@ -202,7 +202,8 @@ size_t tryPushDownFilter(QueryPlan::Node * parent_node, QueryPlan::Nodes & nodes
const auto & left_header = join->getInputStreams().front().header;
const auto & res_header = join->getOutputStream().header;
Names allowed_keys;
for (const auto & name : table_join.keyNamesLeft().front())
const auto & key_names_left = table_join.keyNamesLeft();
for (const auto & name : key_names_left.front())
{
/// Skip key if it is renamed.
/// I don't know if it is possible. Just in case.

View File

@ -62,7 +62,7 @@ StorageJoin::StorageJoin(
if (!metadata_snapshot->getColumns().hasPhysical(key))
throw Exception{"Key column (" + key + ") does not exist in table declaration.", ErrorCodes::NO_SUCH_COLUMN_IN_TABLE};
table_join = std::make_shared<TableJoin>(limits, use_nulls, kind, strictness, NamesVector{key_names});
table_join = std::make_shared<TableJoin>(limits, use_nulls, kind, strictness, key_names);
join = std::make_shared<HashJoin>(table_join, metadata_snapshot->getSampleBlock().sortColumns(), overwrite);
restore();
}