Minor changes related to JOIN ON ORs

This commit is contained in:
vdimir 2021-09-10 17:52:44 +03:00 committed by Ilya Golshtein
parent 300eb5098d
commit 71b6c9414c
13 changed files with 59 additions and 96 deletions

View File

@ -13,7 +13,6 @@ namespace DB
using NullMap = ColumnUInt8::Container;
using ConstNullMapPtr = const NullMap *;
using ConstNullMapPtrVector = std::vector<ConstNullMapPtr>;
/// Class that specifies nullable columns. A nullable column represents
/// a column, which may have any type, provided with the possibility of

View File

@ -260,21 +260,32 @@ HashJoin::HashJoin(std::shared_ptr<TableJoin> table_join_, const Block & right_s
, log(&Poco::Logger::get("HashJoin"))
{
LOG_DEBUG(log, "Right sample block: {}", right_sample_block.dumpStructure());
if (!table_join->oneDisjunct())
if (isComma(kind) || isCross(kind))
{
/// required right keys concept does not work well if multiple disjuncts,
/// we need all keys
sample_block_with_columns_to_add = right_table_keys = materializeBlock(right_sample_block);
data->type = Type::CROSS;
sample_block_with_columns_to_add = right_sample_block;
}
else
else if (table_join->oneDisjunct())
{
const auto & key_names_right = table_join->getOnlyClause().key_names_right;
JoinCommon::splitAdditionalColumns(key_names_right, right_sample_block, right_table_keys, sample_block_with_columns_to_add);
required_right_keys = table_join->getRequiredRightKeys(right_table_keys, required_right_keys_sources);
}
else
{
/// required right keys concept does not work well if multiple disjuncts, we need all keys
sample_block_with_columns_to_add = right_table_keys = materializeBlock(right_sample_block);
}
// LOG_DEBUG(log, "Join keys: [{}], required right: [{}]", formatKeysDebug(table_join->getClauses()), fmt::join(required_right_keys.getNames(), ", "));
LOG_DEBUG(log, "Columns to add: [{}]", sample_block_with_columns_to_add.dumpStructure());
LOG_TRACE(log, "Columns to add: [{}], required right [{}]",
sample_block_with_columns_to_add.dumpStructure(), fmt::join(required_right_keys.getNames(), ", "));
{
std::vector<String> log_text;
for (const auto & clause : table_join->getClauses())
log_text.push_back(clause.formatDebug());
LOG_TRACE(log, "Joining on: {}", fmt::join(log_text, " | "));
}
JoinCommon::removeLowCardinalityInplace(right_table_keys);

View File

@ -45,7 +45,8 @@ public:
/// Different query plan is used for such joins.
virtual bool isFilled() const { return false; }
virtual std::shared_ptr<NotJoinedBlocks> getNonJoinedBlocks(const Block &, const Block &, UInt64) const = 0;
virtual std::shared_ptr<NotJoinedBlocks>
getNonJoinedBlocks(const Block & left_sample_block, const Block & result_sample_block, UInt64 max_block_size) const = 0;
};
using JoinPtr = std::shared_ptr<IJoin>;

View File

@ -61,10 +61,10 @@ public:
return join->alwaysReturnsEmptySet();
}
std::shared_ptr<NotJoinedBlocks> getNonJoinedBlocks(
const Block & left_sample_block, const Block & result_block, UInt64 max_block_size) const override
std::shared_ptr<NotJoinedBlocks>
getNonJoinedBlocks(const Block & left_sample_block, const Block & result_sample_block, UInt64 max_block_size) const override
{
return join->getNonJoinedBlocks(left_sample_block, result_block, max_block_size);
return join->getNonJoinedBlocks(left_sample_block, result_sample_block, max_block_size);
}
private:

View File

@ -1,5 +1,4 @@
#include <Common/assert_cast.h>
#include <Columns/IColumn.h>
#include <Interpreters/NullableUtils.h>

View File

@ -101,7 +101,6 @@ TableJoin::TableJoin(const Settings & settings, VolumePtr tmp_volume_)
, partial_merge_join_left_table_buffer_bytes(settings.partial_merge_join_left_table_buffer_bytes)
, max_files_to_merge(settings.join_on_disk_max_files_to_merge)
, temporary_files_codec(settings.temporary_files_codec)
, clauses(1)
, tmp_volume(tmp_volume_)
{
}
@ -234,10 +233,8 @@ ASTPtr TableJoin::rightKeysList() const
Names TableJoin::requiredJoinedNames() const
{
NameSet required_columns_set;
for (const auto & clause : clauses)
required_columns_set.insert(clause.key_names_right.begin(), clause.key_names_right.end());
Names key_names_right = getAllNames(JoinTableSide::Right);
NameSet required_columns_set(key_names_right.begin(), key_names_right.end());
for (const auto & joined_column : columns_added_by_join)
required_columns_set.insert(joined_column.name);
@ -258,16 +255,13 @@ NameSet TableJoin::requiredRightKeys() const
return required;
}
NamesWithAliases TableJoin::getRequiredColumns(const Block & sample, const Names & action_required_columns) const
{
NameSet required_columns(action_required_columns.begin(), action_required_columns.end());
for (auto & column : requiredJoinedNames())
{
if (!sample.has(column))
required_columns.insert(column);
}
return getNamesWithAliases(required_columns);
}
@ -372,8 +366,9 @@ bool TableJoin::sameStrictnessAndKind(ASTTableJoin::Strictness strictness_, ASTT
bool TableJoin::oneDisjunct() const
{
if (!isComma(kind()) && !isCross(kind()))
assert(!clauses.empty());
return clauses.size() == 1;
return clauses.size() <= 1;
}
bool TableJoin::allowMergeJoin() const
@ -460,8 +455,7 @@ bool TableJoin::tryInitDictJoin(const Block & sample_block, ContextPtr context)
return true;
}
static void tryRename(String & name, const NameToNameMap & renames)
static void renameIfNeeded(String & name, const NameToNameMap & renames)
{
if (const auto it = renames.find(name); it != renames.end())
name = it->second;
@ -479,8 +473,8 @@ TableJoin::createConvertingActions(const ColumnsWithTypeAndName & left_sample_co
forAllKeys(clauses, [&](auto & left_key, auto & right_key)
{
tryRename(left_key, left_key_column_rename);
tryRename(right_key, right_key_column_rename);
renameIfNeeded(left_key, left_key_column_rename);
renameIfNeeded(right_key, right_key_column_rename);
return true;
});
@ -510,11 +504,11 @@ bool TableJoin::inferJoinKeyCommonType(const LeftNamesAndTypes & left, const Rig
/// Name mismatch, give up
left_type_map.clear();
right_type_map.clear();
return false; /// break;
return false;
}
if (JoinCommon::typesEqualUpToNullability(ltype->second, rtype->second))
return true; /// continue;
return true;
DataTypePtr common_type;
try
@ -600,48 +594,34 @@ String TableJoin::renamedRightColumnName(const String & name) const
return name;
}
void TableJoin::addKey(const String & left_name, const String & right_name, const ASTPtr & left_ast, const ASTPtr & right_ast)
{
clauses.back().key_names_left.emplace_back(left_name);
key_asts_left.emplace_back(left_ast);
clauses.back().key_names_right.emplace_back(right_name);
key_asts_right.emplace_back(right_ast ? right_ast : left_ast);
}
static void addJoinConditionWithAnd(ASTPtr & current_cond, const ASTPtr & new_cond)
{
if (current_cond == nullptr)
{
/// no conditions, set new one
current_cond = new_cond;
}
else if (const auto * func = current_cond->as<ASTFunction>(); func && func->name == "and")
{
/// already have `and` in condition, just add new argument
func->arguments->children.push_back(new_cond);
}
else
{
/// already have some condition, unite coditions with `and`
/// already have some conditions, unite it with `and`
current_cond = makeASTFunction("and", current_cond, new_cond);
}
}
void TableJoin::addJoinCondition(const ASTPtr & ast, bool is_left)
{
addJoinConditionWithAnd(is_left ? clauses.back().on_filter_condition_left : clauses.back().on_filter_condition_right, ast);
}
void TableJoin::leftToRightKeyRemap(
const Names & left_keys,
const Names & right_keys,
const NameSet & required_right_keys,
std::unordered_map<String, String> & key_map) const
{
if (hasUsing())
{
for (size_t i = 0; i < left_keys.size(); ++i)
{
const String & left_key_name = left_keys[i];
const String & right_key_name = right_keys[i];
if (!required_right_keys.contains(right_key_name))
key_map[left_key_name] = right_key_name;
}
}
auto & cond_ast = is_left ? clauses.back().on_filter_condition_left : clauses.back().on_filter_condition_right;
LOG_TRACE(&Poco::Logger::get("TableJoin"), "Adding join condition for {} table: {} -> {}",
(is_left ? "left" : "right"), ast ? queryToString(ast) : "NULL", cond_ast ? queryToString(cond_ast) : "NULL");
addJoinConditionWithAnd(cond_ast, ast);
}
std::unordered_map<String, String> TableJoin::leftToRightKeyRemap() const
@ -663,11 +643,11 @@ std::unordered_map<String, String> TableJoin::leftToRightKeyRemap() const
Names TableJoin::getAllNames(JoinTableSide side) const
{
Names res;
forAllKeys(clauses, [&res, side](const auto & left, const auto & right)
{
res.emplace_back(side == JoinTableSide::Left ? left : right);
return true;
});
auto func = [&res](const auto & name) { res.emplace_back(name); return true; };
if (side == JoinTableSide::Left)
forAllKeys<LeftSideTag>(clauses, func);
else
forAllKeys<RightSideTag>(clauses, func);
return res;
}

View File

@ -12,10 +12,7 @@
#include <DataTypes/getLeastSupertype.h>
#include <Storages/IStorage_fwd.h>
#include <Common/Exception.h>
#include <Parsers/IAST_fwd.h>
#include <cstddef>
#include <unordered_map>
#include <utility>
#include <memory>
#include <common/types.h>
@ -51,7 +48,6 @@ class TableJoin
public:
using NameToTypeMap = std::unordered_map<String, DataTypePtr>;
using Disjuncts = ASTs;
/// Corresponds to one disjunct
struct JoinOnClause
@ -114,13 +110,11 @@ private:
const size_t max_files_to_merge = 0;
const String temporary_files_codec = "LZ4";
std::vector<JoinOnClause> clauses;
ASTs key_asts_left;
ASTs key_asts_right;
Disjuncts disjuncts;
std::vector<JoinOnClause> clauses;
ASTTableJoin table_join;
ASOF::Inequality asof_inequality = ASOF::Inequality::GreaterOrEquals;
@ -154,26 +148,7 @@ private:
ActionsDAGPtr applyKeyConvertToTable(
const ColumnsWithTypeAndName & cols_src, const NameToTypeMap & type_mapping, NameToNameMap & key_column_rename) const;
/// Calculates common supertypes for corresponding join key columns.
template <typename LeftNamesAndTypes, typename RightNamesAndTypes>
bool inferJoinKeyCommonType(const LeftNamesAndTypes & left, const RightNamesAndTypes & right, bool allow_right);
NamesAndTypesList correctedColumnsAddedByJoin() const;
void leftToRightKeyRemap(
const Names & left_keys,
const Names & right_keys,
const NameSet & required_right_keys,
std::unordered_map<String, String> & key_map) const;
void addKey(const String & left_name, const String & right_name,
const ASTPtr & left_ast, const ASTPtr & right_ast = nullptr)
{
clauses.back().key_names_left.emplace_back(left_name);
key_asts_left.emplace_back(left_ast);
clauses.back().key_names_right.emplace_back(right_name);
key_asts_right.emplace_back(right_ast ? right_ast : left_ast);
}
void addKey(const String & left_name, const String & right_name, const ASTPtr & left_ast, const ASTPtr & right_ast = nullptr);
void assertHasOneOnExpr() const;
@ -190,9 +165,8 @@ public:
, default_max_bytes(0)
, join_use_nulls(use_nulls)
, join_algorithm(JoinAlgorithm::HASH)
, clauses(1)
{
getOnlyClause().key_names_right = key_names_right;
clauses.emplace_back().key_names_right = key_names_right;
table_join.kind = kind;
table_join.strictness = strictness;
}

View File

@ -35,7 +35,6 @@
#include <Storages/IStorage.h>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <Poco/Logger.h>
namespace DB
{

View File

@ -1,4 +1,3 @@
#include <unordered_map>
#include <Interpreters/join_common.h>
#include <Columns/ColumnLowCardinality.h>

View File

@ -16,7 +16,6 @@ class IColumn;
using ColumnRawPtrs = std::vector<const IColumn *>;
using ColumnRawPtrMap = std::unordered_map<String, const IColumn *>;
using UInt8ColumnDataPtr = const ColumnUInt8::Container *;
using UInt8ColumnDataPtrVector = std::vector<UInt8ColumnDataPtr>;
namespace JoinCommon
{

View File

@ -9,6 +9,5 @@ namespace DB
class IAST;
using ASTPtr = std::shared_ptr<IAST>;
using ASTs = std::vector<ASTPtr>;
using ASTsVector = std::vector<ASTs>;
}

View File

@ -377,6 +377,9 @@ public:
, max_block_size(max_block_size_)
, sample_block(std::move(sample_block_))
{
if (!join->getTableJoin().oneDisjunct())
throw DB::Exception(ErrorCodes::NOT_IMPLEMENTED, "StorageJoin does not support OR for keys in JOIN ON section");
column_indices.resize(sample_block.columns());
auto & saved_block = join->getJoinedData()->sample_block;