Rewrite join totals, fix block structure mismatch

This commit is contained in:
vdimir 2021-07-14 13:02:23 +03:00
parent f94b1419f2
commit b49e37aa07
No known key found for this signature in database
GPG Key ID: F57B3E10A21DBB31
14 changed files with 70 additions and 85 deletions

View File

@ -436,7 +436,7 @@ Block Block::sortColumns() const
Block sorted_block; Block sorted_block;
/// std::unordered_map (index_by_name) cannot be used to guarantee the sort order /// std::unordered_map (index_by_name) cannot be used to guarantee the sort order
std::vector<decltype(index_by_name.begin())> sorted_index_by_name(index_by_name.size()); std::vector<IndexByName::const_iterator> sorted_index_by_name(index_by_name.size());
{ {
size_t i = 0; size_t i = 0;
for (auto it = index_by_name.begin(); it != index_by_name.end(); ++it) for (auto it = index_by_name.begin(); it != index_by_name.end(); ++it)

View File

@ -68,7 +68,7 @@ public:
const_cast<const Block *>(this)->findByName(name)); const_cast<const Block *>(this)->findByName(name));
} }
const ColumnWithTypeAndName* findByName(const std::string & name) const; const ColumnWithTypeAndName * findByName(const std::string & name) const;
ColumnWithTypeAndName & getByName(const std::string & name) ColumnWithTypeAndName & getByName(const std::string & name)
{ {

View File

@ -1368,18 +1368,6 @@ void HashJoin::joinBlock(Block & block, ExtraBlockPtr & not_processed)
throw Exception("Logical error: unknown combination of JOIN", ErrorCodes::LOGICAL_ERROR); throw Exception("Logical error: unknown combination of JOIN", ErrorCodes::LOGICAL_ERROR);
} }
void HashJoin::joinTotals(Block & block) const
{
Block sample_right_block = sample_block_with_columns_to_add.cloneEmpty();
/// For StorageJoin column names isn't qualified in sample_block_with_columns_to_add
for (auto & col : sample_right_block)
col.name = getTableJoin().renamedRightColumnName(col.name);
JoinCommon::joinTotals(totals, sample_right_block, *table_join, block);
}
template <typename Mapped> template <typename Mapped>
struct AdderNonJoined struct AdderNonJoined
{ {

View File

@ -155,9 +155,7 @@ public:
/** Keep "totals" (separate part of dataset, see WITH TOTALS) to use later. /** Keep "totals" (separate part of dataset, see WITH TOTALS) to use later.
*/ */
void setTotals(const Block & block) override { totals = block; } void setTotals(const Block & block) override { totals = block; }
bool hasTotals() const override { return totals; } const Block & getTotals() const override { return totals; }
void joinTotals(Block & block) const override;
bool isFilled() const override { return from_storage_join || data->type == Type::DICT; } bool isFilled() const override { return from_storage_join || data->type == Type::DICT; }

View File

@ -31,11 +31,9 @@ public:
/// Could be called from different threads in parallel. /// Could be called from different threads in parallel.
virtual void joinBlock(Block & block, std::shared_ptr<ExtraBlock> & not_processed) = 0; virtual void joinBlock(Block & block, std::shared_ptr<ExtraBlock> & not_processed) = 0;
virtual bool hasTotals() const = 0; /// Set/Get totals for right table
/// Set totals for right table
virtual void setTotals(const Block & block) = 0; virtual void setTotals(const Block & block) = 0;
/// Add totals to block from left table virtual const Block & getTotals() const = 0;
virtual void joinTotals(Block & block) const = 0;
virtual size_t getTotalRowCount() const = 0; virtual size_t getTotalRowCount() const = 0;
virtual size_t getTotalByteCount() const = 0; virtual size_t getTotalByteCount() const = 0;

View File

@ -31,9 +31,9 @@ public:
join->joinBlock(block, not_processed); join->joinBlock(block, not_processed);
} }
bool hasTotals() const override const Block & getTotals() const override
{ {
return join->hasTotals(); return join->getTotals();
} }
void setTotals(const Block & block) override void setTotals(const Block & block) override
@ -41,11 +41,6 @@ public:
join->setTotals(block); join->setTotals(block);
} }
void joinTotals(Block & block) const override
{
join->joinTotals(block);
}
size_t getTotalRowCount() const override size_t getTotalRowCount() const override
{ {
return join->getTotalRowCount(); return join->getTotalRowCount();

View File

@ -503,11 +503,6 @@ void MergeJoin::setTotals(const Block & totals_block)
used_rows_bitmap = std::make_shared<RowBitmaps>(getRightBlocksCount()); used_rows_bitmap = std::make_shared<RowBitmaps>(getRightBlocksCount());
} }
void MergeJoin::joinTotals(Block & block) const
{
JoinCommon::joinTotals(totals, right_columns_to_add, *table_join, block);
}
void MergeJoin::mergeRightBlocks() void MergeJoin::mergeRightBlocks()
{ {
if (is_in_memory) if (is_in_memory)

View File

@ -26,9 +26,10 @@ public:
const TableJoin & getTableJoin() const override { return *table_join; } const TableJoin & getTableJoin() const override { return *table_join; }
bool addJoinedBlock(const Block & block, bool check_limits) override; bool addJoinedBlock(const Block & block, bool check_limits) override;
void joinBlock(Block &, ExtraBlockPtr & not_processed) override; void joinBlock(Block &, ExtraBlockPtr & not_processed) override;
void joinTotals(Block &) const override;
void setTotals(const Block &) override; void setTotals(const Block &) override;
bool hasTotals() const override { return totals; } const Block & getTotals() const override { return totals; }
size_t getTotalRowCount() const override { return right_blocks.row_count; } size_t getTotalRowCount() const override { return right_blocks.row_count; }
size_t getTotalByteCount() const override { return right_blocks.bytes; } size_t getTotalByteCount() const override { return right_blocks.bytes; }

View File

@ -322,50 +322,26 @@ void createMissedColumns(Block & block)
} }
/// Append totals from right to left block, correct types if needed /// Append totals from right to left block, correct types if needed
void joinTotals(const Block & totals, const Block & columns_to_add, const TableJoin & table_join, Block & block) void joinTotals(Block left_totals, Block right_totals, const TableJoin & table_join, Block & out_block)
{ {
if (table_join.forceNullableLeft()) if (table_join.forceNullableLeft())
convertColumnsToNullable(block); JoinCommon::convertColumnsToNullable(left_totals);
if (Block totals_without_keys = totals) if (table_join.forceNullableRight())
JoinCommon::convertColumnsToNullable(right_totals);
for (auto & col : out_block)
{ {
const auto & required_right = table_join.requiredRightKeys(); if (const auto * left_col = left_totals.findByName(col.name))
for (const auto & name : table_join.keyNamesRight()) col = *left_col;
{ else if (const auto * right_col = right_totals.findByName(col.name))
if (!required_right.contains(name)) col = *right_col;
totals_without_keys.erase(totals_without_keys.getPositionByName(name)); else
} col.column = col.type->createColumnConstWithDefaultValue(1)->convertToFullColumnIfConst();
for (auto & col : totals_without_keys) /// In case of using `arrayJoin` we can get more or less rows than one
{ if (col.column->size() != 1)
if (table_join.rightBecomeNullable(col.type)) col.column = col.column->cloneResized(1);
JoinCommon::convertColumnToNullable(col);
/// In case of arrayJoin it can be not one row
if (col.column->size() != 1)
col.column = col.column->cloneResized(1);
}
for (size_t i = 0; i < totals_without_keys.columns(); ++i)
block.insert(totals_without_keys.safeGetByPosition(i));
}
else
{
/// We will join empty `totals` - from one row with the default values.
for (size_t i = 0; i < columns_to_add.columns(); ++i)
{
const auto & col = columns_to_add.getByPosition(i);
if (block.has(col.name))
{
/// For StorageJoin we discarded table qualifiers, so some names may clash
continue;
}
block.insert({
col.type->createColumnConstWithDefaultValue(1)->convertToFullColumnIfConst(),
col.type,
col.name});
}
} }
} }

View File

@ -35,7 +35,7 @@ ColumnRawPtrs extractKeysForJoin(const Block & block_keys, const Names & key_nam
void checkTypesOfKeys(const Block & block_left, const Names & key_names_left, const Block & block_right, const Names & key_names_right); void checkTypesOfKeys(const Block & block_left, const Names & key_names_left, const Block & block_right, const Names & key_names_right);
void createMissedColumns(Block & block); void createMissedColumns(Block & block);
void joinTotals(const Block & totals, const Block & columns_to_add, const TableJoin & table_join, Block & block); void joinTotals(Block left_totals, Block right_totals, const TableJoin & table_join, Block & out_block);
void addDefaultValues(IColumn & column, const DataTypePtr & type, size_t count); void addDefaultValues(IColumn & column, const DataTypePtr & type, size_t count);

View File

@ -70,7 +70,7 @@ FilledJoinStep::FilledJoinStep(const DataStream & input_stream_, JoinPtr join_,
void FilledJoinStep::transformPipeline(QueryPipeline & pipeline, const BuildQueryPipelineSettings &) void FilledJoinStep::transformPipeline(QueryPipeline & pipeline, const BuildQueryPipelineSettings &)
{ {
bool default_totals = false; bool default_totals = false;
if (!pipeline.hasTotals() && join->hasTotals()) if (!pipeline.hasTotals() && join->getTotals())
{ {
pipeline.addDefaultTotals(); pipeline.addDefaultTotals();
default_totals = true; default_totals = true;

View File

@ -1,6 +1,6 @@
#include <Processors/Transforms/JoiningTransform.h> #include <Processors/Transforms/JoiningTransform.h>
#include <Interpreters/ExpressionAnalyzer.h> #include <Interpreters/ExpressionAnalyzer.h>
#include <Interpreters/ExpressionActions.h> #include <Interpreters/join_common.h>
#include <DataStreams/IBlockInputStream.h> #include <DataStreams/IBlockInputStream.h>
#include <DataTypes/DataTypesNumber.h> #include <DataTypes/DataTypesNumber.h>
@ -159,19 +159,16 @@ void JoiningTransform::transform(Chunk & chunk)
Block block; Block block;
if (on_totals) if (on_totals)
{ {
/// We have to make chunk empty before return const auto & left_totals = inputs.front().getHeader().cloneWithColumns(chunk.detachColumns());
/// In case of using `arrayJoin` we can get more or less rows than one const auto & right_totals = join->getTotals();
auto cols = chunk.detachColumns();
for (auto & col : cols)
col = col->cloneResized(1);
block = inputs.front().getHeader().cloneWithColumns(std::move(cols));
/// Drop totals if both out stream and joined stream doesn't have ones. /// Drop totals if both out stream and joined stream doesn't have ones.
/// See comment in ExpressionTransform.h /// See comment in ExpressionTransform.h
if (default_totals && !join->hasTotals()) if (default_totals && !right_totals)
return; return;
join->joinTotals(block); block = outputs.front().getHeader().cloneEmpty();
JoinCommon::joinTotals(left_totals, right_totals, join->getTableJoin(), block);
} }
else else
block = readExecute(chunk); block = readExecute(chunk);

View File

@ -27,5 +27,22 @@
0 0 0 0
1 1 1 1
1 1
0 0 0 0
1 1
1 1
0 0
1 1
0 0
1 foo 1 1 300
0 foo 1 0 300
1 100 1970-01-01 1 100 1970-01-01
1 100 1970-01-01 1 200 1970-01-02
1 200 1970-01-02 1 100 1970-01-01
1 200 1970-01-02 1 200 1970-01-02
0 0 1970-01-01 0 0 1970-01-01

View File

@ -67,14 +67,34 @@ FROM (SELECT item_id FROM t GROUP BY item_id WITH TOTALS) l
LEFT JOIN (SELECT item_id FROM t ) r LEFT JOIN (SELECT item_id FROM t ) r
ON l.item_id = r.item_id; ON l.item_id = r.item_id;
SELECT *
FROM (SELECT item_id FROM t GROUP BY item_id WITH TOTALS) l
RIGHT JOIN (SELECT item_id FROM t ) r
ON l.item_id = r.item_id;
SELECT * SELECT *
FROM (SELECT item_id FROM t) l FROM (SELECT item_id FROM t) l
LEFT JOIN (SELECT item_id FROM t GROUP BY item_id WITH TOTALS ) r LEFT JOIN (SELECT item_id FROM t GROUP BY item_id WITH TOTALS ) r
ON l.item_id = r.item_id; ON l.item_id = r.item_id;
SELECT *
FROM (SELECT item_id FROM t) l
RIGHT JOIN (SELECT item_id FROM t GROUP BY item_id WITH TOTALS ) r
ON l.item_id = r.item_id;
SELECT * SELECT *
FROM (SELECT item_id FROM t GROUP BY item_id WITH TOTALS) l FROM (SELECT item_id FROM t GROUP BY item_id WITH TOTALS) l
LEFT JOIN (SELECT item_id FROM t GROUP BY item_id WITH TOTALS ) r LEFT JOIN (SELECT item_id FROM t GROUP BY item_id WITH TOTALS ) r
ON l.item_id = r.item_id; ON l.item_id = r.item_id;
SELECT *
FROM (SELECT item_id, 'foo' AS key, 1 AS val FROM t GROUP BY item_id WITH TOTALS) l
LEFT JOIN (SELECT item_id, sum(price_sold) AS val FROM t GROUP BY item_id WITH TOTALS ) r
ON l.item_id = r.item_id;
SELECT *
FROM (SELECT * FROM t GROUP BY item_id, price_sold, date WITH TOTALS) l
LEFT JOIN (SELECT * FROM t GROUP BY item_id, price_sold, date WITH TOTALS ) r
ON l.item_id = r.item_id;
DROP TABLE t; DROP TABLE t;