Merge pull request #58919 from ClickHouse/analyzer-inj-func-elemination

Analyzer: Support GROUP BY injective function elimination
This commit is contained in:
Dmitry Novik 2024-01-26 16:45:09 +01:00 committed by GitHub
commit 6c5057c4f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 361 additions and 9 deletions

View File

@ -0,0 +1,124 @@
#include <Analyzer/Passes/OptimizeGroupByInjectiveFunctionsPass.h>
#include <Analyzer/ConstantNode.h>
#include <Analyzer/FunctionNode.h>
#include <Analyzer/InDepthQueryTreeVisitor.h>
#include <Analyzer/IQueryTreeNode.h>
#include <DataTypes/IDataType.h>
#include <Interpreters/ExternalDictionariesLoader.h>
namespace DB
{
namespace
{
const std::unordered_set<String> possibly_injective_function_names
{
"dictGet",
"dictGetString",
"dictGetUInt8",
"dictGetUInt16",
"dictGetUInt32",
"dictGetUInt64",
"dictGetInt8",
"dictGetInt16",
"dictGetInt32",
"dictGetInt64",
"dictGetFloat32",
"dictGetFloat64",
"dictGetDate",
"dictGetDateTime"
};
class OptimizeGroupByInjectiveFunctionsVisitor : public InDepthQueryTreeVisitorWithContext<OptimizeGroupByInjectiveFunctionsVisitor>
{
using Base = InDepthQueryTreeVisitorWithContext<OptimizeGroupByInjectiveFunctionsVisitor>;
public:
explicit OptimizeGroupByInjectiveFunctionsVisitor(ContextPtr context)
: Base(std::move(context))
{}
void enterImpl(QueryTreeNodePtr & node)
{
if (!getSettings().optimize_injective_functions_in_group_by)
return;
auto * query = node->as<QueryNode>();
if (!query)
return;
if (!query->hasGroupBy())
return;
if (query->isGroupByWithCube() || query->isGroupByWithRollup())
return;
auto & group_by = query->getGroupBy().getNodes();
if (query->isGroupByWithGroupingSets())
{
for (auto & set : group_by)
{
auto & grouping_set = set->as<ListNode>()->getNodes();
optimizeGroupingSet(grouping_set);
}
}
else
optimizeGroupingSet(group_by);
}
private:
void optimizeGroupingSet(QueryTreeNodes & grouping_set)
{
auto context = getContext();
QueryTreeNodes new_group_by_keys;
new_group_by_keys.reserve(grouping_set.size());
for (auto & group_by_elem : grouping_set)
{
std::queue<QueryTreeNodePtr> nodes_to_process;
nodes_to_process.push(group_by_elem);
while (!nodes_to_process.empty())
{
auto node_to_process = nodes_to_process.front();
nodes_to_process.pop();
auto const * function_node = node_to_process->as<FunctionNode>();
if (!function_node)
{
// Constant aggregation keys are removed in PlannerExpressionAnalysis.cpp
new_group_by_keys.push_back(node_to_process);
continue;
}
// Aggregate functions are not allowed in GROUP BY clause
auto function = function_node->getFunctionOrThrow();
bool can_be_eliminated = function->isInjective(function_node->getArgumentColumns());
if (can_be_eliminated)
{
for (auto const & argument : function_node->getArguments())
{
// We can skip constants here because aggregation key is already not a constant.
if (argument->getNodeType() != QueryTreeNodeType::CONSTANT)
nodes_to_process.push(argument);
}
}
else
new_group_by_keys.push_back(node_to_process);
}
}
grouping_set = std::move(new_group_by_keys);
}
};
}
void OptimizeGroupByInjectiveFunctionsPass::run(QueryTreeNodePtr query_tree_node, ContextPtr context)
{
OptimizeGroupByInjectiveFunctionsVisitor visitor(std::move(context));
visitor.visit(query_tree_node);
}
}

View File

@ -0,0 +1,20 @@
#pragma once
#include <Analyzer/IQueryTreePass.h>
namespace DB
{
/* Eliminates injective functions in GROUP BY section.
*/
class OptimizeGroupByInjectiveFunctionsPass final : public IQueryTreePass
{
public:
String getName() override { return "OptimizeGroupByInjectiveFunctionsPass"; }
String getDescription() override { return "Replaces injective functions by it's arguments in GROUP BY section."; }
void run(QueryTreeNodePtr query_tree_node, ContextPtr context) override;
};
}

View File

@ -2321,11 +2321,15 @@ std::pair<bool, UInt64> QueryAnalyzer::recursivelyCollectMaxOrdinaryExpressions(
*/
void QueryAnalyzer::expandGroupByAll(QueryNode & query_tree_node_typed)
{
if (!query_tree_node_typed.isGroupByAll())
return;
auto & group_by_nodes = query_tree_node_typed.getGroupBy().getNodes();
auto & projection_list = query_tree_node_typed.getProjection();
for (auto & node : projection_list.getNodes())
recursivelyCollectMaxOrdinaryExpressions(node, group_by_nodes);
query_tree_node_typed.setIsGroupByAll(false);
}
void QueryAnalyzer::expandOrderByAll(QueryNode & query_tree_node_typed)
@ -7422,8 +7426,7 @@ void QueryAnalyzer::resolveQuery(const QueryTreeNodePtr & query_node, Identifier
node->removeAlias();
}
if (query_node_typed.isGroupByAll())
expandGroupByAll(query_node_typed);
expandGroupByAll(query_node_typed);
validateFilters(query_node);
validateAggregates(query_node, { .group_by_use_nulls = scope.group_by_use_nulls });

View File

@ -3,6 +3,7 @@
#include <memory>
#include <Common/Exception.h>
#include "Analyzer/Passes/OptimizeGroupByInjectiveFunctionsPass.h"
#include <IO/WriteHelpers.h>
#include <IO/Operators.h>
@ -163,8 +164,6 @@ private:
/** ClickHouse query tree pass manager.
*
* TODO: Support setting optimize_substitute_columns.
* TODO: Support GROUP BY injective function elimination.
* TODO: Support setting optimize_aggregators_of_group_by_keys.
* TODO: Support setting optimize_monotonous_functions_in_order_by.
* TODO: Add optimizations based on function semantics. Example: SELECT * FROM test_table WHERE id != id. (id is not nullable column).
@ -268,6 +267,7 @@ void addQueryTreePasses(QueryTreePassManager & manager)
manager.addPass(std::make_unique<AggregateFunctionsArithmericOperationsPass>());
manager.addPass(std::make_unique<UniqInjectiveFunctionsEliminationPass>());
manager.addPass(std::make_unique<OptimizeGroupByFunctionKeysPass>());
manager.addPass(std::make_unique<OptimizeGroupByInjectiveFunctionsPass>());
manager.addPass(std::make_unique<MultiIfToIfPass>());
manager.addPass(std::make_unique<IfConstantConditionPass>());

View File

@ -699,6 +699,7 @@ class IColumn;
M(SetOperationMode, intersect_default_mode, SetOperationMode::ALL, "Set default mode in INTERSECT query. Possible values: empty string, 'ALL', 'DISTINCT'. If empty, query without mode will throw exception.", 0) \
M(SetOperationMode, except_default_mode, SetOperationMode::ALL, "Set default mode in EXCEPT query. Possible values: empty string, 'ALL', 'DISTINCT'. If empty, query without mode will throw exception.", 0) \
M(Bool, optimize_aggregators_of_group_by_keys, true, "Eliminates min/max/any/anyLast aggregators of GROUP BY keys in SELECT section", 0) \
M(Bool, optimize_injective_functions_in_group_by, true, "Replaces injective functions by it's arguments in GROUP BY section", 0) \
M(Bool, optimize_group_by_function_keys, true, "Eliminates functions of other keys in GROUP BY section", 0) \
M(Bool, optimize_group_by_constant_keys, true, "Optimize GROUP BY when all keys in block are constant", 0) \
M(Bool, legacy_column_name_of_tuple_literal, false, "List all names of element of large tuple literals in their column names instead of hash. This settings exists only for compatibility reasons. It makes sense to set to 'true', while doing rolling update of cluster from version lower than 21.7 to higher.", 0) \

View File

@ -1 +1,24 @@
1.1
SELECT dictGet(\'dictdb_01376.dict_exists\', \'value\', toUInt64(1)) AS val
FROM numbers(2)
GROUP BY toUInt64(1)
QUERY id: 0
PROJECTION COLUMNS
val Float64
PROJECTION
LIST id: 1, nodes: 1
FUNCTION id: 2, function_name: dictGet, function_type: ordinary, result_type: Float64
ARGUMENTS
LIST id: 3, nodes: 3
CONSTANT id: 4, constant_value: \'dictdb_01376.dict_exists\', constant_value_type: String
CONSTANT id: 5, constant_value: \'value\', constant_value_type: String
COLUMN id: 6, column_name: number, result_type: UInt64, source_id: 7
JOIN TREE
TABLE_FUNCTION id: 7, alias: __table1, table_function_name: numbers
ARGUMENTS
LIST id: 8, nodes: 1
CONSTANT id: 9, constant_value: UInt64_2, constant_value_type: UInt8
GROUP BY
LIST id: 10, nodes: 1
COLUMN id: 6, column_name: number, result_type: UInt64, source_id: 7
SETTINGS allow_experimental_analyzer=1

View File

@ -23,7 +23,7 @@ INSERT INTO dictdb_01376.table_for_dict VALUES (1, 1.1);
CREATE DICTIONARY IF NOT EXISTS dictdb_01376.dict_exists
(
key_column UInt64,
value Float64 DEFAULT 77.77
value Float64 DEFAULT 77.77 INJECTIVE
)
PRIMARY KEY key_column
SOURCE(CLICKHOUSE(HOST 'localhost' PORT tcpPort() USER 'default' TABLE 'table_for_dict' DB 'dictdb_01376'))
@ -32,6 +32,14 @@ LAYOUT(FLAT());
SELECT dictGet('dictdb_01376.dict_exists', 'value', toUInt64(1)) as val FROM numbers(2) GROUP BY val;
EXPLAIN SYNTAX SELECT dictGet('dictdb_01376.dict_exists', 'value', toUInt64(1)) as val FROM numbers(2) GROUP BY val;
EXPLAIN QUERY TREE
SELECT dictGet('dictdb_01376.dict_exists', 'value', number) as val
FROM numbers(2)
GROUP BY val
SETTINGS allow_experimental_analyzer = 1;
DROP DICTIONARY dictdb_01376.dict_exists;
DROP TABLE dictdb_01376.table_for_dict;
DROP DATABASE dictdb_01376;

View File

@ -20,17 +20,17 @@ clickhouse-client --allow_experimental_analyzer=1 --query_kind initial_query -q
Expression ((Project names + Projection))
Header: dummy String
Aggregating
Header: toString(__table1.dummy) String
Header: __table1.dummy UInt8
Expression ((Before GROUP BY + Change column names to column identifiers))
Header: toString(__table1.dummy) String
Header: __table1.dummy UInt8
ReadFromStorage (SystemOne)
Header: dummy UInt8
clickhouse-local --allow_experimental_analyzer=1 --query_kind initial_query -q explain plan header=1 select toString(dummy) as dummy from system.one group by dummy
Expression ((Project names + Projection))
Header: dummy String
Aggregating
Header: toString(__table1.dummy) String
Header: __table1.dummy UInt8
Expression ((Before GROUP BY + Change column names to column identifiers))
Header: toString(__table1.dummy) String
Header: __table1.dummy UInt8
ReadFromStorage (SystemOne)
Header: dummy UInt8

View File

@ -0,0 +1,142 @@
QUERY id: 0
PROJECTION COLUMNS
val String
count() UInt64
PROJECTION
LIST id: 1, nodes: 2
FUNCTION id: 2, function_name: toString, function_type: ordinary, result_type: String
ARGUMENTS
LIST id: 3, nodes: 1
FUNCTION id: 4, function_name: toString, function_type: ordinary, result_type: String
ARGUMENTS
LIST id: 5, nodes: 1
FUNCTION id: 6, function_name: plus, function_type: ordinary, result_type: UInt64
ARGUMENTS
LIST id: 7, nodes: 2
COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 9
CONSTANT id: 10, constant_value: UInt64_1, constant_value_type: UInt8
FUNCTION id: 11, function_name: count, function_type: aggregate, result_type: UInt64
JOIN TREE
TABLE_FUNCTION id: 9, alias: __table1, table_function_name: numbers
ARGUMENTS
LIST id: 12, nodes: 1
CONSTANT id: 13, constant_value: UInt64_2, constant_value_type: UInt8
GROUP BY
LIST id: 14, nodes: 1
FUNCTION id: 6, function_name: plus, function_type: ordinary, result_type: UInt64
ARGUMENTS
LIST id: 7, nodes: 2
COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 9
CONSTANT id: 10, constant_value: UInt64_1, constant_value_type: UInt8
ORDER BY
LIST id: 15, nodes: 1
SORT id: 16, sort_direction: ASCENDING, with_fill: 0
EXPRESSION
FUNCTION id: 2, function_name: toString, function_type: ordinary, result_type: String
ARGUMENTS
LIST id: 3, nodes: 1
FUNCTION id: 4, function_name: toString, function_type: ordinary, result_type: String
ARGUMENTS
LIST id: 5, nodes: 1
FUNCTION id: 6, function_name: plus, function_type: ordinary, result_type: UInt64
ARGUMENTS
LIST id: 7, nodes: 2
COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 9
CONSTANT id: 10, constant_value: UInt64_1, constant_value_type: UInt8
1 1
2 1
QUERY id: 0
PROJECTION COLUMNS
val String
count() UInt64
PROJECTION
LIST id: 1, nodes: 2
FUNCTION id: 2, function_name: toString, function_type: ordinary, result_type: String
ARGUMENTS
LIST id: 3, nodes: 1
FUNCTION id: 4, function_name: toString, function_type: ordinary, result_type: String
ARGUMENTS
LIST id: 5, nodes: 1
FUNCTION id: 6, function_name: plus, function_type: ordinary, result_type: UInt64
ARGUMENTS
LIST id: 7, nodes: 2
COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 9
CONSTANT id: 10, constant_value: UInt64_1, constant_value_type: UInt8
FUNCTION id: 11, function_name: count, function_type: aggregate, result_type: UInt64
JOIN TREE
TABLE_FUNCTION id: 9, alias: __table1, table_function_name: numbers
ARGUMENTS
LIST id: 12, nodes: 1
CONSTANT id: 13, constant_value: UInt64_2, constant_value_type: UInt8
GROUP BY
LIST id: 14, nodes: 1
FUNCTION id: 6, function_name: plus, function_type: ordinary, result_type: UInt64
ARGUMENTS
LIST id: 7, nodes: 2
COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 9
CONSTANT id: 10, constant_value: UInt64_1, constant_value_type: UInt8
ORDER BY
LIST id: 15, nodes: 1
SORT id: 16, sort_direction: ASCENDING, with_fill: 0
EXPRESSION
FUNCTION id: 2, function_name: toString, function_type: ordinary, result_type: String
ARGUMENTS
LIST id: 3, nodes: 1
FUNCTION id: 4, function_name: toString, function_type: ordinary, result_type: String
ARGUMENTS
LIST id: 5, nodes: 1
FUNCTION id: 6, function_name: plus, function_type: ordinary, result_type: UInt64
ARGUMENTS
LIST id: 7, nodes: 2
COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 9
CONSTANT id: 10, constant_value: UInt64_1, constant_value_type: UInt8
CHECK WITH TOTALS
QUERY id: 0, is_group_by_with_totals: 1
PROJECTION COLUMNS
val String
count() UInt64
PROJECTION
LIST id: 1, nodes: 2
FUNCTION id: 2, function_name: toString, function_type: ordinary, result_type: String
ARGUMENTS
LIST id: 3, nodes: 1
FUNCTION id: 4, function_name: toString, function_type: ordinary, result_type: String
ARGUMENTS
LIST id: 5, nodes: 1
FUNCTION id: 6, function_name: plus, function_type: ordinary, result_type: UInt64
ARGUMENTS
LIST id: 7, nodes: 2
COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 9
CONSTANT id: 10, constant_value: UInt64_1, constant_value_type: UInt8
FUNCTION id: 11, function_name: count, function_type: aggregate, result_type: UInt64
JOIN TREE
TABLE_FUNCTION id: 9, alias: __table1, table_function_name: numbers
ARGUMENTS
LIST id: 12, nodes: 1
CONSTANT id: 13, constant_value: UInt64_2, constant_value_type: UInt8
GROUP BY
LIST id: 14, nodes: 1
FUNCTION id: 6, function_name: plus, function_type: ordinary, result_type: UInt64
ARGUMENTS
LIST id: 7, nodes: 2
COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 9
CONSTANT id: 10, constant_value: UInt64_1, constant_value_type: UInt8
ORDER BY
LIST id: 15, nodes: 1
SORT id: 16, sort_direction: ASCENDING, with_fill: 0
EXPRESSION
FUNCTION id: 2, function_name: toString, function_type: ordinary, result_type: String
ARGUMENTS
LIST id: 3, nodes: 1
FUNCTION id: 4, function_name: toString, function_type: ordinary, result_type: String
ARGUMENTS
LIST id: 5, nodes: 1
FUNCTION id: 6, function_name: plus, function_type: ordinary, result_type: UInt64
ARGUMENTS
LIST id: 7, nodes: 2
COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 9
CONSTANT id: 10, constant_value: UInt64_1, constant_value_type: UInt8
1 1
2 1
0 2

View File

@ -0,0 +1,31 @@
set allow_experimental_analyzer = 1;
EXPLAIN QUERY TREE
SELECT toString(toString(number + 1)) as val, count()
FROM numbers(2)
GROUP BY val
ORDER BY val;
SELECT toString(toString(number + 1)) as val, count()
FROM numbers(2)
GROUP BY ALL
ORDER BY val;
EXPLAIN QUERY TREE
SELECT toString(toString(number + 1)) as val, count()
FROM numbers(2)
GROUP BY ALL
ORDER BY val;
SELECT 'CHECK WITH TOTALS';
EXPLAIN QUERY TREE
SELECT toString(toString(number + 1)) as val, count()
FROM numbers(2)
GROUP BY val WITH TOTALS
ORDER BY val;
SELECT toString(toString(number + 1)) as val, count()
FROM numbers(2)
GROUP BY val WITH TOTALS
ORDER BY val;