Support NULLs in ROLLUP

This commit is contained in:
Dmitry Novik 2022-06-27 18:42:26 +00:00
parent b5a977ad54
commit 1d15d72211
12 changed files with 152 additions and 23 deletions

View File

@ -129,6 +129,8 @@ static constexpr UInt64 operator""_GiB(unsigned long long value)
M(UInt64, aggregation_memory_efficient_merge_threads, 0, "Number of threads to use for merge intermediate aggregation results in memory efficient mode. When bigger, then more memory is consumed. 0 means - same as 'max_threads'.", 0) \
M(Bool, enable_positional_arguments, false, "Enable positional arguments in ORDER BY, GROUP BY and LIMIT BY", 0) \
\
M(Bool, group_by_use_nulls, false, "Treat columns mentioned in ROLLUP, CUBE or GROUPING SETS as Nullable", 0) \
\
M(UInt64, max_parallel_replicas, 1, "The maximum number of replicas of each shard used when the query is executed. For consistency (to get different parts of the same partition), this option only works for the specified sampling key. The lag of the replicas is not controlled.", 0) \
M(UInt64, parallel_replicas_count, 0, "", 0) \
M(UInt64, parallel_replica_offset, 0, "", 0) \

View File

@ -532,6 +532,12 @@ inline bool isBool(const DataTypePtr & data_type)
return data_type->getName() == "Bool";
}
inline bool isAggregateFunction(const DataTypePtr & data_type)
{
WhichDataType which(data_type);
return which.isAggregateFunction();
}
template <typename DataType> constexpr bool IsDataTypeDecimal = false;
template <typename DataType> constexpr bool IsDataTypeNumber = false;
template <typename DataType> constexpr bool IsDataTypeDateOrDateTime = false;

View File

@ -41,8 +41,12 @@
#include <Dictionaries/DictionaryStructure.h>
#include "Common/logger_useful.h"
#include <Common/typeid_cast.h>
#include <Common/StringUtils/StringUtils.h>
#include "Columns/ColumnNullable.h"
#include "Core/ColumnsWithTypeAndName.h"
#include "DataTypes/IDataType.h"
#include <Core/ColumnNumbers.h>
#include <Core/Names.h>
#include <Core/NamesAndTypes.h>
@ -64,6 +68,7 @@
#include <Processors/Executors/PullingAsyncPipelineExecutor.h>
#include <Processors/QueryPlan/QueryPlan.h>
#include <Parsers/formatAST.h>
#include <Poco/Logger.h>
namespace DB
{
@ -393,7 +398,7 @@ void ExpressionAnalyzer::analyzeAggregation(ActionsDAGPtr & temp_actions)
}
}
NameAndTypePair key{column_name, node->result_type};
NameAndTypePair key{column_name, makeNullable(node->result_type)};
grouping_set_list.push_back(key);
@ -447,7 +452,7 @@ void ExpressionAnalyzer::analyzeAggregation(ActionsDAGPtr & temp_actions)
}
}
NameAndTypePair key{column_name, node->result_type};
NameAndTypePair key = select_query->group_by_with_rollup || select_query->group_by_with_cube ? NameAndTypePair{ column_name, makeNullable(node->result_type) } : NameAndTypePair{column_name, node->result_type};
/// Aggregation keys are uniqued.
if (!unique_keys.contains(key.name))
@ -1418,6 +1423,28 @@ void SelectQueryExpressionAnalyzer::appendExpressionsAfterWindowFunctions(Expres
}
}
void SelectQueryExpressionAnalyzer::appendGroupByModifiers(ActionsDAGPtr & before_aggregation, ExpressionActionsChain & chain, bool /* only_types */)
{
const auto * select_query = getAggregatingQuery();
if (!select_query->groupBy() || !(select_query->group_by_with_rollup || select_query->group_by_with_cube))
return;
auto source_columns = before_aggregation->getResultColumns();
ColumnsWithTypeAndName result_columns;
for (const auto & source_column : source_columns)
{
if (isAggregateFunction(source_column.type))
result_columns.push_back(source_column);
else
result_columns.emplace_back(makeNullable(source_column.type), source_column.name);
}
ExpressionActionsChain::Step & step = chain.lastStep(before_aggregation->getNamesAndTypesList());
step.actions() = ActionsDAG::makeConvertingActions(source_columns, result_columns, ActionsDAG::MatchColumnsMode::Position);
}
bool SelectQueryExpressionAnalyzer::appendHaving(ExpressionActionsChain & chain, bool only_types)
{
const auto * select_query = getAggregatingQuery();
@ -1597,6 +1624,8 @@ ActionsDAGPtr SelectQueryExpressionAnalyzer::appendProjectResult(ExpressionActio
ExpressionActionsChain::Step & step = chain.lastStep(aggregated_columns);
LOG_DEBUG(&Poco::Logger::get("SelectQueryExpressionAnalyzer"), "Before output: {}", step.actions()->getNamesAndTypesList().toString());
NamesWithAliases result_columns;
ASTs asts = select_query->select()->children;
@ -1638,7 +1667,11 @@ ActionsDAGPtr SelectQueryExpressionAnalyzer::appendProjectResult(ExpressionActio
}
auto actions = chain.getLastActions();
LOG_DEBUG(&Poco::Logger::get("SelectQueryExpressionAnalyzer"), "Before projection: {}", actions->getNamesAndTypesList().toString());
actions->project(result_columns);
LOG_DEBUG(&Poco::Logger::get("SelectQueryExpressionAnalyzer"), "After projection: {}", actions->getNamesAndTypesList().toString());
return actions;
}
@ -1862,6 +1895,9 @@ ExpressionAnalysisResult::ExpressionAnalysisResult(
query_analyzer.appendAggregateFunctionsArguments(chain, only_types || !first_stage);
before_aggregation = chain.getLastActions();
before_aggregation_with_nullable = chain.getLastActions();
query_analyzer.appendGroupByModifiers(before_aggregation, chain, only_types);
finalize_chain(chain);
if (query_analyzer.appendHaving(chain, only_types || !second_stage))

View File

@ -245,6 +245,7 @@ struct ExpressionAnalysisResult
JoinPtr join;
ActionsDAGPtr before_where;
ActionsDAGPtr before_aggregation;
ActionsDAGPtr before_aggregation_with_nullable;
ActionsDAGPtr before_having;
String having_column_name;
bool remove_having_filter = false;
@ -410,6 +411,8 @@ private:
void appendExpressionsAfterWindowFunctions(ExpressionActionsChain & chain, bool only_types);
void appendGroupByModifiers(ActionsDAGPtr & before_aggregation, ExpressionActionsChain & chain, bool only_types);
/// After aggregation:
bool appendHaving(ExpressionActionsChain & chain, bool only_types);
/// appendSelect

View File

@ -582,6 +582,7 @@ InterpreterSelectQuery::InterpreterSelectQuery(
/// Calculate structure of the result.
result_header = getSampleBlockImpl();
LOG_DEBUG(&Poco::Logger::get("InterpreterSelectQuery"), "Result header: {}", result_header.dumpStructure());
};
analyze(shouldMoveToPrewhere());

View File

@ -11,10 +11,13 @@
#include <Processors/Merges/AggregatingSortedTransform.h>
#include <Processors/Merges/FinishAggregatingInOrderTransform.h>
#include <Interpreters/Aggregator.h>
#include <Functions/FunctionFactory.h>
#include <Processors/QueryPlan/IQueryPlanStep.h>
#include <Columns/ColumnFixedString.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypeFixedString.h>
#include "Core/ColumnNumbers.h"
#include "DataTypes/IDataType.h"
namespace DB
{
@ -46,22 +49,44 @@ Block appendGroupingSetColumn(Block header)
return res;
}
Block generateOutputHeader(const Block & input_header)
{
auto header = appendGroupingSetColumn(input_header);
for (size_t i = 1; i < header.columns(); ++i)
{
auto & column = header.getByPosition(i);
if (!isAggregateFunction(column.type))
{
column.type = makeNullable(column.type);
column.column = makeNullable(column.column);
}
}
return header;
}
Block generateOutputHeader(const Block & input_header, const ColumnNumbers & keys)
{
auto header = appendGroupingSetColumn(input_header);
for (auto key : keys)
{
auto & column = header.getByPosition(key + 1);
if (!isAggregateFunction(column.type))
{
column.type = makeNullable(column.type);
column.column = makeNullable(column.column);
}
}
return header;
}
static Block appendGroupingColumn(Block block, const GroupingSetsParamsList & params)
{
if (params.empty())
return block;
Block res;
size_t rows = block.rows();
auto column = ColumnUInt64::create(rows);
res.insert({ColumnPtr(std::move(column)), std::make_shared<DataTypeUInt64>(), "__grouping_set"});
for (auto & col : block)
res.insert(std::move(col));
return res;
return generateOutputHeader(block);
}
AggregatingStep::AggregatingStep(
@ -249,7 +274,13 @@ void AggregatingStep::transformPipeline(QueryPipelineBuilder & pipeline, const B
index.push_back(node);
}
else
index.push_back(dag->getIndex()[header.getPositionByName(col.name)]);
{
const auto * column_node = dag->getIndex()[header.getPositionByName(col.name)];
// index.push_back(dag->getIndex()[header.getPositionByName(col.name)]);
const auto * node = &dag->addFunction(FunctionFactory::instance().get("toNullable", nullptr), { column_node }, col.name);
index.push_back(node);
}
}
dag->getIndex().swap(index);

View File

@ -3,6 +3,7 @@
#include <QueryPipeline/SizeLimits.h>
#include <Storages/SelectQueryInfo.h>
#include <Interpreters/Aggregator.h>
#include "Core/ColumnNumbers.h"
namespace DB
{
@ -26,6 +27,8 @@ struct GroupingSetsParams
using GroupingSetsParamsList = std::vector<GroupingSetsParams>;
Block appendGroupingSetColumn(Block header);
Block generateOutputHeader(const Block & input_header);
Block generateOutputHeader(const Block & input_header, const ColumnNumbers & keys);
/// Aggregation. See AggregatingTransform.
class AggregatingStep : public ITransformingStep

View File

@ -23,7 +23,7 @@ static ITransformingStep::Traits getTraits()
}
RollupStep::RollupStep(const DataStream & input_stream_, AggregatingTransformParamsPtr params_)
: ITransformingStep(input_stream_, appendGroupingSetColumn(params_->getHeader()), getTraits())
: ITransformingStep(input_stream_, generateOutputHeader(params_->getHeader(), params_->params.keys), getTraits())
, params(std::move(params_))
, keys_size(params->params.keys_size)
{

View File

@ -1,16 +1,24 @@
#include <Processors/Transforms/RollupTransform.h>
#include <Processors/Transforms/TotalsHavingTransform.h>
#include <Processors/QueryPlan/AggregatingStep.h>
#include <Poco/Logger.h>
#include "Common/logger_useful.h"
#include "Columns/ColumnNullable.h"
namespace DB
{
RollupTransform::RollupTransform(Block header, AggregatingTransformParamsPtr params_)
: IAccumulatingTransform(std::move(header), appendGroupingSetColumn(params_->getHeader()))
: IAccumulatingTransform(std::move(header), generateOutputHeader(params_->getHeader(), params_->params.keys))
, params(std::move(params_))
, keys(params->params.keys)
, aggregates_mask(getAggregatesMask(params->getHeader(), params->params.aggregates))
{
auto output_aggregator_params = params->params;
intermediate_header = getOutputPort().getHeader();
intermediate_header.erase(0);
output_aggregator_params.src_header = intermediate_header;
output_aggregator = std::make_unique<Aggregator>(output_aggregator_params);
}
void RollupTransform::consume(Chunk chunk)
@ -18,13 +26,14 @@ void RollupTransform::consume(Chunk chunk)
consumed_chunks.emplace_back(std::move(chunk));
}
Chunk RollupTransform::merge(Chunks && chunks, bool final)
Chunk RollupTransform::merge(Chunks && chunks, bool is_input, bool final)
{
BlocksList rollup_blocks;
auto header = is_input ? getInputPort().getHeader() : intermediate_header;
for (auto & chunk : chunks)
rollup_blocks.emplace_back(getInputPort().getHeader().cloneWithColumns(chunk.detachColumns()));
rollup_blocks.emplace_back(header.cloneWithColumns(chunk.detachColumns()));
auto rollup_block = params->aggregator.mergeBlocks(rollup_blocks, final);
auto rollup_block = is_input ? params->aggregator.mergeBlocks(rollup_blocks, final) : output_aggregator->mergeBlocks(rollup_blocks, final);
auto num_rows = rollup_block.rows();
return Chunk(rollup_block.getColumns(), num_rows);
}
@ -42,9 +51,16 @@ Chunk RollupTransform::generate()
if (!consumed_chunks.empty())
{
if (consumed_chunks.size() > 1)
rollup_chunk = merge(std::move(consumed_chunks), false);
rollup_chunk = merge(std::move(consumed_chunks), true, false);
else
rollup_chunk = std::move(consumed_chunks.front());
size_t rows = rollup_chunk.getNumRows();
auto columns = rollup_chunk.getColumns();
for (auto key : keys)
columns[key] = makeNullable(columns[key]);
rollup_chunk = Chunk{ columns, rows };
LOG_DEBUG(&Poco::Logger::get("RollupTransform"), "Chunk source: {}", rollup_chunk.dumpStructure());
consumed_chunks.clear();
last_removed_key = keys.size();
@ -59,11 +75,12 @@ Chunk RollupTransform::generate()
auto num_rows = gen_chunk.getNumRows();
auto columns = gen_chunk.getColumns();
columns[key] = getColumnWithDefaults(getInputPort().getHeader(), key, num_rows);
columns[key] = getColumnWithDefaults(intermediate_header, key, num_rows);
Chunks chunks;
chunks.emplace_back(std::move(columns), num_rows);
rollup_chunk = merge(std::move(chunks), false);
rollup_chunk = merge(std::move(chunks), false, false);
LOG_DEBUG(&Poco::Logger::get("RollupTransform"), "Chunk generated: {}", rollup_chunk.dumpStructure());
}
finalizeChunk(gen_chunk, aggregates_mask);

View File

@ -1,4 +1,5 @@
#pragma once
#include <memory>
#include <Processors/IAccumulatingTransform.h>
#include <Processors/Transforms/AggregatingTransform.h>
#include <Processors/Transforms/finalizeChunk.h>
@ -23,12 +24,16 @@ private:
const ColumnNumbers keys;
const ColumnsMask aggregates_mask;
std::unique_ptr<Aggregator> output_aggregator;
Block intermediate_header;
Chunks consumed_chunks;
Chunk rollup_chunk;
size_t last_removed_key = 0;
size_t set_counter = 0;
Chunk merge(Chunks && chunks, bool final);
Chunk merge(Chunks && chunks, bool is_input, bool final);
};
}

View File

@ -0,0 +1,21 @@
0 0 0
0 \N 0
1 1 1
1 \N 1
2 0 2
2 \N 2
3 1 3
3 \N 3
4 0 4
4 \N 4
5 1 5
5 \N 5
6 0 6
6 \N 6
7 1 7
7 \N 7
8 0 8
8 \N 8
9 1 9
9 \N 9
\N \N 45

View File

@ -0,0 +1,4 @@
SELECT number, number % 2, sum(number) AS val
FROM numbers(10)
GROUP BY ROLLUP(number, number % 2)
ORDER BY (number, number % 2, val);