Support NULLs in ROLLUP

This commit is contained in:
Dmitry Novik 2022-06-27 18:42:26 +00:00
parent b5a977ad54
commit 1d15d72211
12 changed files with 152 additions and 23 deletions

View File

@ -129,6 +129,8 @@ static constexpr UInt64 operator""_GiB(unsigned long long value)
M(UInt64, aggregation_memory_efficient_merge_threads, 0, "Number of threads to use for merge intermediate aggregation results in memory efficient mode. When bigger, then more memory is consumed. 0 means - same as 'max_threads'.", 0) \ M(UInt64, aggregation_memory_efficient_merge_threads, 0, "Number of threads to use for merge intermediate aggregation results in memory efficient mode. When bigger, then more memory is consumed. 0 means - same as 'max_threads'.", 0) \
M(Bool, enable_positional_arguments, false, "Enable positional arguments in ORDER BY, GROUP BY and LIMIT BY", 0) \ M(Bool, enable_positional_arguments, false, "Enable positional arguments in ORDER BY, GROUP BY and LIMIT BY", 0) \
\ \
M(Bool, group_by_use_nulls, false, "Treat columns mentioned in ROLLUP, CUBE or GROUPING SETS as Nullable", 0) \
\
M(UInt64, max_parallel_replicas, 1, "The maximum number of replicas of each shard used when the query is executed. For consistency (to get different parts of the same partition), this option only works for the specified sampling key. The lag of the replicas is not controlled.", 0) \ M(UInt64, max_parallel_replicas, 1, "The maximum number of replicas of each shard used when the query is executed. For consistency (to get different parts of the same partition), this option only works for the specified sampling key. The lag of the replicas is not controlled.", 0) \
M(UInt64, parallel_replicas_count, 0, "", 0) \ M(UInt64, parallel_replicas_count, 0, "", 0) \
M(UInt64, parallel_replica_offset, 0, "", 0) \ M(UInt64, parallel_replica_offset, 0, "", 0) \

View File

@ -532,6 +532,12 @@ inline bool isBool(const DataTypePtr & data_type)
return data_type->getName() == "Bool"; return data_type->getName() == "Bool";
} }
inline bool isAggregateFunction(const DataTypePtr & data_type)
{
WhichDataType which(data_type);
return which.isAggregateFunction();
}
template <typename DataType> constexpr bool IsDataTypeDecimal = false; template <typename DataType> constexpr bool IsDataTypeDecimal = false;
template <typename DataType> constexpr bool IsDataTypeNumber = false; template <typename DataType> constexpr bool IsDataTypeNumber = false;
template <typename DataType> constexpr bool IsDataTypeDateOrDateTime = false; template <typename DataType> constexpr bool IsDataTypeDateOrDateTime = false;

View File

@ -41,8 +41,12 @@
#include <Dictionaries/DictionaryStructure.h> #include <Dictionaries/DictionaryStructure.h>
#include "Common/logger_useful.h"
#include <Common/typeid_cast.h> #include <Common/typeid_cast.h>
#include <Common/StringUtils/StringUtils.h> #include <Common/StringUtils/StringUtils.h>
#include "Columns/ColumnNullable.h"
#include "Core/ColumnsWithTypeAndName.h"
#include "DataTypes/IDataType.h"
#include <Core/ColumnNumbers.h> #include <Core/ColumnNumbers.h>
#include <Core/Names.h> #include <Core/Names.h>
#include <Core/NamesAndTypes.h> #include <Core/NamesAndTypes.h>
@ -64,6 +68,7 @@
#include <Processors/Executors/PullingAsyncPipelineExecutor.h> #include <Processors/Executors/PullingAsyncPipelineExecutor.h>
#include <Processors/QueryPlan/QueryPlan.h> #include <Processors/QueryPlan/QueryPlan.h>
#include <Parsers/formatAST.h> #include <Parsers/formatAST.h>
#include <Poco/Logger.h>
namespace DB namespace DB
{ {
@ -393,7 +398,7 @@ void ExpressionAnalyzer::analyzeAggregation(ActionsDAGPtr & temp_actions)
} }
} }
NameAndTypePair key{column_name, node->result_type}; NameAndTypePair key{column_name, makeNullable(node->result_type)};
grouping_set_list.push_back(key); grouping_set_list.push_back(key);
@ -447,7 +452,7 @@ void ExpressionAnalyzer::analyzeAggregation(ActionsDAGPtr & temp_actions)
} }
} }
NameAndTypePair key{column_name, node->result_type}; NameAndTypePair key = select_query->group_by_with_rollup || select_query->group_by_with_cube ? NameAndTypePair{ column_name, makeNullable(node->result_type) } : NameAndTypePair{column_name, node->result_type};
/// Aggregation keys are uniqued. /// Aggregation keys are uniqued.
if (!unique_keys.contains(key.name)) if (!unique_keys.contains(key.name))
@ -1418,6 +1423,28 @@ void SelectQueryExpressionAnalyzer::appendExpressionsAfterWindowFunctions(Expres
} }
} }
void SelectQueryExpressionAnalyzer::appendGroupByModifiers(ActionsDAGPtr & before_aggregation, ExpressionActionsChain & chain, bool /* only_types */)
{
const auto * select_query = getAggregatingQuery();
if (!select_query->groupBy() || !(select_query->group_by_with_rollup || select_query->group_by_with_cube))
return;
auto source_columns = before_aggregation->getResultColumns();
ColumnsWithTypeAndName result_columns;
for (const auto & source_column : source_columns)
{
if (isAggregateFunction(source_column.type))
result_columns.push_back(source_column);
else
result_columns.emplace_back(makeNullable(source_column.type), source_column.name);
}
ExpressionActionsChain::Step & step = chain.lastStep(before_aggregation->getNamesAndTypesList());
step.actions() = ActionsDAG::makeConvertingActions(source_columns, result_columns, ActionsDAG::MatchColumnsMode::Position);
}
bool SelectQueryExpressionAnalyzer::appendHaving(ExpressionActionsChain & chain, bool only_types) bool SelectQueryExpressionAnalyzer::appendHaving(ExpressionActionsChain & chain, bool only_types)
{ {
const auto * select_query = getAggregatingQuery(); const auto * select_query = getAggregatingQuery();
@ -1597,6 +1624,8 @@ ActionsDAGPtr SelectQueryExpressionAnalyzer::appendProjectResult(ExpressionActio
ExpressionActionsChain::Step & step = chain.lastStep(aggregated_columns); ExpressionActionsChain::Step & step = chain.lastStep(aggregated_columns);
LOG_DEBUG(&Poco::Logger::get("SelectQueryExpressionAnalyzer"), "Before output: {}", step.actions()->getNamesAndTypesList().toString());
NamesWithAliases result_columns; NamesWithAliases result_columns;
ASTs asts = select_query->select()->children; ASTs asts = select_query->select()->children;
@ -1638,7 +1667,11 @@ ActionsDAGPtr SelectQueryExpressionAnalyzer::appendProjectResult(ExpressionActio
} }
auto actions = chain.getLastActions(); auto actions = chain.getLastActions();
LOG_DEBUG(&Poco::Logger::get("SelectQueryExpressionAnalyzer"), "Before projection: {}", actions->getNamesAndTypesList().toString());
actions->project(result_columns); actions->project(result_columns);
LOG_DEBUG(&Poco::Logger::get("SelectQueryExpressionAnalyzer"), "After projection: {}", actions->getNamesAndTypesList().toString());
return actions; return actions;
} }
@ -1862,6 +1895,9 @@ ExpressionAnalysisResult::ExpressionAnalysisResult(
query_analyzer.appendAggregateFunctionsArguments(chain, only_types || !first_stage); query_analyzer.appendAggregateFunctionsArguments(chain, only_types || !first_stage);
before_aggregation = chain.getLastActions(); before_aggregation = chain.getLastActions();
before_aggregation_with_nullable = chain.getLastActions();
query_analyzer.appendGroupByModifiers(before_aggregation, chain, only_types);
finalize_chain(chain); finalize_chain(chain);
if (query_analyzer.appendHaving(chain, only_types || !second_stage)) if (query_analyzer.appendHaving(chain, only_types || !second_stage))

View File

@ -245,6 +245,7 @@ struct ExpressionAnalysisResult
JoinPtr join; JoinPtr join;
ActionsDAGPtr before_where; ActionsDAGPtr before_where;
ActionsDAGPtr before_aggregation; ActionsDAGPtr before_aggregation;
ActionsDAGPtr before_aggregation_with_nullable;
ActionsDAGPtr before_having; ActionsDAGPtr before_having;
String having_column_name; String having_column_name;
bool remove_having_filter = false; bool remove_having_filter = false;
@ -410,6 +411,8 @@ private:
void appendExpressionsAfterWindowFunctions(ExpressionActionsChain & chain, bool only_types); void appendExpressionsAfterWindowFunctions(ExpressionActionsChain & chain, bool only_types);
void appendGroupByModifiers(ActionsDAGPtr & before_aggregation, ExpressionActionsChain & chain, bool only_types);
/// After aggregation: /// After aggregation:
bool appendHaving(ExpressionActionsChain & chain, bool only_types); bool appendHaving(ExpressionActionsChain & chain, bool only_types);
/// appendSelect /// appendSelect

View File

@ -582,6 +582,7 @@ InterpreterSelectQuery::InterpreterSelectQuery(
/// Calculate structure of the result. /// Calculate structure of the result.
result_header = getSampleBlockImpl(); result_header = getSampleBlockImpl();
LOG_DEBUG(&Poco::Logger::get("InterpreterSelectQuery"), "Result header: {}", result_header.dumpStructure());
}; };
analyze(shouldMoveToPrewhere()); analyze(shouldMoveToPrewhere());

View File

@ -11,10 +11,13 @@
#include <Processors/Merges/AggregatingSortedTransform.h> #include <Processors/Merges/AggregatingSortedTransform.h>
#include <Processors/Merges/FinishAggregatingInOrderTransform.h> #include <Processors/Merges/FinishAggregatingInOrderTransform.h>
#include <Interpreters/Aggregator.h> #include <Interpreters/Aggregator.h>
#include <Functions/FunctionFactory.h>
#include <Processors/QueryPlan/IQueryPlanStep.h> #include <Processors/QueryPlan/IQueryPlanStep.h>
#include <Columns/ColumnFixedString.h> #include <Columns/ColumnFixedString.h>
#include <DataTypes/DataTypesNumber.h> #include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypeFixedString.h> #include <DataTypes/DataTypeFixedString.h>
#include "Core/ColumnNumbers.h"
#include "DataTypes/IDataType.h"
namespace DB namespace DB
{ {
@ -46,22 +49,44 @@ Block appendGroupingSetColumn(Block header)
return res; return res;
} }
Block generateOutputHeader(const Block & input_header)
{
auto header = appendGroupingSetColumn(input_header);
for (size_t i = 1; i < header.columns(); ++i)
{
auto & column = header.getByPosition(i);
if (!isAggregateFunction(column.type))
{
column.type = makeNullable(column.type);
column.column = makeNullable(column.column);
}
}
return header;
}
Block generateOutputHeader(const Block & input_header, const ColumnNumbers & keys)
{
auto header = appendGroupingSetColumn(input_header);
for (auto key : keys)
{
auto & column = header.getByPosition(key + 1);
if (!isAggregateFunction(column.type))
{
column.type = makeNullable(column.type);
column.column = makeNullable(column.column);
}
}
return header;
}
static Block appendGroupingColumn(Block block, const GroupingSetsParamsList & params) static Block appendGroupingColumn(Block block, const GroupingSetsParamsList & params)
{ {
if (params.empty()) if (params.empty())
return block; return block;
Block res; return generateOutputHeader(block);
size_t rows = block.rows();
auto column = ColumnUInt64::create(rows);
res.insert({ColumnPtr(std::move(column)), std::make_shared<DataTypeUInt64>(), "__grouping_set"});
for (auto & col : block)
res.insert(std::move(col));
return res;
} }
AggregatingStep::AggregatingStep( AggregatingStep::AggregatingStep(
@ -249,7 +274,13 @@ void AggregatingStep::transformPipeline(QueryPipelineBuilder & pipeline, const B
index.push_back(node); index.push_back(node);
} }
else else
index.push_back(dag->getIndex()[header.getPositionByName(col.name)]); {
const auto * column_node = dag->getIndex()[header.getPositionByName(col.name)];
// index.push_back(dag->getIndex()[header.getPositionByName(col.name)]);
const auto * node = &dag->addFunction(FunctionFactory::instance().get("toNullable", nullptr), { column_node }, col.name);
index.push_back(node);
}
} }
dag->getIndex().swap(index); dag->getIndex().swap(index);

View File

@ -3,6 +3,7 @@
#include <QueryPipeline/SizeLimits.h> #include <QueryPipeline/SizeLimits.h>
#include <Storages/SelectQueryInfo.h> #include <Storages/SelectQueryInfo.h>
#include <Interpreters/Aggregator.h> #include <Interpreters/Aggregator.h>
#include "Core/ColumnNumbers.h"
namespace DB namespace DB
{ {
@ -26,6 +27,8 @@ struct GroupingSetsParams
using GroupingSetsParamsList = std::vector<GroupingSetsParams>; using GroupingSetsParamsList = std::vector<GroupingSetsParams>;
Block appendGroupingSetColumn(Block header); Block appendGroupingSetColumn(Block header);
Block generateOutputHeader(const Block & input_header);
Block generateOutputHeader(const Block & input_header, const ColumnNumbers & keys);
/// Aggregation. See AggregatingTransform. /// Aggregation. See AggregatingTransform.
class AggregatingStep : public ITransformingStep class AggregatingStep : public ITransformingStep

View File

@ -23,7 +23,7 @@ static ITransformingStep::Traits getTraits()
} }
RollupStep::RollupStep(const DataStream & input_stream_, AggregatingTransformParamsPtr params_) RollupStep::RollupStep(const DataStream & input_stream_, AggregatingTransformParamsPtr params_)
: ITransformingStep(input_stream_, appendGroupingSetColumn(params_->getHeader()), getTraits()) : ITransformingStep(input_stream_, generateOutputHeader(params_->getHeader(), params_->params.keys), getTraits())
, params(std::move(params_)) , params(std::move(params_))
, keys_size(params->params.keys_size) , keys_size(params->params.keys_size)
{ {

View File

@ -1,16 +1,24 @@
#include <Processors/Transforms/RollupTransform.h> #include <Processors/Transforms/RollupTransform.h>
#include <Processors/Transforms/TotalsHavingTransform.h> #include <Processors/Transforms/TotalsHavingTransform.h>
#include <Processors/QueryPlan/AggregatingStep.h> #include <Processors/QueryPlan/AggregatingStep.h>
#include <Poco/Logger.h>
#include "Common/logger_useful.h"
#include "Columns/ColumnNullable.h"
namespace DB namespace DB
{ {
RollupTransform::RollupTransform(Block header, AggregatingTransformParamsPtr params_) RollupTransform::RollupTransform(Block header, AggregatingTransformParamsPtr params_)
: IAccumulatingTransform(std::move(header), appendGroupingSetColumn(params_->getHeader())) : IAccumulatingTransform(std::move(header), generateOutputHeader(params_->getHeader(), params_->params.keys))
, params(std::move(params_)) , params(std::move(params_))
, keys(params->params.keys) , keys(params->params.keys)
, aggregates_mask(getAggregatesMask(params->getHeader(), params->params.aggregates)) , aggregates_mask(getAggregatesMask(params->getHeader(), params->params.aggregates))
{ {
auto output_aggregator_params = params->params;
intermediate_header = getOutputPort().getHeader();
intermediate_header.erase(0);
output_aggregator_params.src_header = intermediate_header;
output_aggregator = std::make_unique<Aggregator>(output_aggregator_params);
} }
void RollupTransform::consume(Chunk chunk) void RollupTransform::consume(Chunk chunk)
@ -18,13 +26,14 @@ void RollupTransform::consume(Chunk chunk)
consumed_chunks.emplace_back(std::move(chunk)); consumed_chunks.emplace_back(std::move(chunk));
} }
Chunk RollupTransform::merge(Chunks && chunks, bool final) Chunk RollupTransform::merge(Chunks && chunks, bool is_input, bool final)
{ {
BlocksList rollup_blocks; BlocksList rollup_blocks;
auto header = is_input ? getInputPort().getHeader() : intermediate_header;
for (auto & chunk : chunks) for (auto & chunk : chunks)
rollup_blocks.emplace_back(getInputPort().getHeader().cloneWithColumns(chunk.detachColumns())); rollup_blocks.emplace_back(header.cloneWithColumns(chunk.detachColumns()));
auto rollup_block = params->aggregator.mergeBlocks(rollup_blocks, final); auto rollup_block = is_input ? params->aggregator.mergeBlocks(rollup_blocks, final) : output_aggregator->mergeBlocks(rollup_blocks, final);
auto num_rows = rollup_block.rows(); auto num_rows = rollup_block.rows();
return Chunk(rollup_block.getColumns(), num_rows); return Chunk(rollup_block.getColumns(), num_rows);
} }
@ -42,10 +51,17 @@ Chunk RollupTransform::generate()
if (!consumed_chunks.empty()) if (!consumed_chunks.empty())
{ {
if (consumed_chunks.size() > 1) if (consumed_chunks.size() > 1)
rollup_chunk = merge(std::move(consumed_chunks), false); rollup_chunk = merge(std::move(consumed_chunks), true, false);
else else
rollup_chunk = std::move(consumed_chunks.front()); rollup_chunk = std::move(consumed_chunks.front());
size_t rows = rollup_chunk.getNumRows();
auto columns = rollup_chunk.getColumns();
for (auto key : keys)
columns[key] = makeNullable(columns[key]);
rollup_chunk = Chunk{ columns, rows };
LOG_DEBUG(&Poco::Logger::get("RollupTransform"), "Chunk source: {}", rollup_chunk.dumpStructure());
consumed_chunks.clear(); consumed_chunks.clear();
last_removed_key = keys.size(); last_removed_key = keys.size();
} }
@ -59,11 +75,12 @@ Chunk RollupTransform::generate()
auto num_rows = gen_chunk.getNumRows(); auto num_rows = gen_chunk.getNumRows();
auto columns = gen_chunk.getColumns(); auto columns = gen_chunk.getColumns();
columns[key] = getColumnWithDefaults(getInputPort().getHeader(), key, num_rows); columns[key] = getColumnWithDefaults(intermediate_header, key, num_rows);
Chunks chunks; Chunks chunks;
chunks.emplace_back(std::move(columns), num_rows); chunks.emplace_back(std::move(columns), num_rows);
rollup_chunk = merge(std::move(chunks), false); rollup_chunk = merge(std::move(chunks), false, false);
LOG_DEBUG(&Poco::Logger::get("RollupTransform"), "Chunk generated: {}", rollup_chunk.dumpStructure());
} }
finalizeChunk(gen_chunk, aggregates_mask); finalizeChunk(gen_chunk, aggregates_mask);

View File

@ -1,4 +1,5 @@
#pragma once #pragma once
#include <memory>
#include <Processors/IAccumulatingTransform.h> #include <Processors/IAccumulatingTransform.h>
#include <Processors/Transforms/AggregatingTransform.h> #include <Processors/Transforms/AggregatingTransform.h>
#include <Processors/Transforms/finalizeChunk.h> #include <Processors/Transforms/finalizeChunk.h>
@ -23,12 +24,16 @@ private:
const ColumnNumbers keys; const ColumnNumbers keys;
const ColumnsMask aggregates_mask; const ColumnsMask aggregates_mask;
std::unique_ptr<Aggregator> output_aggregator;
Block intermediate_header;
Chunks consumed_chunks; Chunks consumed_chunks;
Chunk rollup_chunk; Chunk rollup_chunk;
size_t last_removed_key = 0; size_t last_removed_key = 0;
size_t set_counter = 0; size_t set_counter = 0;
Chunk merge(Chunks && chunks, bool final); Chunk merge(Chunks && chunks, bool is_input, bool final);
}; };
} }

View File

@ -0,0 +1,21 @@
0 0 0
0 \N 0
1 1 1
1 \N 1
2 0 2
2 \N 2
3 1 3
3 \N 3
4 0 4
4 \N 4
5 1 5
5 \N 5
6 0 6
6 \N 6
7 1 7
7 \N 7
8 0 8
8 \N 8
9 1 9
9 \N 9
\N \N 45

View File

@ -0,0 +1,4 @@
SELECT number, number % 2, sum(number) AS val
FROM numbers(10)
GROUP BY ROLLUP(number, number % 2)
ORDER BY (number, number % 2, val);