Fix UniqInjectiveFunctionsEliminationPass with uniqCombined

This commit is contained in:
Raúl Marín 2024-06-12 19:08:23 +02:00
parent f4c5d172ac
commit b7161b77d1
3 changed files with 31 additions and 6 deletions

View File

@ -41,9 +41,9 @@ public:
return;
bool replaced_argument = false;
auto & uniq_function_arguments_nodes = function_node->getArguments().getNodes();
auto replaced_uniq_function_arguments_nodes = function_node->getArguments().getNodes();
for (auto & uniq_function_argument_node : uniq_function_arguments_nodes)
for (auto & uniq_function_argument_node : replaced_uniq_function_arguments_nodes)
{
auto * uniq_function_argument_node_typed = uniq_function_argument_node->as<FunctionNode>();
if (!uniq_function_argument_node_typed || !uniq_function_argument_node_typed->isOrdinaryFunction())
@ -67,12 +67,10 @@ public:
if (!replaced_argument)
return;
const auto & function_node_argument_nodes = function_node->getArguments().getNodes();
DataTypes argument_types;
argument_types.reserve(function_node_argument_nodes.size());
argument_types.reserve(replaced_uniq_function_arguments_nodes.size());
for (const auto & function_node_argument : function_node_argument_nodes)
for (const auto & function_node_argument : replaced_uniq_function_arguments_nodes)
argument_types.emplace_back(function_node_argument->getResultType());
AggregateFunctionProperties properties;
@ -83,6 +81,10 @@ public:
function_node->getAggregateFunction()->getParameters(),
properties);
/// uniqCombined returns nullable with nullable arguments so the result type might change which breaks the pass
if (!aggregate_function->getResultType()->equals(*function_node->getAggregateFunction()->getResultType()))
return;
function_node->resolveAsAggregateFunction(std::move(aggregate_function));
}
};

View File

@ -0,0 +1,21 @@
SELECT sum(u)
FROM
(
SELECT
intDiv(number, 4096) AS k,
uniqCombined(tuple(materialize(toLowCardinality(toNullable(16))))) AS u
FROM numbers(4096 * 100)
GROUP BY k
)
SETTINGS allow_experimental_analyzer = 1, optimize_injective_functions_inside_uniq=0;
SELECT sum(u)
FROM
(
SELECT
intDiv(number, 4096) AS k,
uniqCombined(tuple(materialize(toLowCardinality(toNullable(16))))) AS u
FROM numbers(4096 * 100)
GROUP BY k
)
SETTINGS allow_experimental_analyzer = 1, optimize_injective_functions_inside_uniq=1;