This commit is contained in:
yariks5s 2023-12-18 15:02:51 +00:00
parent 85be7cf6b8
commit 6740316a88
12 changed files with 525 additions and 9 deletions

View File

@ -13,6 +13,7 @@ const char * toString(JoinKind kind)
case JoinKind::Full: return "FULL";
case JoinKind::Cross: return "CROSS";
case JoinKind::Comma: return "COMMA";
case JoinKind::Paste: return "PASTE";
}
};

View File

@ -13,7 +13,8 @@ enum class JoinKind
Right,
Full,
Cross, /// Direct product. Strictness and condition doesn't matter.
Comma /// Same as direct product. Intended to be converted to INNER JOIN with conditions from WHERE.
Comma, /// Same as direct product. Intended to be converted to INNER JOIN with conditions from WHERE.
Paste, /// Used to join parts without `ON` clause.
};
const char * toString(JoinKind kind);
@ -27,6 +28,7 @@ inline constexpr bool isRightOrFull(JoinKind kind) { return kind == JoinKind::R
inline constexpr bool isLeftOrFull(JoinKind kind) { return kind == JoinKind::Left || kind == JoinKind::Full; }
inline constexpr bool isInnerOrRight(JoinKind kind) { return kind == JoinKind::Inner || kind == JoinKind::Right; }
inline constexpr bool isInnerOrLeft(JoinKind kind) { return kind == JoinKind::Inner || kind == JoinKind::Left; }
inline constexpr bool isPaste(JoinKind kind) { return kind == JoinKind::Paste; }
/// Allows more optimal JOIN for typical cases.
enum class JoinStrictness

View File

@ -56,6 +56,7 @@
#include <Core/Names.h>
#include <Core/NamesAndTypes.h>
#include <Common/logger_useful.h>
#include <Interpreters/PasteJoin.h>
#include <QueryPipeline/SizeLimits.h>
@ -952,6 +953,9 @@ static std::shared_ptr<IJoin> tryCreateJoin(
std::unique_ptr<QueryPlan> & joined_plan,
ContextPtr context)
{
if (analyzed_join->kind() == JoinKind::Paste)
return std::make_shared<PasteJoin>(analyzed_join, right_sample_block);
if (algorithm == JoinAlgorithm::DIRECT || algorithm == JoinAlgorithm::DEFAULT)
{
JoinPtr direct_join = tryKeyValueJoin(analyzed_join, right_sample_block);

View File

@ -1699,7 +1699,7 @@ void InterpreterSelectQuery::executeImpl(QueryPlan & query_plan, std::optional<P
return step_raw_ptr;
};
if (expressions.join->pipelineType() == JoinPipelineType::YShaped)
if (expressions.join->pipelineType() == JoinPipelineType::YShaped && expressions.join->getTableJoin().kind() != JoinKind::Paste)
{
const auto & table_join = expressions.join->getTableJoin();
const auto & join_clause = table_join.getOnlyClause();

View File

@ -0,0 +1,92 @@
#pragma once
#include <Interpreters/IJoin.h>
#include <Interpreters/TableJoin.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypeLowCardinality.h>
#include <Common/logger_useful.h>
#include <Poco/Logger.h>
namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
extern const int NOT_IMPLEMENTED;
}
/// Dummy class, actual joining is done by MergeTransform
class PasteJoin : public IJoin
{
public:
explicit PasteJoin(std::shared_ptr<TableJoin> table_join_, const Block & right_sample_block_)
: table_join(table_join_)
, right_sample_block(right_sample_block_)
{
LOG_TRACE(&Poco::Logger::get("PasteJoin"), "Will use paste join");
}
std::string getName() const override { return "PasteJoin"; }
const TableJoin & getTableJoin() const override { return *table_join; }
bool addBlockToJoin(const Block & /* block */, bool /* check_limits */) override
{
throw Exception(ErrorCodes::LOGICAL_ERROR, "PasteJoin::addBlockToJoin should not be called");
}
static bool isSupported(const std::shared_ptr<TableJoin> & table_join)
{
bool support_storage = !table_join->isSpecialStorage();
/// Key column can change nullability and it's not handled on type conversion stage, so algorithm should be aware of it
bool support_using_and_nulls = !table_join->hasUsing() || !table_join->joinUseNulls();
return support_using_and_nulls && support_storage;
}
void checkTypesOfKeys(const Block & /*left_block*/) const override
{
if (!isSupported(table_join))
throw DB::Exception(ErrorCodes::NOT_IMPLEMENTED, "PasteJoin doesn't support specified query");
}
/// Used just to get result header
void joinBlock(Block & block, std::shared_ptr<ExtraBlock> & /* not_processed */) override
{
for (const auto & col : right_sample_block)
block.insert(col);
block = materializeBlock(block).cloneEmpty();
}
void setTotals(const Block & block) override { totals = block; }
const Block & getTotals() const override { return totals; }
size_t getTotalRowCount() const override
{
throw Exception(ErrorCodes::LOGICAL_ERROR, "PasteJoin::getTotalRowCount should not be called");
}
size_t getTotalByteCount() const override
{
throw Exception(ErrorCodes::LOGICAL_ERROR, "PasteJoin::getTotalByteCount should not be called");
}
bool alwaysReturnsEmptySet() const override { return false; }
IBlocksStreamPtr
getNonJoinedBlocks(const Block & /* left_sample_block */, const Block & /* result_sample_block */, UInt64 /* max_block_size */) const override
{
throw Exception(ErrorCodes::LOGICAL_ERROR, "PasteJoin::getNonJoinedBlocks should not be called");
}
/// Left and right streams have the same priority and are processed simultaneously
JoinPipelineType pipelineType() const override { return JoinPipelineType::YShaped; }
private:
std::shared_ptr<TableJoin> table_join;
Block right_sample_block;
Block totals;
};
}

View File

@ -211,6 +211,9 @@ void ASTTableJoin::formatImplBeforeTable(const FormatSettings & settings, Format
case JoinKind::Comma:
settings.ostr << ",";
break;
case JoinKind::Paste:
settings.ostr << "PASTE JOIN";
break;
}
settings.ostr << (settings.hilite ? hilite_none : "");

View File

@ -6,6 +6,7 @@
#include <Parsers/ParserSelectQuery.h>
#include <Parsers/ParserSampleRatio.h>
#include <Parsers/ParserTablesInSelectQuery.h>
#include <Core/Joins.h>
namespace DB
@ -166,6 +167,8 @@ bool ParserTablesInSelectQueryElement::parseImpl(Pos & pos, ASTPtr & node, Expec
table_join->kind = JoinKind::Full;
else if (ParserKeyword("CROSS").ignore(pos))
table_join->kind = JoinKind::Cross;
else if (ParserKeyword("PASTE").ignore(pos))
table_join->kind = JoinKind::Paste;
else
no_kind = true;
@ -191,8 +194,8 @@ bool ParserTablesInSelectQueryElement::parseImpl(Pos & pos, ASTPtr & node, Expec
}
if (table_join->strictness != JoinStrictness::Unspecified
&& table_join->kind == JoinKind::Cross)
throw Exception(ErrorCodes::SYNTAX_ERROR, "You must not specify ANY or ALL for CROSS JOIN.");
&& (table_join->kind == JoinKind::Cross || table_join->kind == JoinKind::Paste))
throw Exception(ErrorCodes::SYNTAX_ERROR, "You must not specify ANY or ALL for {} JOIN.", toString(table_join->kind));
if ((table_join->strictness == JoinStrictness::Semi || table_join->strictness == JoinStrictness::Anti) &&
(table_join->kind != JoinKind::Left && table_join->kind != JoinKind::Right))
@ -206,7 +209,7 @@ bool ParserTablesInSelectQueryElement::parseImpl(Pos & pos, ASTPtr & node, Expec
return false;
if (table_join->kind != JoinKind::Comma
&& table_join->kind != JoinKind::Cross)
&& table_join->kind != JoinKind::Cross && table_join->kind != JoinKind::Paste)
{
if (ParserKeyword("USING").ignore(pos, expected))
{

View File

@ -0,0 +1,276 @@
#include <cassert>
#include <cstddef>
#include <limits>
#include <memory>
#include <type_traits>
#include <base/defines.h>
#include <base/types.h>
#include <Common/logger_useful.h>
#include <Columns/ColumnNullable.h>
#include <Columns/ColumnsNumber.h>
#include <Columns/IColumn.h>
#include <IO/WriteHelpers.h>
#include <Interpreters/TableJoin.h>
#include <Parsers/ASTTablesInSelectQuery.h>
#include <Processors/Transforms/PasteJoinTransform.h>
namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
extern const int NOT_IMPLEMENTED;
extern const int LOGICAL_ERROR;
}
namespace
{
template <bool has_left_nulls, bool has_right_nulls>
int nullableCompareAt(const IColumn & left_column, const IColumn & right_column, size_t lhs_pos, size_t rhs_pos, int null_direction_hint = 1)
{
if constexpr (has_left_nulls && has_right_nulls)
{
const auto * left_nullable = checkAndGetColumn<ColumnNullable>(left_column);
const auto * right_nullable = checkAndGetColumn<ColumnNullable>(right_column);
if (left_nullable && right_nullable)
{
int res = left_nullable->compareAt(lhs_pos, rhs_pos, right_column, null_direction_hint);
if (res)
return res;
/// NULL != NULL case
if (left_nullable->isNullAt(lhs_pos))
return null_direction_hint;
return 0;
}
}
if constexpr (has_left_nulls)
{
if (const auto * left_nullable = checkAndGetColumn<ColumnNullable>(left_column))
{
if (left_nullable->isNullAt(lhs_pos))
return null_direction_hint;
return left_nullable->getNestedColumn().compareAt(lhs_pos, rhs_pos, right_column, null_direction_hint);
}
}
if constexpr (has_right_nulls)
{
if (const auto * right_nullable = checkAndGetColumn<ColumnNullable>(right_column))
{
if (right_nullable->isNullAt(rhs_pos))
return -null_direction_hint;
return left_column.compareAt(lhs_pos, rhs_pos, right_nullable->getNestedColumn(), null_direction_hint);
}
}
return left_column.compareAt(lhs_pos, rhs_pos, right_column, null_direction_hint);
}
ColumnPtr replicateRow(const IColumn & column, size_t num)
{
MutableColumnPtr res = column.cloneEmpty();
res->insertManyFrom(column, 0, num);
return res;
}
template <typename TColumns>
void copyColumnsResized(const TColumns & cols, size_t start, size_t size, Chunk & result_chunk)
{
for (const auto & col : cols)
{
if (col->empty())
{
/// add defaults
result_chunk.addColumn(col->cloneResized(size));
}
else if (col->size() == 1)
{
/// copy same row n times
result_chunk.addColumn(replicateRow(*col, size));
}
else
{
/// cut column
assert(start + size <= col->size());
result_chunk.addColumn(col->cut(start, size));
}
}
}
}
PasteJoinAlgorithm::PasteJoinAlgorithm(
JoinPtr table_join_,
const Blocks & input_headers,
size_t max_block_size_)
: table_join(table_join_)
, max_block_size(max_block_size_)
, log(&Poco::Logger::get("PasteJoinAlgorithm"))
{
if (input_headers.size() != 2)
throw Exception(ErrorCodes::LOGICAL_ERROR, "PasteJoinAlgorithm requires exactly two inputs");
auto strictness = table_join->getTableJoin().strictness();
if (strictness != JoinStrictness::Any && strictness != JoinStrictness::All)
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "PasteJoinAlgorithm is not implemented for strictness {}", strictness);
auto kind = table_join->getTableJoin().kind();
if (!isPaste(kind))
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "PasteJoinAlgorithm is not implemented for kind {}", kind);
for (const auto & [left_key, right_key] : table_join->getTableJoin().leftToRightKeyRemap())
{
size_t left_idx = input_headers[0].getPositionByName(left_key);
size_t right_idx = input_headers[1].getPositionByName(right_key);
left_to_right_key_remap[left_idx] = right_idx;
}
}
void PasteJoinAlgorithm::logElapsed(double seconds)
{
LOG_TRACE(log,
"Finished pocessing in {} seconds"
", left: {} blocks, {} rows; right: {} blocks, {} rows"
", max blocks loaded to memory: {}",
seconds, stat.num_blocks[0], stat.num_rows[0], stat.num_blocks[1], stat.num_rows[1],
stat.max_blocks_loaded);
}
static void prepareChunk(Chunk & chunk)
{
if (!chunk)
return;
auto num_rows = chunk.getNumRows();
auto columns = chunk.detachColumns();
for (auto & column : columns)
column = column->convertToFullColumnIfConst();
chunk.setColumns(std::move(columns), num_rows);
}
void PasteJoinAlgorithm::initialize(Inputs inputs)
{
if (inputs.size() != 2)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Two inputs are required, got {}", inputs.size());
for (size_t i = 0; i < inputs.size(); ++i)
{
consume(inputs[i], i);
}
}
void PasteJoinAlgorithm::consume(Input & input, size_t source_num)
{
if (input.skip_last_row)
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "skip_last_row is not supported");
if (input.permutation)
throw DB::Exception(ErrorCodes::NOT_IMPLEMENTED, "permutation is not supported");
if (input.chunk)
{
stat.num_blocks[source_num] += 1;
stat.num_rows[source_num] += input.chunk.getNumRows();
}
prepareChunk(input.chunk);
chunks[source_num] = std::move(input.chunk);
}
/// if `source_num == 0` get data from left cursor and fill defaults at right
/// otherwise - vice versa
Chunk PasteJoinAlgorithm::createBlockWithDefaults(size_t source_num, size_t start, size_t num_rows) const
{
ColumnRawPtrs cols;
{
const auto & columns_left = chunks[0].getColumns();
const auto & columns_right = chunks[1].getColumns();
for (size_t i = 0; i < columns_left.size(); ++i)
{
if (auto it = left_to_right_key_remap.find(i); source_num == 0 || it == left_to_right_key_remap.end())
{
cols.push_back(columns_left[i].get());
}
else
{
cols.push_back(columns_right[it->second].get());
}
}
for (const auto & col : columns_right)
{
cols.push_back(col.get());
}
}
Chunk result_chunk;
copyColumnsResized(cols, start, num_rows, result_chunk);
return result_chunk;
}
enum ChunkToCut
{
First,
Second,
None,
};
IMergingAlgorithm::Status PasteJoinAlgorithm::merge()
{
PaddedPODArray<UInt64> indices[2];
Chunk result;
for (size_t source_num = 0; source_num < 2; ++source_num)
{
ChunkToCut to_cut = None;
if (chunks[0].getNumRows() != chunks[1].getNumRows())
to_cut = chunks[0].getNumRows() > chunks[1].getNumRows() ? ChunkToCut::First : ChunkToCut::Second;
for (const auto & col : chunks[source_num].getColumns())
{
if (to_cut == ChunkToCut::First)
result.addColumn(col->cut(0, chunks[1].getNumRows()));
else if (to_cut == ChunkToCut::Second)
result.addColumn(col->cut(0, chunks[0].getNumRows()));
else
result.addColumn(col);
}
}
return Status(std::move(result), true);
}
PasteJoinTransform::PasteJoinTransform(
JoinPtr table_join,
const Blocks & input_headers,
const Block & output_header,
size_t max_block_size,
UInt64 limit_hint_)
: IMergingTransform<PasteJoinAlgorithm>(
input_headers,
output_header,
/* have_all_inputs_= */ true,
limit_hint_,
/* always_read_till_end_= */ false,
/* empty_chunk_on_finish_= */ true,
table_join, input_headers, max_block_size)
, log(&Poco::Logger::get("PasteJoinTransform"))
{
LOG_TRACE(log, "Use PasteJoinTransform");
}
void PasteJoinTransform::onFinish()
{
algorithm.logElapsed(total_stopwatch.elapsedSeconds());
}
}

View File

@ -0,0 +1,87 @@
#pragma once
#include <cassert>
#include <cstddef>
#include <memory>
#include <mutex>
#include <utility>
#include <boost/core/noncopyable.hpp>
#include <Common/PODArray.h>
#include <IO/ReadBuffer.h>
#include <Parsers/ASTTablesInSelectQuery.h>
#include <Processors/Chunk.h>
#include <Processors/Merges/Algorithms/IMergingAlgorithm.h>
#include <Processors/Merges/IMergingTransform.h>
namespace Poco { class Logger; }
namespace DB
{
class IJoin;
using JoinPtr = std::shared_ptr<IJoin>;
/*
* This class is used to join chunks from two sorted streams.
* It is used in MergeJoinTransform.
*/
class PasteJoinAlgorithm final : public IMergingAlgorithm
{
public:
explicit PasteJoinAlgorithm(JoinPtr table_join, const Blocks & input_headers, size_t max_block_size_);
const char * getName() const override { return "PasteJoinAlgorithm"; }
virtual void initialize(Inputs inputs) override;
virtual void consume(Input & input, size_t source_num) override;
virtual Status merge() override;
void logElapsed(double seconds);
private:
Chunk createBlockWithDefaults(size_t source_num);
Chunk createBlockWithDefaults(size_t source_num, size_t start, size_t num_rows) const;
/// For `USING` join key columns should have values from right side instead of defaults
std::unordered_map<size_t, size_t> left_to_right_key_remap;
std::array<Chunk, 2> chunks;
JoinPtr table_join;
size_t max_block_size;
struct Statistic
{
size_t num_blocks[2] = {0, 0};
size_t num_rows[2] = {0, 0};
size_t max_blocks_loaded = 0;
};
Statistic stat;
Poco::Logger * log;
};
class PasteJoinTransform final : public IMergingTransform<PasteJoinAlgorithm>
{
using Base = IMergingTransform<PasteJoinAlgorithm>;
public:
PasteJoinTransform(
JoinPtr table_join,
const Blocks & input_headers,
const Block & output_header,
size_t max_block_size,
UInt64 limit_hint = 0);
String getName() const override { return "PasteJoinTransform"; }
protected:
void onFinish() override;
Poco::Logger * log;
};
}

View File

@ -25,6 +25,7 @@
#include <Processors/Transforms/ExtremesTransform.h>
#include <Processors/Transforms/JoiningTransform.h>
#include <Processors/Transforms/MergeJoinTransform.h>
#include <Processors/Transforms/PasteJoinTransform.h>
#include <Processors/Transforms/MergingAggregatedMemoryEfficientTransform.h>
#include <Processors/Transforms/PartialSortingTransform.h>
#include <Processors/Transforms/TotalsHavingTransform.h>
@ -354,7 +355,9 @@ std::unique_ptr<QueryPipelineBuilder> QueryPipelineBuilder::joinPipelinesYShaped
left->pipe.dropExtremes();
right->pipe.dropExtremes();
if (left->getNumStreams() != 1 || right->getNumStreams() != 1)
if ((left->getNumStreams() != 1 || right->getNumStreams() != 1) && join->getTableJoin().kind() == JoinKind::Paste)
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Paste JOIN requires sorted tables only");
else if (left->getNumStreams() != 1 || right->getNumStreams() != 1)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Join is supported only for pipelines with one output port");
if (left->hasTotals() || right->hasTotals())
@ -362,9 +365,16 @@ std::unique_ptr<QueryPipelineBuilder> QueryPipelineBuilder::joinPipelinesYShaped
Blocks inputs = {left->getHeader(), right->getHeader()};
auto joining = std::make_shared<MergeJoinTransform>(join, inputs, out_header, max_block_size);
return mergePipelines(std::move(left), std::move(right), std::move(joining), collected_processors);
if (join->getTableJoin().kind() == JoinKind::Paste)
{
auto joining = std::make_shared<PasteJoinTransform>(join, inputs, out_header, max_block_size);
return mergePipelines(std::move(left), std::move(right), std::move(joining), collected_processors);
}
else
{
auto joining = std::make_shared<MergeJoinTransform>(join, inputs, out_header, max_block_size);
return mergePipelines(std::move(left), std::move(right), std::move(joining), collected_processors);
}
}
std::unique_ptr<QueryPipelineBuilder> QueryPipelineBuilder::joinPipelinesRightLeft(

View File

@ -0,0 +1,28 @@
0 0
1 1
2 2
3 3
4 4
5 5
6 6
7 7
8 8
9 9
0 9
1 8
2 7
3 6
4 5
5 4
6 3
7 2
8 1
9 0
0 0
1 1
2 2
3 3
4 4
5 5
0 0
1 1

View File

@ -0,0 +1,10 @@
select * from (SELECT number as a FROM numbers(10)) t1 PASTE JOIN (select number as a from numbers(10)) t2;
select * from (SELECT number as a FROM numbers(10)) t1 PASTE JOIN (select number as a from numbers(10) order by a desc) t2;
create table if not exists test (num UInt64) engine=Memory;
insert into test select number from numbers(6);
insert into test select number from numbers(5);
select * from (SELECT number as a FROM numbers(11)) t1 PASTE JOIN test t2 SETTINGS max_threads=1;
select * from (SELECT number as a FROM numbers(11)) t1 PASTE JOIN (select * from test limit 2) t2 SETTINGs max_threads=1;
select * from (SELECT number as a FROM numbers(10)) t1 ANY PASTE JOIN (select number as a from numbers(10)) t2; -- { clientError SYNTAX_ERROR }
select * from (SELECT number as a FROM numbers(10)) t1 ALL PASTE JOIN (select number as a from numbers(10)) t2; -- { clientError SYNTAX_ERROR }
select * from (SELECT number as a FROM numbers_mt(10)) t1 PASTE JOIN (select number as a from numbers(10) ORDER BY a DESC) t2 SETTINGS max_block_size=3; -- { serverError BAD_ARGUMENTS }