mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-27 01:51:59 +00:00
wip full sorting asof join
This commit is contained in:
parent
48d47d26a4
commit
2412f85219
@ -34,13 +34,15 @@ namespace ErrorCodes
|
||||
namespace
|
||||
{
|
||||
|
||||
FullMergeJoinCursorPtr createCursor(const Block & block, const Names & columns)
|
||||
constexpr UInt64 DEFAULT_VALUE_INDEX = std::numeric_limits<UInt64>::max();
|
||||
|
||||
FullMergeJoinCursorPtr createCursor(const Block & block, const Names & columns, JoinStrictness strictness)
|
||||
{
|
||||
SortDescription desc;
|
||||
desc.reserve(columns.size());
|
||||
for (const auto & name : columns)
|
||||
desc.emplace_back(name);
|
||||
return std::make_unique<FullMergeJoinCursor>(materializeBlock(block), desc);
|
||||
return std::make_unique<FullMergeJoinCursor>(materializeBlock(block), desc, strictness == JoinStrictness::Asof);
|
||||
}
|
||||
|
||||
template <bool has_left_nulls, bool has_right_nulls>
|
||||
@ -90,9 +92,10 @@ int nullableCompareAt(const IColumn & left_column, const IColumn & right_column,
|
||||
|
||||
int ALWAYS_INLINE compareCursors(const SortCursorImpl & lhs, size_t lpos,
|
||||
const SortCursorImpl & rhs, size_t rpos,
|
||||
size_t key_length,
|
||||
int null_direction_hint)
|
||||
{
|
||||
for (size_t i = 0; i < lhs.sort_columns_size; ++i)
|
||||
for (size_t i = 0; i < key_length; ++i)
|
||||
{
|
||||
/// TODO(@vdimir): use nullableCompareAt only if there's nullable columns
|
||||
int cmp = nullableCompareAt<true, true>(*lhs.sort_columns[i], *rhs.sort_columns[i], lpos, rpos, null_direction_hint);
|
||||
@ -104,13 +107,18 @@ int ALWAYS_INLINE compareCursors(const SortCursorImpl & lhs, size_t lpos,
|
||||
|
||||
int ALWAYS_INLINE compareCursors(const SortCursorImpl & lhs, const SortCursorImpl & rhs, int null_direction_hint)
|
||||
{
|
||||
return compareCursors(lhs, lhs.getRow(), rhs, rhs.getRow(), null_direction_hint);
|
||||
return compareCursors(lhs, lhs.getRow(), rhs, rhs.getRow(), lhs.sort_columns_size, null_direction_hint);
|
||||
}
|
||||
|
||||
int compareAsofCursors(const FullMergeJoinCursor & lhs, const FullMergeJoinCursor & rhs)
|
||||
{
|
||||
return nullableCompareAt<false, false>(lhs.getAsofColumn(), rhs.getAsofColumn(), lhs->getRow(), rhs->getRow());
|
||||
}
|
||||
|
||||
bool ALWAYS_INLINE totallyLess(SortCursorImpl & lhs, SortCursorImpl & rhs, int null_direction_hint)
|
||||
{
|
||||
/// The last row of left cursor is less than the current row of the right cursor.
|
||||
int cmp = compareCursors(lhs, lhs.rows - 1, rhs, rhs.getRow(), null_direction_hint);
|
||||
int cmp = compareCursors(lhs, lhs.rows - 1, rhs, rhs.getRow(), lhs.sort_columns_size, null_direction_hint);
|
||||
return cmp < 0;
|
||||
}
|
||||
|
||||
@ -222,11 +230,11 @@ Chunk getRowFromChunk(const Chunk & chunk, size_t pos)
|
||||
return result;
|
||||
}
|
||||
|
||||
void inline addRange(PaddedPODArray<UInt64> & left_map, size_t start, size_t end)
|
||||
void inline addRange(PaddedPODArray<UInt64> & values, UInt64 start, UInt64 end)
|
||||
{
|
||||
assert(end > start);
|
||||
for (size_t i = start; i < end; ++i)
|
||||
left_map.push_back(i);
|
||||
for (UInt64 i = start; i < end; ++i)
|
||||
values.push_back(i);
|
||||
}
|
||||
|
||||
void inline addMany(PaddedPODArray<UInt64> & left_or_right_map, size_t idx, size_t num)
|
||||
@ -235,6 +243,11 @@ void inline addMany(PaddedPODArray<UInt64> & left_or_right_map, size_t idx, size
|
||||
left_or_right_map.push_back(idx);
|
||||
}
|
||||
|
||||
void inline addMany(PaddedPODArray<UInt64> & values, UInt64 value, size_t num)
|
||||
{
|
||||
values.resize_fill(values.size() + num, value);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
const Chunk & FullMergeJoinCursor::getCurrent() const
|
||||
@ -283,9 +296,15 @@ MergeJoinAlgorithm::MergeJoinAlgorithm(
|
||||
if (input_headers.size() != 2)
|
||||
throw Exception(ErrorCodes::LOGICAL_ERROR, "MergeJoinAlgorithm requires exactly two inputs");
|
||||
|
||||
if (strictness != JoinStrictness::Any && strictness != JoinStrictness::All)
|
||||
if (strictness != JoinStrictness::Any && strictness != JoinStrictness::All && strictness != JoinStrictness::Asof)
|
||||
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "MergeJoinAlgorithm is not implemented for strictness {}", strictness);
|
||||
|
||||
if (strictness == JoinStrictness::Asof)
|
||||
{
|
||||
if (kind != JoinKind::Left && kind != JoinKind::Inner)
|
||||
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "MergeJoinAlgorithm does not implement ASOF {} join", kind);
|
||||
}
|
||||
|
||||
if (!isInner(kind) && !isLeft(kind) && !isRight(kind) && !isFull(kind))
|
||||
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "MergeJoinAlgorithm is not implemented for kind {}", kind);
|
||||
|
||||
@ -293,8 +312,8 @@ MergeJoinAlgorithm::MergeJoinAlgorithm(
|
||||
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "MergeJoinAlgorithm does not support ON filter conditions");
|
||||
|
||||
cursors = {
|
||||
createCursor(input_headers[0], on_clause_.key_names_left),
|
||||
createCursor(input_headers[1], on_clause_.key_names_right)
|
||||
createCursor(input_headers[0], on_clause_.key_names_left, strictness),
|
||||
createCursor(input_headers[1], on_clause_.key_names_right, strictness),
|
||||
};
|
||||
|
||||
MergeJoinAlgorithm::MergeJoinAlgorithm(
|
||||
@ -313,6 +332,8 @@ MergeJoinAlgorithm::MergeJoinAlgorithm(
|
||||
size_t left_idx = input_headers[0].getPositionByName(left_key);
|
||||
size_t right_idx = input_headers[1].getPositionByName(right_key);
|
||||
left_to_right_key_remap[left_idx] = right_idx;
|
||||
if (strictness == JoinStrictness::Asof)
|
||||
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "MergeJoinAlgorithm does not support ASOF joins USING");
|
||||
}
|
||||
|
||||
const auto *smjPtr = typeid_cast<const FullSortingMergeJoin *>(table_join.get());
|
||||
@ -321,6 +342,19 @@ MergeJoinAlgorithm::MergeJoinAlgorithm(
|
||||
null_direction_hint = smjPtr->getNullDirection();
|
||||
}
|
||||
|
||||
if (strictness == JoinStrictness::Asof)
|
||||
setAsofInequality(join_ptr->getTableJoin().getAsofInequality());
|
||||
}
|
||||
|
||||
void MergeJoinAlgorithm::setAsofInequality(ASOFJoinInequality asof_inequality_)
|
||||
{
|
||||
if (strictness != JoinStrictness::Asof)
|
||||
throw Exception(ErrorCodes::LOGICAL_ERROR, "setAsofInequality is only supported for ASOF joins");
|
||||
|
||||
if (asof_inequality_ == ASOFJoinInequality::None)
|
||||
throw Exception(ErrorCodes::LOGICAL_ERROR, "ASOF inequality cannot be None");
|
||||
|
||||
asof_inequality = asof_inequality_;
|
||||
}
|
||||
|
||||
void MergeJoinAlgorithm::logElapsed(double seconds)
|
||||
@ -770,6 +804,81 @@ MergeJoinAlgorithm::Status MergeJoinAlgorithm::anyJoin()
|
||||
return Status(std::move(result));
|
||||
}
|
||||
|
||||
|
||||
MergeJoinAlgorithm::Status MergeJoinAlgorithm::asofJoin()
|
||||
{
|
||||
auto & left_cursor = *cursors[0];
|
||||
if (!left_cursor->isValid())
|
||||
return Status(0);
|
||||
|
||||
auto & right_cursor = *cursors[1];
|
||||
if (!right_cursor->isValid())
|
||||
return Status(1);
|
||||
|
||||
PaddedPODArray<UInt64> left_map;
|
||||
PaddedPODArray<UInt64> right_map;
|
||||
|
||||
while (left_cursor->isValid() && right_cursor->isValid())
|
||||
{
|
||||
auto lpos = left_cursor->getRow();
|
||||
auto rpos = right_cursor->getRow();
|
||||
auto cmp = compareCursors(*left_cursor, *right_cursor);
|
||||
if (cmp == 0)
|
||||
{
|
||||
auto asof_cmp = compareAsofCursors(left_cursor, right_cursor);
|
||||
if ((asof_inequality == ASOFJoinInequality::Less && asof_cmp <= -1)
|
||||
|| (asof_inequality == ASOFJoinInequality::LessOrEquals && asof_cmp <= 0))
|
||||
{
|
||||
/// First row in right table that is greater (or equal) than current row in left table
|
||||
/// matches asof join condition the best
|
||||
left_map.push_back(lpos);
|
||||
right_map.push_back(rpos);
|
||||
left_cursor->next();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (asof_inequality == ASOFJoinInequality::Less || asof_inequality == ASOFJoinInequality::LessOrEquals)
|
||||
{
|
||||
/// Asof condition is not (yet) satisfied, skip row in right table
|
||||
right_cursor->next();
|
||||
continue;
|
||||
}
|
||||
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "TODO: implement ASOF equality join");
|
||||
}
|
||||
else if (cmp < 0)
|
||||
{
|
||||
/// no matches for rows in left table, just pass them through
|
||||
size_t num = nextDistinct(*left_cursor);
|
||||
if (isLeft(kind))
|
||||
{
|
||||
/// return them with default values at right side
|
||||
addRange(left_map, lpos, lpos + num);
|
||||
addMany(right_map, DEFAULT_VALUE_INDEX, num);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
/// skip rows in right table until we find match for current row in left table
|
||||
nextDistinct(*right_cursor);
|
||||
}
|
||||
}
|
||||
|
||||
chassert(left_map.size() == right_map.size());
|
||||
Chunk result;
|
||||
{
|
||||
Columns lcols = indexColumns(left_cursor.getCurrent().getColumns(), left_map);
|
||||
for (auto & col : lcols)
|
||||
result.addColumn(std::move(col));
|
||||
|
||||
Columns rcols = indexColumns(right_cursor.getCurrent().getColumns(), right_map);
|
||||
for (auto & col : rcols)
|
||||
result.addColumn(std::move(col));
|
||||
}
|
||||
UNUSED(asof_inequality);
|
||||
return Status(std::move(result));
|
||||
}
|
||||
|
||||
|
||||
/// if `source_num == 0` get data from left cursor and fill defaults at right
|
||||
/// otherwise - vice versa
|
||||
Chunk MergeJoinAlgorithm::createBlockWithDefaults(size_t source_num, size_t start, size_t num_rows) const
|
||||
@ -861,6 +970,9 @@ IMergingAlgorithm::Status MergeJoinAlgorithm::merge()
|
||||
if (strictness == JoinStrictness::All)
|
||||
return allJoin();
|
||||
|
||||
if (strictness == JoinStrictness::Asof)
|
||||
return asofJoin();
|
||||
|
||||
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Unsupported strictness '{}'", strictness);
|
||||
}
|
||||
|
||||
@ -878,10 +990,6 @@ MergeJoinTransform::MergeJoinTransform(
|
||||
/* always_read_till_end_= */ false,
|
||||
/* empty_chunk_on_finish_= */ true,
|
||||
table_join, input_headers, max_block_size)
|
||||
<<<<<<< HEAD
|
||||
, log(getLogger("MergeJoinTransform"))
|
||||
=======
|
||||
>>>>>>> b4a16f38320 (Add simple unit test for full sorting join)
|
||||
{
|
||||
}
|
||||
|
||||
|
@ -195,10 +195,27 @@ private:
|
||||
class FullMergeJoinCursor : boost::noncopyable
|
||||
{
|
||||
public:
|
||||
explicit FullMergeJoinCursor(const Block & sample_block_, const SortDescription & description_)
|
||||
FullMergeJoinCursor(
|
||||
const Block & sample_block_,
|
||||
const SortDescription & description_,
|
||||
bool is_asof = false)
|
||||
: sample_block(sample_block_.cloneEmpty())
|
||||
, desc(description_)
|
||||
{
|
||||
if (desc.size() == 0)
|
||||
throw Exception(ErrorCodes::LOGICAL_ERROR, "Got empty sort description for FullMergeJoinCursor");
|
||||
|
||||
if (is_asof)
|
||||
{
|
||||
/// For ASOF join prefix of sort description is used for equality comparison
|
||||
/// and the last column is used for inequality comparison and is handled separately
|
||||
|
||||
auto asof_column_description = desc.back();
|
||||
desc.pop_back();
|
||||
|
||||
chassert(asof_column_description.direction == 1 && asof_column_description.nulls_direction == 1);
|
||||
asof_column_position = sample_block.getPositionByName(asof_column_description.column_name);
|
||||
}
|
||||
}
|
||||
|
||||
bool fullyCompleted() const;
|
||||
@ -209,17 +226,27 @@ public:
|
||||
SortCursorImpl * operator-> () { return &cursor; }
|
||||
const SortCursorImpl * operator-> () const { return &cursor; }
|
||||
|
||||
SortCursorImpl & operator* () { return cursor; }
|
||||
const SortCursorImpl & operator* () const { return cursor; }
|
||||
|
||||
SortCursorImpl cursor;
|
||||
|
||||
const Block & sampleBlock() const { return sample_block; }
|
||||
Columns sampleColumns() const { return sample_block.getColumns(); }
|
||||
|
||||
const IColumn & getAsofColumn() const
|
||||
{
|
||||
return *cursor.all_columns[asof_column_position];
|
||||
}
|
||||
|
||||
private:
|
||||
Block sample_block;
|
||||
SortDescription desc;
|
||||
|
||||
Chunk current_chunk;
|
||||
bool recieved_all_blocks = false;
|
||||
|
||||
size_t asof_column_position;
|
||||
};
|
||||
|
||||
/*
|
||||
@ -242,8 +269,9 @@ public:
|
||||
void consume(Input & input, size_t source_num) override;
|
||||
Status merge() override;
|
||||
|
||||
void logElapsed(double seconds);
|
||||
void setAsofInequality(ASOFJoinInequality asof_inequality_);
|
||||
|
||||
void logElapsed(double seconds);
|
||||
private:
|
||||
std::optional<Status> handleAnyJoinState();
|
||||
Status anyJoin();
|
||||
@ -251,13 +279,17 @@ private:
|
||||
std::optional<Status> handleAllJoinState();
|
||||
Status allJoin();
|
||||
|
||||
Status asofJoin();
|
||||
|
||||
Chunk createBlockWithDefaults(size_t source_num);
|
||||
Chunk createBlockWithDefaults(size_t source_num, size_t start, size_t num_rows) const;
|
||||
|
||||
|
||||
/// For `USING` join key columns should have values from right side instead of defaults
|
||||
std::unordered_map<size_t, size_t> left_to_right_key_remap;
|
||||
|
||||
std::array<FullMergeJoinCursorPtr, 2> cursors;
|
||||
ASOFJoinInequality asof_inequality = ASOFJoinInequality::None;
|
||||
|
||||
/// Keep some state to make handle data from different blocks
|
||||
AnyJoinState any_join_state;
|
||||
@ -305,6 +337,8 @@ public:
|
||||
|
||||
String getName() const override { return "MergeJoinTransform"; }
|
||||
|
||||
void setAsofInequality(ASOFJoinInequality asof_inequality_) { algorithm.setAsofInequality(asof_inequality_); }
|
||||
|
||||
protected:
|
||||
void onFinish() override;
|
||||
};
|
||||
|
287
src/Processors/tests/gtest_full_sorting_join.cpp
Normal file
287
src/Processors/tests/gtest_full_sorting_join.cpp
Normal file
@ -0,0 +1,287 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
|
||||
#include <Common/randomSeed.h>
|
||||
#include <pcg_random.hpp>
|
||||
#include <random>
|
||||
|
||||
#include <Processors/Sources/SourceFromSingleChunk.h>
|
||||
#include <Processors/Sources/SourceFromChunks.h>
|
||||
#include <Processors/Sinks/NullSink.h>
|
||||
#include <Processors/Executors/PipelineExecutor.h>
|
||||
#include <Processors/Executors/PullingPipelineExecutor.h>
|
||||
#include <QueryPipeline/QueryPipeline.h>
|
||||
|
||||
#include <Columns/ColumnsNumber.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <Interpreters/TableJoin.h>
|
||||
|
||||
#include <Processors/Transforms/MergeJoinTransform.h>
|
||||
|
||||
|
||||
using namespace DB;
|
||||
|
||||
UInt64 getAndPrintRandomSeed()
|
||||
{
|
||||
UInt64 seed = randomSeed();
|
||||
std::cerr << "TEST_RANDOM_SEED: " << seed << std::endl;
|
||||
return seed;
|
||||
}
|
||||
|
||||
static UInt64 TEST_RANDOM_SEED = getAndPrintRandomSeed();
|
||||
|
||||
static pcg64 rng(TEST_RANDOM_SEED);
|
||||
|
||||
|
||||
QueryPipeline buildJoinPipeline(
|
||||
std::shared_ptr<ISource> left_source,
|
||||
std::shared_ptr<ISource> right_source,
|
||||
size_t key_length = 1,
|
||||
JoinKind kind = JoinKind::Inner,
|
||||
JoinStrictness strictness = JoinStrictness::All,
|
||||
ASOFJoinInequality asof_inequality = ASOFJoinInequality::None)
|
||||
{
|
||||
Blocks inputs;
|
||||
inputs.emplace_back(left_source->getPort().getHeader());
|
||||
inputs.emplace_back(right_source->getPort().getHeader());
|
||||
|
||||
Block out_header;
|
||||
for (const auto & input : inputs)
|
||||
{
|
||||
for (ColumnWithTypeAndName column : input)
|
||||
{
|
||||
if (&input == &inputs.front())
|
||||
column.name = "t1." + column.name;
|
||||
else
|
||||
column.name = "t2." + column.name;
|
||||
out_header.insert(column);
|
||||
}
|
||||
}
|
||||
|
||||
TableJoin::JoinOnClause on_clause;
|
||||
for (size_t i = 0; i < key_length; ++i)
|
||||
{
|
||||
on_clause.key_names_left.emplace_back(inputs[0].getByPosition(i).name);
|
||||
on_clause.key_names_right.emplace_back(inputs[1].getByPosition(i).name);
|
||||
}
|
||||
|
||||
auto joining = std::make_shared<MergeJoinTransform>(
|
||||
kind,
|
||||
strictness,
|
||||
on_clause,
|
||||
inputs, out_header, /* max_block_size = */ 0);
|
||||
|
||||
if (asof_inequality != ASOFJoinInequality::None)
|
||||
joining->setAsofInequality(asof_inequality);
|
||||
|
||||
chassert(joining->getInputs().size() == 2);
|
||||
|
||||
connect(left_source->getPort(), joining->getInputs().front());
|
||||
connect(right_source->getPort(), joining->getInputs().back());
|
||||
|
||||
auto * output_port = &joining->getOutputPort();
|
||||
|
||||
auto processors = std::make_shared<Processors>();
|
||||
processors->emplace_back(std::move(left_source));
|
||||
processors->emplace_back(std::move(right_source));
|
||||
processors->emplace_back(std::move(joining));
|
||||
|
||||
QueryPipeline pipeline(QueryPlanResourceHolder{}, processors, output_port);
|
||||
return pipeline;
|
||||
}
|
||||
|
||||
|
||||
std::shared_ptr<ISource> oneColumnSource(const std::vector<std::vector<UInt64>> & values)
|
||||
{
|
||||
Block header = { ColumnWithTypeAndName(ColumnUInt64::create(), std::make_shared<DataTypeUInt64>(), "x") };
|
||||
Chunks chunks;
|
||||
for (const auto & chunk_values : values)
|
||||
{
|
||||
auto column = ColumnUInt64::create();
|
||||
for (auto n : chunk_values)
|
||||
column->insertValue(n);
|
||||
chunks.emplace_back(Chunk(Columns{std::move(column)}, chunk_values.size()));
|
||||
}
|
||||
return std::make_shared<SourceFromChunks>(header, std::move(chunks));
|
||||
}
|
||||
|
||||
|
||||
TEST(FullSortingJoin, Simple)
|
||||
try
|
||||
{
|
||||
auto left_source = oneColumnSource({ {1, 2, 3, 4, 5} });
|
||||
auto right_source = oneColumnSource({ {1}, {2}, {3}, {4}, {5} });
|
||||
|
||||
auto pipeline = buildJoinPipeline(left_source, right_source);
|
||||
PullingPipelineExecutor executor(pipeline);
|
||||
|
||||
Block block;
|
||||
|
||||
size_t total_result_rows = 0;
|
||||
while (executor.pull(block))
|
||||
total_result_rows += block.rows();
|
||||
|
||||
ASSERT_EQ(total_result_rows, 5);
|
||||
}
|
||||
catch (Exception & e)
|
||||
{
|
||||
std::cout << e.getStackTraceString() << std::endl;
|
||||
throw;
|
||||
}
|
||||
|
||||
std::shared_ptr<ISource> sourceFromRows(
|
||||
const Block & header, const std::vector<std::vector<Field>> & values, double break_prob = 0.0)
|
||||
{
|
||||
Chunks chunks;
|
||||
auto columns = header.cloneEmptyColumns();
|
||||
|
||||
std::uniform_real_distribution<> prob_dis(0.0, 1.0);
|
||||
|
||||
|
||||
for (auto row : values)
|
||||
{
|
||||
if (!columns.empty() && (row.empty() || prob_dis(rng) < break_prob))
|
||||
{
|
||||
size_t rows = columns.front()->size();
|
||||
chunks.emplace_back(std::move(columns), rows);
|
||||
columns = header.cloneEmptyColumns();
|
||||
continue;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < columns.size(); ++i)
|
||||
columns[i]->insert(row[i]);
|
||||
}
|
||||
|
||||
if (!columns.empty())
|
||||
chunks.emplace_back(std::move(columns), columns.front()->size());
|
||||
|
||||
return std::make_shared<SourceFromChunks>(header, std::move(chunks));
|
||||
}
|
||||
|
||||
|
||||
std::vector<std::vector<Field>> getValuesFromBlock(const Block & block, const Names & names)
|
||||
{
|
||||
std::vector<std::vector<Field>> result;
|
||||
for (size_t i = 0; i < block.rows(); ++i)
|
||||
{
|
||||
auto & row = result.emplace_back();
|
||||
for (const auto & name : names)
|
||||
block.getByName(name).column->get(i, row.emplace_back());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
Block executePipeline(QueryPipeline & pipeline)
|
||||
{
|
||||
PullingPipelineExecutor executor(pipeline);
|
||||
|
||||
Blocks result_blocks;
|
||||
while (true)
|
||||
{
|
||||
Block block;
|
||||
bool is_ok = executor.pull(block);
|
||||
if (!is_ok)
|
||||
break;
|
||||
result_blocks.emplace_back(std::move(block));
|
||||
}
|
||||
|
||||
return concatenateBlocks(result_blocks);
|
||||
}
|
||||
|
||||
TEST(FullSortingJoin, Asof)
|
||||
try
|
||||
{
|
||||
const std::vector<Field> chunk_break = {};
|
||||
|
||||
auto left_source = sourceFromRows({
|
||||
ColumnWithTypeAndName(ColumnString::create(), std::make_shared<DataTypeString>(), "key"),
|
||||
ColumnWithTypeAndName(ColumnUInt64::create(), std::make_shared<DataTypeUInt64>(), "t"),
|
||||
}, {
|
||||
{"AMZN", 3},
|
||||
{"AMZN", 4},
|
||||
{"AMZN", 6},
|
||||
});
|
||||
|
||||
auto right_source = sourceFromRows({
|
||||
ColumnWithTypeAndName(ColumnString::create(), std::make_shared<DataTypeString>(), "key"),
|
||||
ColumnWithTypeAndName(ColumnUInt64::create(), std::make_shared<DataTypeUInt64>(), "t"),
|
||||
ColumnWithTypeAndName(ColumnUInt64::create(), std::make_shared<DataTypeUInt64>(), "value"),
|
||||
}, {
|
||||
{"AAPL", 1, 97},
|
||||
chunk_break,
|
||||
{"AAPL", 2, 98},
|
||||
{"AAPL", 3, 99},
|
||||
{"AMZN", 1, 100},
|
||||
{"AMZN", 2, 110},
|
||||
chunk_break,
|
||||
{"AMZN", 4, 130},
|
||||
{"AMZN", 5, 140},
|
||||
});
|
||||
|
||||
auto pipeline = buildJoinPipeline(
|
||||
left_source, right_source, /* key_length = */ 2,
|
||||
JoinKind::Inner, JoinStrictness::Asof, ASOFJoinInequality::LessOrEquals);
|
||||
|
||||
Block result_block = executePipeline(pipeline);
|
||||
auto values = getValuesFromBlock(result_block, {"t1.key", "t1.t", "t2.t", "t2.value"});
|
||||
ASSERT_EQ(values.size(), 2);
|
||||
ASSERT_EQ(values[0], (std::vector<Field>{"AMZN", 3u, 4u, 130u}));
|
||||
ASSERT_EQ(values[1], (std::vector<Field>{"AMZN", 4u, 4u, 130u}));
|
||||
}
|
||||
catch (Exception & e)
|
||||
{
|
||||
std::cout << e.getStackTraceString() << std::endl;
|
||||
throw;
|
||||
}
|
||||
|
||||
|
||||
TEST(FullSortingJoin, AsofOnlyColumn)
|
||||
try
|
||||
{
|
||||
const std::vector<Field> chunk_break = {};
|
||||
|
||||
auto left_source = oneColumnSource({ {3}, {3, 3, 3}, {3, 5, 5, 6}, {9, 9}, {10, 20} });
|
||||
|
||||
UInt64 p = std::uniform_int_distribution<>(0, 2)(rng);
|
||||
double break_prob = p == 0 ? 0.0 : (p == 1 ? 0.5 : 1.0);
|
||||
|
||||
auto right_source = sourceFromRows({
|
||||
ColumnWithTypeAndName(ColumnUInt64::create(), std::make_shared<DataTypeUInt64>(), "t"),
|
||||
ColumnWithTypeAndName(ColumnUInt64::create(), std::make_shared<DataTypeUInt64>(), "value"),
|
||||
}, {
|
||||
{1, 101},
|
||||
{2, 102},
|
||||
{4, 104},
|
||||
{5, 105},
|
||||
{11, 111},
|
||||
{15, 115},
|
||||
},
|
||||
break_prob);
|
||||
|
||||
auto pipeline = buildJoinPipeline(
|
||||
left_source, right_source, /* key_length = */ 1,
|
||||
JoinKind::Inner, JoinStrictness::Asof, ASOFJoinInequality::LessOrEquals);
|
||||
|
||||
Block result_block = executePipeline(pipeline);
|
||||
|
||||
ASSERT_EQ(
|
||||
assert_cast<const ColumnUInt64 *>(result_block.getByName("t1.x").column.get())->getData(),
|
||||
(ColumnUInt64::Container{3, 3, 3, 3, 3, 5, 5, 6, 9, 9, 10})
|
||||
);
|
||||
|
||||
ASSERT_EQ(
|
||||
assert_cast<const ColumnUInt64 *>(result_block.getByName("t2.t").column.get())->getData(),
|
||||
(ColumnUInt64::Container{4, 4, 4, 4, 4, 5, 5, 11, 11, 11, 15})
|
||||
);
|
||||
|
||||
ASSERT_EQ(
|
||||
assert_cast<const ColumnUInt64 *>(result_block.getByName("t2.value").column.get())->getData(),
|
||||
(ColumnUInt64::Container{104, 104, 104, 104, 104, 105, 105, 111, 111, 111, 115})
|
||||
);
|
||||
}
|
||||
catch (Exception & e)
|
||||
{
|
||||
std::cout << e.getStackTraceString() << std::endl;
|
||||
throw;
|
||||
}
|
@ -1,95 +0,0 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <Processors/Sources/SourceFromSingleChunk.h>
|
||||
#include <Processors/Sources/SourceFromChunks.h>
|
||||
#include <Processors/Sinks/NullSink.h>
|
||||
#include <Processors/Executors/PipelineExecutor.h>
|
||||
#include <Processors/Executors/PullingPipelineExecutor.h>
|
||||
#include <QueryPipeline/QueryPipeline.h>
|
||||
|
||||
#include <Columns/ColumnsNumber.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <Interpreters/TableJoin.h>
|
||||
|
||||
#include <Processors/Transforms/MergeJoinTransform.h>
|
||||
|
||||
|
||||
using namespace DB;
|
||||
|
||||
|
||||
QueryPipeline buildJoinPipeline(std::shared_ptr<ISource> left_source, std::shared_ptr<ISource> right_source)
|
||||
{
|
||||
Blocks inputs;
|
||||
inputs.emplace_back(left_source->getPort().getHeader());
|
||||
inputs.emplace_back(right_source->getPort().getHeader());
|
||||
Block out_header = {
|
||||
ColumnWithTypeAndName(ColumnUInt8::create(), std::make_shared<DataTypeUInt8>(), "t1.x"),
|
||||
ColumnWithTypeAndName(ColumnUInt8::create(), std::make_shared<DataTypeUInt8>(), "t2.x"),
|
||||
};
|
||||
|
||||
TableJoin::JoinOnClause on_clause;
|
||||
on_clause.key_names_left = {"x"};
|
||||
on_clause.key_names_right = {"x"};
|
||||
auto joining = std::make_shared<MergeJoinTransform>(
|
||||
JoinKind::Inner,
|
||||
JoinStrictness::All,
|
||||
on_clause,
|
||||
inputs, out_header, /* max_block_size = */ 0);
|
||||
|
||||
chassert(joining->getInputs().size() == 2);
|
||||
|
||||
connect(left_source->getPort(), joining->getInputs().front());
|
||||
connect(right_source->getPort(), joining->getInputs().back());
|
||||
|
||||
auto * output_port = &joining->getOutputPort();
|
||||
|
||||
auto processors = std::make_shared<Processors>();
|
||||
processors->emplace_back(std::move(left_source));
|
||||
processors->emplace_back(std::move(right_source));
|
||||
processors->emplace_back(std::move(joining));
|
||||
|
||||
QueryPipeline pipeline(QueryPlanResourceHolder{}, processors, output_port);
|
||||
return pipeline;
|
||||
}
|
||||
|
||||
|
||||
std::shared_ptr<ISource> createSourceWithSingleValue(size_t rows_per_chunk, size_t total_chunks)
|
||||
{
|
||||
Block header = {
|
||||
ColumnWithTypeAndName(ColumnUInt64::create(), std::make_shared<DataTypeUInt64>(), "x")
|
||||
};
|
||||
|
||||
Chunks chunks;
|
||||
|
||||
for (size_t i = 0; i < total_chunks; ++i)
|
||||
{
|
||||
auto col = ColumnUInt64::create(rows_per_chunk, 1);
|
||||
chunks.emplace_back(Columns{std::move(col)}, rows_per_chunk);
|
||||
}
|
||||
|
||||
return std::make_shared<SourceFromChunks>(std::move(header), std::move(chunks));
|
||||
}
|
||||
|
||||
TEST(FullSortingJoin, Simple)
|
||||
try
|
||||
{
|
||||
auto left_source = createSourceWithSingleValue(3, 10);
|
||||
auto right_source = createSourceWithSingleValue(2, 15);
|
||||
|
||||
auto pipeline = buildJoinPipeline(left_source, right_source);
|
||||
PullingPipelineExecutor executor(pipeline);
|
||||
|
||||
Block block;
|
||||
|
||||
size_t total_result_rows = 0;
|
||||
while (executor.pull(block))
|
||||
{
|
||||
total_result_rows += block.rows();
|
||||
}
|
||||
ASSERT_EQ(total_result_rows, 3 * 10 * 2 * 15);
|
||||
}
|
||||
catch (Exception & e)
|
||||
{
|
||||
std::cout << e.getStackTraceString() << std::endl;
|
||||
throw;
|
||||
}
|
Loading…
Reference in New Issue
Block a user