Rewrite join totals, fix block structure mismatch

This commit is contained in:
vdimir 2021-07-14 13:02:23 +03:00
parent f94b1419f2
commit b49e37aa07
No known key found for this signature in database
GPG Key ID: F57B3E10A21DBB31
14 changed files with 70 additions and 85 deletions

View File

@ -436,7 +436,7 @@ Block Block::sortColumns() const
Block sorted_block;
/// std::unordered_map (index_by_name) cannot be used to guarantee the sort order
std::vector<decltype(index_by_name.begin())> sorted_index_by_name(index_by_name.size());
std::vector<IndexByName::const_iterator> sorted_index_by_name(index_by_name.size());
{
size_t i = 0;
for (auto it = index_by_name.begin(); it != index_by_name.end(); ++it)

View File

@ -1368,18 +1368,6 @@ void HashJoin::joinBlock(Block & block, ExtraBlockPtr & not_processed)
throw Exception("Logical error: unknown combination of JOIN", ErrorCodes::LOGICAL_ERROR);
}
void HashJoin::joinTotals(Block & block) const
{
Block sample_right_block = sample_block_with_columns_to_add.cloneEmpty();
/// For StorageJoin column names isn't qualified in sample_block_with_columns_to_add
for (auto & col : sample_right_block)
col.name = getTableJoin().renamedRightColumnName(col.name);
JoinCommon::joinTotals(totals, sample_right_block, *table_join, block);
}
template <typename Mapped>
struct AdderNonJoined
{

View File

@ -155,9 +155,7 @@ public:
/** Keep "totals" (separate part of dataset, see WITH TOTALS) to use later.
*/
void setTotals(const Block & block) override { totals = block; }
bool hasTotals() const override { return totals; }
void joinTotals(Block & block) const override;
const Block & getTotals() const override { return totals; }
bool isFilled() const override { return from_storage_join || data->type == Type::DICT; }

View File

@ -31,11 +31,9 @@ public:
/// Could be called from different threads in parallel.
virtual void joinBlock(Block & block, std::shared_ptr<ExtraBlock> & not_processed) = 0;
virtual bool hasTotals() const = 0;
/// Set totals for right table
/// Set/Get totals for right table
virtual void setTotals(const Block & block) = 0;
/// Add totals to block from left table
virtual void joinTotals(Block & block) const = 0;
virtual const Block & getTotals() const = 0;
virtual size_t getTotalRowCount() const = 0;
virtual size_t getTotalByteCount() const = 0;

View File

@ -31,9 +31,9 @@ public:
join->joinBlock(block, not_processed);
}
bool hasTotals() const override
const Block & getTotals() const override
{
return join->hasTotals();
return join->getTotals();
}
void setTotals(const Block & block) override
@ -41,11 +41,6 @@ public:
join->setTotals(block);
}
void joinTotals(Block & block) const override
{
join->joinTotals(block);
}
size_t getTotalRowCount() const override
{
return join->getTotalRowCount();

View File

@ -503,11 +503,6 @@ void MergeJoin::setTotals(const Block & totals_block)
used_rows_bitmap = std::make_shared<RowBitmaps>(getRightBlocksCount());
}
void MergeJoin::joinTotals(Block & block) const
{
JoinCommon::joinTotals(totals, right_columns_to_add, *table_join, block);
}
void MergeJoin::mergeRightBlocks()
{
if (is_in_memory)

View File

@ -26,9 +26,10 @@ public:
const TableJoin & getTableJoin() const override { return *table_join; }
bool addJoinedBlock(const Block & block, bool check_limits) override;
void joinBlock(Block &, ExtraBlockPtr & not_processed) override;
void joinTotals(Block &) const override;
void setTotals(const Block &) override;
bool hasTotals() const override { return totals; }
const Block & getTotals() const override { return totals; }
size_t getTotalRowCount() const override { return right_blocks.row_count; }
size_t getTotalByteCount() const override { return right_blocks.bytes; }

View File

@ -322,51 +322,27 @@ void createMissedColumns(Block & block)
}
/// Append totals from right to left block, correct types if needed
void joinTotals(const Block & totals, const Block & columns_to_add, const TableJoin & table_join, Block & block)
void joinTotals(Block left_totals, Block right_totals, const TableJoin & table_join, Block & out_block)
{
if (table_join.forceNullableLeft())
convertColumnsToNullable(block);
JoinCommon::convertColumnsToNullable(left_totals);
if (Block totals_without_keys = totals)
{
const auto & required_right = table_join.requiredRightKeys();
for (const auto & name : table_join.keyNamesRight())
{
if (!required_right.contains(name))
totals_without_keys.erase(totals_without_keys.getPositionByName(name));
}
if (table_join.forceNullableRight())
JoinCommon::convertColumnsToNullable(right_totals);
for (auto & col : totals_without_keys)
for (auto & col : out_block)
{
if (table_join.rightBecomeNullable(col.type))
JoinCommon::convertColumnToNullable(col);
if (const auto * left_col = left_totals.findByName(col.name))
col = *left_col;
else if (const auto * right_col = right_totals.findByName(col.name))
col = *right_col;
else
col.column = col.type->createColumnConstWithDefaultValue(1)->convertToFullColumnIfConst();
/// In case of arrayJoin it can be not one row
/// In case of using `arrayJoin` we can get more or less rows than one
if (col.column->size() != 1)
col.column = col.column->cloneResized(1);
}
for (size_t i = 0; i < totals_without_keys.columns(); ++i)
block.insert(totals_without_keys.safeGetByPosition(i));
}
else
{
/// We will join empty `totals` - from one row with the default values.
for (size_t i = 0; i < columns_to_add.columns(); ++i)
{
const auto & col = columns_to_add.getByPosition(i);
if (block.has(col.name))
{
/// For StorageJoin we discarded table qualifiers, so some names may clash
continue;
}
block.insert({
col.type->createColumnConstWithDefaultValue(1)->convertToFullColumnIfConst(),
col.type,
col.name});
}
}
}
void addDefaultValues(IColumn & column, const DataTypePtr & type, size_t count)

View File

@ -35,7 +35,7 @@ ColumnRawPtrs extractKeysForJoin(const Block & block_keys, const Names & key_nam
void checkTypesOfKeys(const Block & block_left, const Names & key_names_left, const Block & block_right, const Names & key_names_right);
void createMissedColumns(Block & block);
void joinTotals(const Block & totals, const Block & columns_to_add, const TableJoin & table_join, Block & block);
void joinTotals(Block left_totals, Block right_totals, const TableJoin & table_join, Block & out_block);
void addDefaultValues(IColumn & column, const DataTypePtr & type, size_t count);

View File

@ -70,7 +70,7 @@ FilledJoinStep::FilledJoinStep(const DataStream & input_stream_, JoinPtr join_,
void FilledJoinStep::transformPipeline(QueryPipeline & pipeline, const BuildQueryPipelineSettings &)
{
bool default_totals = false;
if (!pipeline.hasTotals() && join->hasTotals())
if (!pipeline.hasTotals() && join->getTotals())
{
pipeline.addDefaultTotals();
default_totals = true;

View File

@ -1,6 +1,6 @@
#include <Processors/Transforms/JoiningTransform.h>
#include <Interpreters/ExpressionAnalyzer.h>
#include <Interpreters/ExpressionActions.h>
#include <Interpreters/join_common.h>
#include <DataStreams/IBlockInputStream.h>
#include <DataTypes/DataTypesNumber.h>
@ -159,19 +159,16 @@ void JoiningTransform::transform(Chunk & chunk)
Block block;
if (on_totals)
{
/// We have to make chunk empty before return
/// In case of using `arrayJoin` we can get more or less rows than one
auto cols = chunk.detachColumns();
for (auto & col : cols)
col = col->cloneResized(1);
block = inputs.front().getHeader().cloneWithColumns(std::move(cols));
const auto & left_totals = inputs.front().getHeader().cloneWithColumns(chunk.detachColumns());
const auto & right_totals = join->getTotals();
/// Drop totals if both out stream and joined stream doesn't have ones.
/// See comment in ExpressionTransform.h
if (default_totals && !join->hasTotals())
if (default_totals && !right_totals)
return;
join->joinTotals(block);
block = outputs.front().getHeader().cloneEmpty();
JoinCommon::joinTotals(left_totals, right_totals, join->getTableJoin(), block);
}
else
block = readExecute(chunk);

View File

@ -27,5 +27,22 @@
0 0
1 1
1 1
0 0
1 1
1 1
0 0
1 1
0 0
1 foo 1 1 300
0 foo 1 0 300
1 100 1970-01-01 1 100 1970-01-01
1 100 1970-01-01 1 200 1970-01-02
1 200 1970-01-02 1 100 1970-01-01
1 200 1970-01-02 1 200 1970-01-02
0 0 1970-01-01 0 0 1970-01-01

View File

@ -67,14 +67,34 @@ FROM (SELECT item_id FROM t GROUP BY item_id WITH TOTALS) l
LEFT JOIN (SELECT item_id FROM t ) r
ON l.item_id = r.item_id;
SELECT *
FROM (SELECT item_id FROM t GROUP BY item_id WITH TOTALS) l
RIGHT JOIN (SELECT item_id FROM t ) r
ON l.item_id = r.item_id;
SELECT *
FROM (SELECT item_id FROM t) l
LEFT JOIN (SELECT item_id FROM t GROUP BY item_id WITH TOTALS ) r
ON l.item_id = r.item_id;
SELECT *
FROM (SELECT item_id FROM t) l
RIGHT JOIN (SELECT item_id FROM t GROUP BY item_id WITH TOTALS ) r
ON l.item_id = r.item_id;
SELECT *
FROM (SELECT item_id FROM t GROUP BY item_id WITH TOTALS) l
LEFT JOIN (SELECT item_id FROM t GROUP BY item_id WITH TOTALS ) r
ON l.item_id = r.item_id;
SELECT *
FROM (SELECT item_id, 'foo' AS key, 1 AS val FROM t GROUP BY item_id WITH TOTALS) l
LEFT JOIN (SELECT item_id, sum(price_sold) AS val FROM t GROUP BY item_id WITH TOTALS ) r
ON l.item_id = r.item_id;
SELECT *
FROM (SELECT * FROM t GROUP BY item_id, price_sold, date WITH TOTALS) l
LEFT JOIN (SELECT * FROM t GROUP BY item_id, price_sold, date WITH TOTALS ) r
ON l.item_id = r.item_id;
DROP TABLE t;