From 1d15d72211f19273793dea91d1c750333f7bf366 Mon Sep 17 00:00:00 2001 From: Dmitry Novik Date: Mon, 27 Jun 2022 18:42:26 +0000 Subject: [PATCH] Support NULLs in ROLLUP --- src/Core/Settings.h | 2 + src/DataTypes/IDataType.h | 6 ++ src/Interpreters/ExpressionAnalyzer.cpp | 40 +++++++++++++- src/Interpreters/ExpressionAnalyzer.h | 3 + src/Interpreters/InterpreterSelectQuery.cpp | 1 + src/Processors/QueryPlan/AggregatingStep.cpp | 55 +++++++++++++++---- src/Processors/QueryPlan/AggregatingStep.h | 3 + src/Processors/QueryPlan/RollupStep.cpp | 2 +- src/Processors/Transforms/RollupTransform.cpp | 31 ++++++++--- src/Processors/Transforms/RollupTransform.h | 7 ++- .../02343_group_by_use_nulls.reference | 21 +++++++ .../0_stateless/02343_group_by_use_nulls.sql | 4 ++ 12 files changed, 152 insertions(+), 23 deletions(-) create mode 100644 tests/queries/0_stateless/02343_group_by_use_nulls.reference create mode 100644 tests/queries/0_stateless/02343_group_by_use_nulls.sql diff --git a/src/Core/Settings.h b/src/Core/Settings.h index 756e6eb651e..5972ca4a9a3 100644 --- a/src/Core/Settings.h +++ b/src/Core/Settings.h @@ -129,6 +129,8 @@ static constexpr UInt64 operator""_GiB(unsigned long long value) M(UInt64, aggregation_memory_efficient_merge_threads, 0, "Number of threads to use for merge intermediate aggregation results in memory efficient mode. When bigger, then more memory is consumed. 0 means - same as 'max_threads'.", 0) \ M(Bool, enable_positional_arguments, false, "Enable positional arguments in ORDER BY, GROUP BY and LIMIT BY", 0) \ \ + M(Bool, group_by_use_nulls, false, "Treat columns mentioned in ROLLUP, CUBE or GROUPING SETS as Nullable", 0) \ + \ M(UInt64, max_parallel_replicas, 1, "The maximum number of replicas of each shard used when the query is executed. For consistency (to get different parts of the same partition), this option only works for the specified sampling key. The lag of the replicas is not controlled.", 0) \ M(UInt64, parallel_replicas_count, 0, "", 0) \ M(UInt64, parallel_replica_offset, 0, "", 0) \ diff --git a/src/DataTypes/IDataType.h b/src/DataTypes/IDataType.h index 420ef61a13f..08c8fd74f3e 100644 --- a/src/DataTypes/IDataType.h +++ b/src/DataTypes/IDataType.h @@ -532,6 +532,12 @@ inline bool isBool(const DataTypePtr & data_type) return data_type->getName() == "Bool"; } +inline bool isAggregateFunction(const DataTypePtr & data_type) +{ + WhichDataType which(data_type); + return which.isAggregateFunction(); +} + template constexpr bool IsDataTypeDecimal = false; template constexpr bool IsDataTypeNumber = false; template constexpr bool IsDataTypeDateOrDateTime = false; diff --git a/src/Interpreters/ExpressionAnalyzer.cpp b/src/Interpreters/ExpressionAnalyzer.cpp index cfe1167c36c..07b2a1ce1f9 100644 --- a/src/Interpreters/ExpressionAnalyzer.cpp +++ b/src/Interpreters/ExpressionAnalyzer.cpp @@ -41,8 +41,12 @@ #include +#include "Common/logger_useful.h" #include #include +#include "Columns/ColumnNullable.h" +#include "Core/ColumnsWithTypeAndName.h" +#include "DataTypes/IDataType.h" #include #include #include @@ -64,6 +68,7 @@ #include #include #include +#include namespace DB { @@ -393,7 +398,7 @@ void ExpressionAnalyzer::analyzeAggregation(ActionsDAGPtr & temp_actions) } } - NameAndTypePair key{column_name, node->result_type}; + NameAndTypePair key{column_name, makeNullable(node->result_type)}; grouping_set_list.push_back(key); @@ -447,7 +452,7 @@ void ExpressionAnalyzer::analyzeAggregation(ActionsDAGPtr & temp_actions) } } - NameAndTypePair key{column_name, node->result_type}; + NameAndTypePair key = select_query->group_by_with_rollup || select_query->group_by_with_cube ? NameAndTypePair{ column_name, makeNullable(node->result_type) } : NameAndTypePair{column_name, node->result_type}; /// Aggregation keys are uniqued. if (!unique_keys.contains(key.name)) @@ -1418,6 +1423,28 @@ void SelectQueryExpressionAnalyzer::appendExpressionsAfterWindowFunctions(Expres } } +void SelectQueryExpressionAnalyzer::appendGroupByModifiers(ActionsDAGPtr & before_aggregation, ExpressionActionsChain & chain, bool /* only_types */) +{ + const auto * select_query = getAggregatingQuery(); + + if (!select_query->groupBy() || !(select_query->group_by_with_rollup || select_query->group_by_with_cube)) + return; + + auto source_columns = before_aggregation->getResultColumns(); + ColumnsWithTypeAndName result_columns; + + for (const auto & source_column : source_columns) + { + if (isAggregateFunction(source_column.type)) + result_columns.push_back(source_column); + else + result_columns.emplace_back(makeNullable(source_column.type), source_column.name); + } + ExpressionActionsChain::Step & step = chain.lastStep(before_aggregation->getNamesAndTypesList()); + + step.actions() = ActionsDAG::makeConvertingActions(source_columns, result_columns, ActionsDAG::MatchColumnsMode::Position); +} + bool SelectQueryExpressionAnalyzer::appendHaving(ExpressionActionsChain & chain, bool only_types) { const auto * select_query = getAggregatingQuery(); @@ -1597,6 +1624,8 @@ ActionsDAGPtr SelectQueryExpressionAnalyzer::appendProjectResult(ExpressionActio ExpressionActionsChain::Step & step = chain.lastStep(aggregated_columns); + LOG_DEBUG(&Poco::Logger::get("SelectQueryExpressionAnalyzer"), "Before output: {}", step.actions()->getNamesAndTypesList().toString()); + NamesWithAliases result_columns; ASTs asts = select_query->select()->children; @@ -1638,7 +1667,11 @@ ActionsDAGPtr SelectQueryExpressionAnalyzer::appendProjectResult(ExpressionActio } auto actions = chain.getLastActions(); + LOG_DEBUG(&Poco::Logger::get("SelectQueryExpressionAnalyzer"), "Before projection: {}", actions->getNamesAndTypesList().toString()); + actions->project(result_columns); + LOG_DEBUG(&Poco::Logger::get("SelectQueryExpressionAnalyzer"), "After projection: {}", actions->getNamesAndTypesList().toString()); + return actions; } @@ -1862,6 +1895,9 @@ ExpressionAnalysisResult::ExpressionAnalysisResult( query_analyzer.appendAggregateFunctionsArguments(chain, only_types || !first_stage); before_aggregation = chain.getLastActions(); + before_aggregation_with_nullable = chain.getLastActions(); + query_analyzer.appendGroupByModifiers(before_aggregation, chain, only_types); + finalize_chain(chain); if (query_analyzer.appendHaving(chain, only_types || !second_stage)) diff --git a/src/Interpreters/ExpressionAnalyzer.h b/src/Interpreters/ExpressionAnalyzer.h index 6c27d8c6760..7a183e865c0 100644 --- a/src/Interpreters/ExpressionAnalyzer.h +++ b/src/Interpreters/ExpressionAnalyzer.h @@ -245,6 +245,7 @@ struct ExpressionAnalysisResult JoinPtr join; ActionsDAGPtr before_where; ActionsDAGPtr before_aggregation; + ActionsDAGPtr before_aggregation_with_nullable; ActionsDAGPtr before_having; String having_column_name; bool remove_having_filter = false; @@ -410,6 +411,8 @@ private: void appendExpressionsAfterWindowFunctions(ExpressionActionsChain & chain, bool only_types); + void appendGroupByModifiers(ActionsDAGPtr & before_aggregation, ExpressionActionsChain & chain, bool only_types); + /// After aggregation: bool appendHaving(ExpressionActionsChain & chain, bool only_types); /// appendSelect diff --git a/src/Interpreters/InterpreterSelectQuery.cpp b/src/Interpreters/InterpreterSelectQuery.cpp index ec7c3878b06..feae7ac6a21 100644 --- a/src/Interpreters/InterpreterSelectQuery.cpp +++ b/src/Interpreters/InterpreterSelectQuery.cpp @@ -582,6 +582,7 @@ InterpreterSelectQuery::InterpreterSelectQuery( /// Calculate structure of the result. result_header = getSampleBlockImpl(); + LOG_DEBUG(&Poco::Logger::get("InterpreterSelectQuery"), "Result header: {}", result_header.dumpStructure()); }; analyze(shouldMoveToPrewhere()); diff --git a/src/Processors/QueryPlan/AggregatingStep.cpp b/src/Processors/QueryPlan/AggregatingStep.cpp index 17a0498fb7e..a0f5fce908b 100644 --- a/src/Processors/QueryPlan/AggregatingStep.cpp +++ b/src/Processors/QueryPlan/AggregatingStep.cpp @@ -11,10 +11,13 @@ #include #include #include +#include #include #include #include #include +#include "Core/ColumnNumbers.h" +#include "DataTypes/IDataType.h" namespace DB { @@ -46,22 +49,44 @@ Block appendGroupingSetColumn(Block header) return res; } +Block generateOutputHeader(const Block & input_header) +{ + auto header = appendGroupingSetColumn(input_header); + for (size_t i = 1; i < header.columns(); ++i) + { + auto & column = header.getByPosition(i); + + if (!isAggregateFunction(column.type)) + { + column.type = makeNullable(column.type); + column.column = makeNullable(column.column); + } + } + return header; +} + +Block generateOutputHeader(const Block & input_header, const ColumnNumbers & keys) +{ + auto header = appendGroupingSetColumn(input_header); + for (auto key : keys) + { + auto & column = header.getByPosition(key + 1); + + if (!isAggregateFunction(column.type)) + { + column.type = makeNullable(column.type); + column.column = makeNullable(column.column); + } + } + return header; +} + static Block appendGroupingColumn(Block block, const GroupingSetsParamsList & params) { if (params.empty()) return block; - Block res; - - size_t rows = block.rows(); - auto column = ColumnUInt64::create(rows); - - res.insert({ColumnPtr(std::move(column)), std::make_shared(), "__grouping_set"}); - - for (auto & col : block) - res.insert(std::move(col)); - - return res; + return generateOutputHeader(block); } AggregatingStep::AggregatingStep( @@ -249,7 +274,13 @@ void AggregatingStep::transformPipeline(QueryPipelineBuilder & pipeline, const B index.push_back(node); } else - index.push_back(dag->getIndex()[header.getPositionByName(col.name)]); + { + const auto * column_node = dag->getIndex()[header.getPositionByName(col.name)]; + // index.push_back(dag->getIndex()[header.getPositionByName(col.name)]); + + const auto * node = &dag->addFunction(FunctionFactory::instance().get("toNullable", nullptr), { column_node }, col.name); + index.push_back(node); + } } dag->getIndex().swap(index); diff --git a/src/Processors/QueryPlan/AggregatingStep.h b/src/Processors/QueryPlan/AggregatingStep.h index 4dd3d956350..3d024a99063 100644 --- a/src/Processors/QueryPlan/AggregatingStep.h +++ b/src/Processors/QueryPlan/AggregatingStep.h @@ -3,6 +3,7 @@ #include #include #include +#include "Core/ColumnNumbers.h" namespace DB { @@ -26,6 +27,8 @@ struct GroupingSetsParams using GroupingSetsParamsList = std::vector; Block appendGroupingSetColumn(Block header); +Block generateOutputHeader(const Block & input_header); +Block generateOutputHeader(const Block & input_header, const ColumnNumbers & keys); /// Aggregation. See AggregatingTransform. class AggregatingStep : public ITransformingStep diff --git a/src/Processors/QueryPlan/RollupStep.cpp b/src/Processors/QueryPlan/RollupStep.cpp index 3b061f9c246..5109a5ce169 100644 --- a/src/Processors/QueryPlan/RollupStep.cpp +++ b/src/Processors/QueryPlan/RollupStep.cpp @@ -23,7 +23,7 @@ static ITransformingStep::Traits getTraits() } RollupStep::RollupStep(const DataStream & input_stream_, AggregatingTransformParamsPtr params_) - : ITransformingStep(input_stream_, appendGroupingSetColumn(params_->getHeader()), getTraits()) + : ITransformingStep(input_stream_, generateOutputHeader(params_->getHeader(), params_->params.keys), getTraits()) , params(std::move(params_)) , keys_size(params->params.keys_size) { diff --git a/src/Processors/Transforms/RollupTransform.cpp b/src/Processors/Transforms/RollupTransform.cpp index b69a691323c..6ac5ae35fa2 100644 --- a/src/Processors/Transforms/RollupTransform.cpp +++ b/src/Processors/Transforms/RollupTransform.cpp @@ -1,16 +1,24 @@ #include #include #include +#include +#include "Common/logger_useful.h" +#include "Columns/ColumnNullable.h" namespace DB { RollupTransform::RollupTransform(Block header, AggregatingTransformParamsPtr params_) - : IAccumulatingTransform(std::move(header), appendGroupingSetColumn(params_->getHeader())) + : IAccumulatingTransform(std::move(header), generateOutputHeader(params_->getHeader(), params_->params.keys)) , params(std::move(params_)) , keys(params->params.keys) , aggregates_mask(getAggregatesMask(params->getHeader(), params->params.aggregates)) { + auto output_aggregator_params = params->params; + intermediate_header = getOutputPort().getHeader(); + intermediate_header.erase(0); + output_aggregator_params.src_header = intermediate_header; + output_aggregator = std::make_unique(output_aggregator_params); } void RollupTransform::consume(Chunk chunk) @@ -18,13 +26,14 @@ void RollupTransform::consume(Chunk chunk) consumed_chunks.emplace_back(std::move(chunk)); } -Chunk RollupTransform::merge(Chunks && chunks, bool final) +Chunk RollupTransform::merge(Chunks && chunks, bool is_input, bool final) { BlocksList rollup_blocks; + auto header = is_input ? getInputPort().getHeader() : intermediate_header; for (auto & chunk : chunks) - rollup_blocks.emplace_back(getInputPort().getHeader().cloneWithColumns(chunk.detachColumns())); + rollup_blocks.emplace_back(header.cloneWithColumns(chunk.detachColumns())); - auto rollup_block = params->aggregator.mergeBlocks(rollup_blocks, final); + auto rollup_block = is_input ? params->aggregator.mergeBlocks(rollup_blocks, final) : output_aggregator->mergeBlocks(rollup_blocks, final); auto num_rows = rollup_block.rows(); return Chunk(rollup_block.getColumns(), num_rows); } @@ -42,9 +51,16 @@ Chunk RollupTransform::generate() if (!consumed_chunks.empty()) { if (consumed_chunks.size() > 1) - rollup_chunk = merge(std::move(consumed_chunks), false); + rollup_chunk = merge(std::move(consumed_chunks), true, false); else rollup_chunk = std::move(consumed_chunks.front()); + + size_t rows = rollup_chunk.getNumRows(); + auto columns = rollup_chunk.getColumns(); + for (auto key : keys) + columns[key] = makeNullable(columns[key]); + rollup_chunk = Chunk{ columns, rows }; + LOG_DEBUG(&Poco::Logger::get("RollupTransform"), "Chunk source: {}", rollup_chunk.dumpStructure()); consumed_chunks.clear(); last_removed_key = keys.size(); @@ -59,11 +75,12 @@ Chunk RollupTransform::generate() auto num_rows = gen_chunk.getNumRows(); auto columns = gen_chunk.getColumns(); - columns[key] = getColumnWithDefaults(getInputPort().getHeader(), key, num_rows); + columns[key] = getColumnWithDefaults(intermediate_header, key, num_rows); Chunks chunks; chunks.emplace_back(std::move(columns), num_rows); - rollup_chunk = merge(std::move(chunks), false); + rollup_chunk = merge(std::move(chunks), false, false); + LOG_DEBUG(&Poco::Logger::get("RollupTransform"), "Chunk generated: {}", rollup_chunk.dumpStructure()); } finalizeChunk(gen_chunk, aggregates_mask); diff --git a/src/Processors/Transforms/RollupTransform.h b/src/Processors/Transforms/RollupTransform.h index 8fd27e3e6a2..8b66c85e0b5 100644 --- a/src/Processors/Transforms/RollupTransform.h +++ b/src/Processors/Transforms/RollupTransform.h @@ -1,4 +1,5 @@ #pragma once +#include #include #include #include @@ -23,12 +24,16 @@ private: const ColumnNumbers keys; const ColumnsMask aggregates_mask; + std::unique_ptr output_aggregator; + + Block intermediate_header; + Chunks consumed_chunks; Chunk rollup_chunk; size_t last_removed_key = 0; size_t set_counter = 0; - Chunk merge(Chunks && chunks, bool final); + Chunk merge(Chunks && chunks, bool is_input, bool final); }; } diff --git a/tests/queries/0_stateless/02343_group_by_use_nulls.reference b/tests/queries/0_stateless/02343_group_by_use_nulls.reference new file mode 100644 index 00000000000..0d7fa8f3a3b --- /dev/null +++ b/tests/queries/0_stateless/02343_group_by_use_nulls.reference @@ -0,0 +1,21 @@ +0 0 0 +0 \N 0 +1 1 1 +1 \N 1 +2 0 2 +2 \N 2 +3 1 3 +3 \N 3 +4 0 4 +4 \N 4 +5 1 5 +5 \N 5 +6 0 6 +6 \N 6 +7 1 7 +7 \N 7 +8 0 8 +8 \N 8 +9 1 9 +9 \N 9 +\N \N 45 diff --git a/tests/queries/0_stateless/02343_group_by_use_nulls.sql b/tests/queries/0_stateless/02343_group_by_use_nulls.sql new file mode 100644 index 00000000000..1107ae79244 --- /dev/null +++ b/tests/queries/0_stateless/02343_group_by_use_nulls.sql @@ -0,0 +1,4 @@ +SELECT number, number % 2, sum(number) AS val +FROM numbers(10) +GROUP BY ROLLUP(number, number % 2) +ORDER BY (number, number % 2, val);