Merge pull request #47370 from ClickHouse/fix-grouping-for-grouping-sets

Fix GROUPING function initialization for grouping sets
This commit is contained in:
Maksim Kita 2023-03-15 16:06:49 +03:00 committed by GitHub
commit 4337a3161a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 115 additions and 9 deletions

View File

@ -32,17 +32,17 @@ enum class GroupByKind
GROUPING_SETS GROUPING_SETS
}; };
class GroupingFunctionResolveVisitor : public InDepthQueryTreeVisitor<GroupingFunctionResolveVisitor> class GroupingFunctionResolveVisitor : public InDepthQueryTreeVisitorWithContext<GroupingFunctionResolveVisitor>
{ {
public: public:
GroupingFunctionResolveVisitor(GroupByKind group_by_kind_, GroupingFunctionResolveVisitor(GroupByKind group_by_kind_,
QueryTreeNodePtrWithHashMap<size_t> aggregation_key_to_index_, QueryTreeNodePtrWithHashMap<size_t> aggregation_key_to_index_,
ColumnNumbersList grouping_sets_keys_indices_, ColumnNumbersList grouping_sets_keys_indices_,
ContextPtr context_) ContextPtr context_)
: group_by_kind(group_by_kind_) : InDepthQueryTreeVisitorWithContext(std::move(context_))
, group_by_kind(group_by_kind_)
, aggregation_key_to_index(std::move(aggregation_key_to_index_)) , aggregation_key_to_index(std::move(aggregation_key_to_index_))
, grouping_sets_keys_indexes(std::move(grouping_sets_keys_indices_)) , grouping_sets_keys_indexes(std::move(grouping_sets_keys_indices_))
, context(std::move(context_))
{ {
} }
@ -71,7 +71,7 @@ public:
FunctionOverloadResolverPtr grouping_function_resolver; FunctionOverloadResolverPtr grouping_function_resolver;
bool add_grouping_set_column = false; bool add_grouping_set_column = false;
bool force_grouping_standard_compatibility = context->getSettingsRef().force_grouping_standard_compatibility; bool force_grouping_standard_compatibility = getSettings().force_grouping_standard_compatibility;
size_t aggregation_keys_size = aggregation_key_to_index.size(); size_t aggregation_keys_size = aggregation_key_to_index.size();
switch (group_by_kind) switch (group_by_kind)
@ -132,7 +132,6 @@ private:
GroupByKind group_by_kind; GroupByKind group_by_kind;
QueryTreeNodePtrWithHashMap<size_t> aggregation_key_to_index; QueryTreeNodePtrWithHashMap<size_t> aggregation_key_to_index;
ColumnNumbersList grouping_sets_keys_indexes; ColumnNumbersList grouping_sets_keys_indexes;
ContextPtr context;
}; };
void resolveGroupingFunctions(QueryTreeNodePtr & query_node, ContextPtr context) void resolveGroupingFunctions(QueryTreeNodePtr & query_node, ContextPtr context)
@ -164,12 +163,17 @@ void resolveGroupingFunctions(QueryTreeNodePtr & query_node, ContextPtr context)
grouping_sets_used_aggregation_keys_list.emplace_back(); grouping_sets_used_aggregation_keys_list.emplace_back();
auto & grouping_sets_used_aggregation_keys = grouping_sets_used_aggregation_keys_list.back(); auto & grouping_sets_used_aggregation_keys = grouping_sets_used_aggregation_keys_list.back();
QueryTreeNodePtrWithHashSet used_keys_in_set;
for (auto & grouping_set_key_node : grouping_set_keys_list_node_typed.getNodes()) for (auto & grouping_set_key_node : grouping_set_keys_list_node_typed.getNodes())
{ {
if (used_keys_in_set.contains(grouping_set_key_node))
continue;
used_keys_in_set.insert(grouping_set_key_node);
grouping_sets_used_aggregation_keys.push_back(grouping_set_key_node);
if (aggregation_key_to_index.contains(grouping_set_key_node)) if (aggregation_key_to_index.contains(grouping_set_key_node))
continue; continue;
grouping_sets_used_aggregation_keys.push_back(grouping_set_key_node);
aggregation_key_to_index.emplace(grouping_set_key_node, aggregation_node_index); aggregation_key_to_index.emplace(grouping_set_key_node, aggregation_node_index);
++aggregation_node_index; ++aggregation_node_index;
} }

View File

@ -56,7 +56,7 @@ public:
} }
if (!found_argument_in_group_by_keys) if (!found_argument_in_group_by_keys)
throw Exception(ErrorCodes::NOT_AN_AGGREGATE, throw Exception(ErrorCodes::BAD_ARGUMENTS,
"GROUPING function argument {} is not in GROUP BY keys. In query {}", "GROUPING function argument {} is not in GROUP BY keys. In query {}",
grouping_function_arguments_node->formatASTForErrorMessage(), grouping_function_arguments_node->formatASTForErrorMessage(),
query_node->formatASTForErrorMessage()); query_node->formatASTForErrorMessage());

View File

@ -395,7 +395,11 @@ void addMergingAggregatedStep(QueryPlan & query_plan,
* but it can work more slowly. * but it can work more slowly.
*/ */
Aggregator::Params params(aggregation_analysis_result.aggregation_keys, auto keys = aggregation_analysis_result.aggregation_keys;
if (!aggregation_analysis_result.grouping_sets_parameters_list.empty())
keys.insert(keys.begin(), "__grouping_set");
Aggregator::Params params(keys,
aggregation_analysis_result.aggregate_descriptions, aggregation_analysis_result.aggregate_descriptions,
query_analysis_result.aggregate_overflow_row, query_analysis_result.aggregate_overflow_row,
settings.max_threads, settings.max_threads,

View File

@ -18,6 +18,7 @@
#include <Storages/getStructureOfRemoteTable.h> #include <Storages/getStructureOfRemoteTable.h>
#include <Storages/checkAndGetLiteralArgument.h> #include <Storages/checkAndGetLiteralArgument.h>
#include <Storages/StorageDummy.h> #include <Storages/StorageDummy.h>
#include <Storages/removeGroupingFunctionSpecializations.h>
#include <Columns/ColumnConst.h> #include <Columns/ColumnConst.h>
@ -1020,6 +1021,8 @@ QueryTreeNodePtr buildQueryTreeDistributed(SelectQueryInfo & query_info,
if (!replacement_map.empty()) if (!replacement_map.empty())
query_tree_to_modify = query_tree_to_modify->cloneAndReplace(replacement_map); query_tree_to_modify = query_tree_to_modify->cloneAndReplace(replacement_map);
removeGroupingFunctionSpecializations(query_tree_to_modify);
return query_tree_to_modify; return query_tree_to_modify;
} }

View File

@ -0,0 +1,65 @@
#include <Storages/removeGroupingFunctionSpecializations.h>
#include <Analyzer/InDepthQueryTreeVisitor.h>
#include <Analyzer/ColumnNode.h>
#include <Analyzer/FunctionNode.h>
#include <Common/Exception.h>
#include <Functions/grouping.h>
namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
}
class GeneralizeGroupingFunctionForDistributedVisitor : public InDepthQueryTreeVisitor<GeneralizeGroupingFunctionForDistributedVisitor>
{
public:
static void visitImpl(QueryTreeNodePtr & node)
{
auto * function = node->as<FunctionNode>();
if (!function)
return;
const auto & function_name = function->getFunctionName();
bool ordinary_grouping = function_name == "groupingOrdinary";
if (!ordinary_grouping
&& function_name != "groupingForRollup"
&& function_name != "groupingForCube"
&& function_name != "groupingForGroupingSets")
return;
if (!ordinary_grouping)
{
auto & arguments = function->getArguments().getNodes();
if (arguments.empty())
throw Exception(ErrorCodes::LOGICAL_ERROR, "Grouping function specialization must have arguments");
auto * grouping_set_arg = arguments[0]->as<ColumnNode>();
if (!grouping_set_arg || grouping_set_arg->getColumnName() != "__grouping_set")
throw Exception(ErrorCodes::LOGICAL_ERROR,
"The first argument of Grouping function specialization must be '__grouping_set' column but {} found",
arguments[0]->dumpTree());
arguments.erase(arguments.begin());
}
// This node will be only converted to AST, so we don't need
// to pass the correct force_compatibility flag to FunctionGrouping.
auto function_adaptor = std::make_shared<FunctionToOverloadResolverAdaptor>(
std::make_shared<FunctionGrouping>(false)
);
function->resolveAsFunction(function_adaptor);
}
};
void removeGroupingFunctionSpecializations(QueryTreeNodePtr & node)
{
GeneralizeGroupingFunctionForDistributedVisitor visitor;
visitor.visit(node);
}
}

View File

@ -0,0 +1,10 @@
#pragma once
#include <Analyzer/IQueryTreeNode.h>
namespace DB
{
void removeGroupingFunctionSpecializations(QueryTreeNodePtr & node);
}

View File

@ -1,3 +1,5 @@
set optimize_group_by_function_keys=0;
SELECT SELECT
number, number,
grouping(number, number % 2, number % 3) AS gr grouping(number, number % 2, number % 3) AS gr

View File

@ -1,3 +1,5 @@
set optimize_group_by_function_keys=0;
SELECT SELECT
number, number,
grouping(number, number % 2, number % 3) = 6 grouping(number, number % 2, number % 3) = 6

View File

@ -27,3 +27,17 @@ SELECT count() AS amount, a, b, GROUPING(a, b) FROM test02315 GROUP BY ROLLUP(a,
5 0 0 2 5 0 0 2
5 1 0 2 5 1 0 2
10 0 0 0 10 0 0 0
SELECT count() AS amount, a, b, GROUPING(a, b) FROM test02315 GROUP BY GROUPING SETS ((a, b), (a, a), ()) ORDER BY (amount, a, b) SETTINGS force_grouping_standard_compatibility=0, allow_experimental_analyzer=1;
1 0 0 3
1 0 2 3
1 0 4 3
1 0 6 3
1 0 8 3
1 1 1 3
1 1 3 3
1 1 5 3
1 1 7 3
1 1 9 3
5 0 0 2
5 1 0 2
10 0 0 0

View File

@ -9,5 +9,7 @@ SELECT count() AS amount, a, b, GROUPING(a, b) FROM test02315 GROUP BY GROUPING
SELECT count() AS amount, a, b, GROUPING(a, b) FROM test02315 GROUP BY ROLLUP(a, b) ORDER BY (amount, a, b) SETTINGS force_grouping_standard_compatibility=0; SELECT count() AS amount, a, b, GROUPING(a, b) FROM test02315 GROUP BY ROLLUP(a, b) ORDER BY (amount, a, b) SETTINGS force_grouping_standard_compatibility=0;
SELECT count() AS amount, a, b, GROUPING(a, b) FROM test02315 GROUP BY GROUPING SETS ((a, b), (a, a), ()) ORDER BY (amount, a, b) SETTINGS force_grouping_standard_compatibility=0, allow_experimental_analyzer=1;
-- { echoOff } -- { echoOff }
DROP TABLE test02315; DROP TABLE test02315;