Make optimization more general

This commit is contained in:
Dmitry Novik 2024-01-23 13:03:09 +00:00
parent f702afb72d
commit e9ea1307e0
5 changed files with 201 additions and 34 deletions

View File

@ -40,6 +40,9 @@ public:
void enterImpl(QueryTreeNodePtr & node)
{
if (!getSettings().optimize_injective_functions_in_group_by)
return;
auto * query = node->as<QueryNode>();
if (!query)
return;
@ -72,55 +75,42 @@ private:
new_group_by_keys.reserve(grouping_set.size());
for (auto & group_by_elem : grouping_set)
{
if (auto const * function_node = group_by_elem->as<FunctionNode>())
std::queue<QueryTreeNodePtr> nodes_to_process;
nodes_to_process.push(group_by_elem);
while (!nodes_to_process.empty())
{
bool can_be_eliminated = false;
if (possibly_injective_function_names.contains(function_node->getFunctionName()))
auto node_to_process = nodes_to_process.front();
nodes_to_process.pop();
auto const * function_node = node_to_process->as<FunctionNode>();
if (!function_node)
{
can_be_eliminated = canBeEliminated(function_node, context);
// Constant aggregation keys are removed in PlannerExpressionAnalysis.cpp
new_group_by_keys.push_back(node_to_process);
continue;
}
else
{
// Aggregate functions are not allowed in GROUP BY clause
auto function = function_node->getFunctionOrThrow();
can_be_eliminated = function->isInjective(function_node->getArgumentColumns());
}
bool can_be_eliminated = function->isInjective(function_node->getArgumentColumns());
if (can_be_eliminated)
{
for (auto const & argument : function_node->getArguments())
{
// We can skip constants here because aggregation key is already not a constant.
if (argument->getNodeType() != QueryTreeNodeType::CONSTANT)
new_group_by_keys.push_back(argument);
nodes_to_process.push(argument);
}
}
else
new_group_by_keys.push_back(group_by_elem);
new_group_by_keys.push_back(node_to_process);
}
else
new_group_by_keys.push_back(group_by_elem);
}
grouping_set = std::move(new_group_by_keys);
}
bool canBeEliminated(const FunctionNode * function_node, const ContextPtr & context)
{
const auto & function_arguments = function_node->getArguments().getNodes();
auto const * dict_name_arg = function_arguments[0]->as<ConstantNode>();
if (!dict_name_arg || !isString(dict_name_arg->getResultType()))
return false;
auto dict_name = dict_name_arg->getValue().safeGet<String>();
const auto & dict_ptr = context->getExternalDictionariesLoader().getDictionary(dict_name, context);
auto const * attr_name_arg = function_arguments[1]->as<ConstantNode>();
if (!attr_name_arg || !isString(attr_name_arg->getResultType()))
return false;
auto attr_name = attr_name_arg->getValue().safeGet<String>();
return dict_ptr->isInjective(attr_name);
}
};
}

View File

@ -2315,11 +2315,15 @@ std::pair<bool, UInt64> QueryAnalyzer::recursivelyCollectMaxOrdinaryExpressions(
*/
void QueryAnalyzer::expandGroupByAll(QueryNode & query_tree_node_typed)
{
if (!query_tree_node_typed.isGroupByAll())
return;
auto & group_by_nodes = query_tree_node_typed.getGroupBy().getNodes();
auto & projection_list = query_tree_node_typed.getProjection();
for (auto & node : projection_list.getNodes())
recursivelyCollectMaxOrdinaryExpressions(node, group_by_nodes);
query_tree_node_typed.setIsGroupByAll(false);
}
void QueryAnalyzer::expandOrderByAll(QueryNode & query_tree_node_typed)
@ -7380,7 +7384,6 @@ void QueryAnalyzer::resolveQuery(const QueryTreeNodePtr & query_node, Identifier
node->removeAlias();
}
if (query_node_typed.isGroupByAll())
expandGroupByAll(query_node_typed);
validateFilters(query_node);

View File

@ -699,6 +699,7 @@ class IColumn;
M(SetOperationMode, intersect_default_mode, SetOperationMode::ALL, "Set default mode in INTERSECT query. Possible values: empty string, 'ALL', 'DISTINCT'. If empty, query without mode will throw exception.", 0) \
M(SetOperationMode, except_default_mode, SetOperationMode::ALL, "Set default mode in EXCEPT query. Possible values: empty string, 'ALL', 'DISTINCT'. If empty, query without mode will throw exception.", 0) \
M(Bool, optimize_aggregators_of_group_by_keys, true, "Eliminates min/max/any/anyLast aggregators of GROUP BY keys in SELECT section", 0) \
M(Bool, optimize_injective_functions_in_group_by, true, "Replaces injective functions by it's arguments in GROUP BY section", 0) \
M(Bool, optimize_group_by_function_keys, true, "Eliminates functions of other keys in GROUP BY section", 0) \
M(Bool, optimize_group_by_constant_keys, true, "Optimize GROUP BY when all keys in block are constant", 0) \
M(Bool, legacy_column_name_of_tuple_literal, false, "List all names of element of large tuple literals in their column names instead of hash. This settings exists only for compatibility reasons. It makes sense to set to 'true', while doing rolling update of cluster from version lower than 21.7 to higher.", 0) \

View File

@ -0,0 +1,142 @@
QUERY id: 0
PROJECTION COLUMNS
val String
count() UInt64
PROJECTION
LIST id: 1, nodes: 2
FUNCTION id: 2, function_name: toString, function_type: ordinary, result_type: String
ARGUMENTS
LIST id: 3, nodes: 1
FUNCTION id: 4, function_name: toString, function_type: ordinary, result_type: String
ARGUMENTS
LIST id: 5, nodes: 1
FUNCTION id: 6, function_name: plus, function_type: ordinary, result_type: UInt64
ARGUMENTS
LIST id: 7, nodes: 2
COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 9
CONSTANT id: 10, constant_value: UInt64_1, constant_value_type: UInt8
FUNCTION id: 11, function_name: count, function_type: aggregate, result_type: UInt64
JOIN TREE
TABLE_FUNCTION id: 9, alias: __table1, table_function_name: numbers
ARGUMENTS
LIST id: 12, nodes: 1
CONSTANT id: 13, constant_value: UInt64_2, constant_value_type: UInt8
GROUP BY
LIST id: 14, nodes: 1
FUNCTION id: 6, function_name: plus, function_type: ordinary, result_type: UInt64
ARGUMENTS
LIST id: 7, nodes: 2
COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 9
CONSTANT id: 10, constant_value: UInt64_1, constant_value_type: UInt8
ORDER BY
LIST id: 15, nodes: 1
SORT id: 16, sort_direction: ASCENDING, with_fill: 0
EXPRESSION
FUNCTION id: 2, function_name: toString, function_type: ordinary, result_type: String
ARGUMENTS
LIST id: 3, nodes: 1
FUNCTION id: 4, function_name: toString, function_type: ordinary, result_type: String
ARGUMENTS
LIST id: 5, nodes: 1
FUNCTION id: 6, function_name: plus, function_type: ordinary, result_type: UInt64
ARGUMENTS
LIST id: 7, nodes: 2
COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 9
CONSTANT id: 10, constant_value: UInt64_1, constant_value_type: UInt8
1 1
2 1
QUERY id: 0
PROJECTION COLUMNS
val String
count() UInt64
PROJECTION
LIST id: 1, nodes: 2
FUNCTION id: 2, function_name: toString, function_type: ordinary, result_type: String
ARGUMENTS
LIST id: 3, nodes: 1
FUNCTION id: 4, function_name: toString, function_type: ordinary, result_type: String
ARGUMENTS
LIST id: 5, nodes: 1
FUNCTION id: 6, function_name: plus, function_type: ordinary, result_type: UInt64
ARGUMENTS
LIST id: 7, nodes: 2
COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 9
CONSTANT id: 10, constant_value: UInt64_1, constant_value_type: UInt8
FUNCTION id: 11, function_name: count, function_type: aggregate, result_type: UInt64
JOIN TREE
TABLE_FUNCTION id: 9, alias: __table1, table_function_name: numbers
ARGUMENTS
LIST id: 12, nodes: 1
CONSTANT id: 13, constant_value: UInt64_2, constant_value_type: UInt8
GROUP BY
LIST id: 14, nodes: 1
FUNCTION id: 6, function_name: plus, function_type: ordinary, result_type: UInt64
ARGUMENTS
LIST id: 7, nodes: 2
COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 9
CONSTANT id: 10, constant_value: UInt64_1, constant_value_type: UInt8
ORDER BY
LIST id: 15, nodes: 1
SORT id: 16, sort_direction: ASCENDING, with_fill: 0
EXPRESSION
FUNCTION id: 2, function_name: toString, function_type: ordinary, result_type: String
ARGUMENTS
LIST id: 3, nodes: 1
FUNCTION id: 4, function_name: toString, function_type: ordinary, result_type: String
ARGUMENTS
LIST id: 5, nodes: 1
FUNCTION id: 6, function_name: plus, function_type: ordinary, result_type: UInt64
ARGUMENTS
LIST id: 7, nodes: 2
COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 9
CONSTANT id: 10, constant_value: UInt64_1, constant_value_type: UInt8
CHECK WITH TOTALS
QUERY id: 0, is_group_by_with_totals: 1
PROJECTION COLUMNS
val String
count() UInt64
PROJECTION
LIST id: 1, nodes: 2
FUNCTION id: 2, function_name: toString, function_type: ordinary, result_type: String
ARGUMENTS
LIST id: 3, nodes: 1
FUNCTION id: 4, function_name: toString, function_type: ordinary, result_type: String
ARGUMENTS
LIST id: 5, nodes: 1
FUNCTION id: 6, function_name: plus, function_type: ordinary, result_type: UInt64
ARGUMENTS
LIST id: 7, nodes: 2
COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 9
CONSTANT id: 10, constant_value: UInt64_1, constant_value_type: UInt8
FUNCTION id: 11, function_name: count, function_type: aggregate, result_type: UInt64
JOIN TREE
TABLE_FUNCTION id: 9, alias: __table1, table_function_name: numbers
ARGUMENTS
LIST id: 12, nodes: 1
CONSTANT id: 13, constant_value: UInt64_2, constant_value_type: UInt8
GROUP BY
LIST id: 14, nodes: 1
FUNCTION id: 6, function_name: plus, function_type: ordinary, result_type: UInt64
ARGUMENTS
LIST id: 7, nodes: 2
COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 9
CONSTANT id: 10, constant_value: UInt64_1, constant_value_type: UInt8
ORDER BY
LIST id: 15, nodes: 1
SORT id: 16, sort_direction: ASCENDING, with_fill: 0
EXPRESSION
FUNCTION id: 2, function_name: toString, function_type: ordinary, result_type: String
ARGUMENTS
LIST id: 3, nodes: 1
FUNCTION id: 4, function_name: toString, function_type: ordinary, result_type: String
ARGUMENTS
LIST id: 5, nodes: 1
FUNCTION id: 6, function_name: plus, function_type: ordinary, result_type: UInt64
ARGUMENTS
LIST id: 7, nodes: 2
COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 9
CONSTANT id: 10, constant_value: UInt64_1, constant_value_type: UInt8
1 1
2 1
0 2

View File

@ -0,0 +1,31 @@
set allow_experimental_analyzer = 1;
EXPLAIN QUERY TREE
SELECT toString(toString(number + 1)) as val, count()
FROM numbers(2)
GROUP BY val
ORDER BY val;
SELECT toString(toString(number + 1)) as val, count()
FROM numbers(2)
GROUP BY ALL
ORDER BY val;
EXPLAIN QUERY TREE
SELECT toString(toString(number + 1)) as val, count()
FROM numbers(2)
GROUP BY ALL
ORDER BY val;
SELECT 'CHECK WITH TOTALS';
EXPLAIN QUERY TREE
SELECT toString(toString(number + 1)) as val, count()
FROM numbers(2)
GROUP BY val WITH TOTALS
ORDER BY val;
SELECT toString(toString(number + 1)) as val, count()
FROM numbers(2)
GROUP BY val WITH TOTALS
ORDER BY val;