Merge pull request #43543 from kitaisreal/analyzer-sum-if-to-count-if-fix

Analyzer SumIfToCountIfPass fix
This commit is contained in:
Maksim Kita 2022-11-25 21:02:20 +03:00 committed by GitHub
commit c46a659ad9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 111 additions and 11 deletions

View File

@ -61,7 +61,7 @@ public:
function_node_arguments_nodes[0] = std::move(function_node_arguments_nodes[1]); function_node_arguments_nodes[0] = std::move(function_node_arguments_nodes[1]);
function_node_arguments_nodes.resize(1); function_node_arguments_nodes.resize(1);
resolveAggregateFunctionNode(*function_node, "countIf"); resolveAsCountIfAggregateFunction(*function_node, function_node_arguments_nodes[0]->getResultType());
return; return;
} }
@ -102,15 +102,16 @@ public:
function_node_arguments_nodes[0] = std::move(nested_if_function_arguments_nodes[0]); function_node_arguments_nodes[0] = std::move(nested_if_function_arguments_nodes[0]);
function_node_arguments_nodes.resize(1); function_node_arguments_nodes.resize(1);
resolveAggregateFunctionNode(*function_node, "countIf"); resolveAsCountIfAggregateFunction(*function_node, function_node_arguments_nodes[0]->getResultType());
return; return;
} }
/// Rewrite `sum(if(cond, 0, 1))` into `countIf(not(cond))`. /// Rewrite `sum(if(cond, 0, 1))` into `countIf(not(cond))`.
if (if_true_condition_value == 0 && if_false_condition_value == 1) if (if_true_condition_value == 0 && if_false_condition_value == 1)
{ {
auto condition_result_type = nested_if_function_arguments_nodes[0]->getResultType();
DataTypePtr not_function_result_type = std::make_shared<DataTypeUInt8>(); DataTypePtr not_function_result_type = std::make_shared<DataTypeUInt8>();
const auto & condition_result_type = nested_if_function_arguments_nodes[0]->getResultType();
if (condition_result_type->isNullable()) if (condition_result_type->isNullable())
not_function_result_type = makeNullable(not_function_result_type); not_function_result_type = makeNullable(not_function_result_type);
@ -123,23 +124,21 @@ public:
function_node_arguments_nodes[0] = std::move(not_function); function_node_arguments_nodes[0] = std::move(not_function);
function_node_arguments_nodes.resize(1); function_node_arguments_nodes.resize(1);
resolveAggregateFunctionNode(*function_node, "countIf"); resolveAsCountIfAggregateFunction(*function_node, function_node_arguments_nodes[0]->getResultType());
return; return;
} }
} }
private: private:
static inline void resolveAggregateFunctionNode(FunctionNode & function_node, const String & aggregate_function_name) static inline void resolveAsCountIfAggregateFunction(FunctionNode & function_node, const DataTypePtr & argument_type)
{ {
auto function_result_type = function_node.getResultType();
auto function_aggregate_function = function_node.getAggregateFunction();
AggregateFunctionProperties properties; AggregateFunctionProperties properties;
auto aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name, auto aggregate_function = AggregateFunctionFactory::instance().get("countIf",
function_aggregate_function->getArgumentTypes(), {argument_type},
function_aggregate_function->getParameters(), function_node.getAggregateFunction()->getParameters(),
properties); properties);
auto function_result_type = function_node.getResultType();
function_node.resolveAsAggregateFunction(std::move(aggregate_function), std::move(function_result_type)); function_node.resolveAsAggregateFunction(std::move(aggregate_function), std::move(function_result_type));
} }

View File

@ -0,0 +1,77 @@
QUERY id: 0
PROJECTION COLUMNS
sumIf(1, equals(modulo(number, 2), 0)) UInt64
PROJECTION
LIST id: 1, nodes: 1
FUNCTION id: 2, function_name: countIf, function_type: aggregate, result_type: UInt64
ARGUMENTS
LIST id: 3, nodes: 1
FUNCTION id: 4, function_name: equals, function_type: ordinary, result_type: UInt8
ARGUMENTS
LIST id: 5, nodes: 2
FUNCTION id: 6, function_name: modulo, function_type: ordinary, result_type: UInt8
ARGUMENTS
LIST id: 7, nodes: 2
COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 9
CONSTANT id: 10, constant_value: UInt64_2, constant_value_type: UInt8
CONSTANT id: 11, constant_value: UInt64_0, constant_value_type: UInt8
JOIN TREE
TABLE_FUNCTION id: 9, table_function_name: numbers
ARGUMENTS
LIST id: 12, nodes: 1
CONSTANT id: 13, constant_value: UInt64_10, constant_value_type: UInt8
--
5
--
QUERY id: 0
PROJECTION COLUMNS
sum(if(equals(modulo(number, 2), 0), 1, 0)) UInt64
PROJECTION
LIST id: 1, nodes: 1
FUNCTION id: 2, function_name: countIf, function_type: aggregate, result_type: UInt64
ARGUMENTS
LIST id: 3, nodes: 1
FUNCTION id: 4, function_name: equals, function_type: ordinary, result_type: UInt8
ARGUMENTS
LIST id: 5, nodes: 2
FUNCTION id: 6, function_name: modulo, function_type: ordinary, result_type: UInt8
ARGUMENTS
LIST id: 7, nodes: 2
COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 9
CONSTANT id: 10, constant_value: UInt64_2, constant_value_type: UInt8
CONSTANT id: 11, constant_value: UInt64_0, constant_value_type: UInt8
JOIN TREE
TABLE_FUNCTION id: 9, table_function_name: numbers
ARGUMENTS
LIST id: 12, nodes: 1
CONSTANT id: 13, constant_value: UInt64_10, constant_value_type: UInt8
--
5
--
QUERY id: 0
PROJECTION COLUMNS
sum(if(equals(modulo(number, 2), 0), 0, 1)) UInt64
PROJECTION
LIST id: 1, nodes: 1
FUNCTION id: 2, function_name: countIf, function_type: aggregate, result_type: UInt64
ARGUMENTS
LIST id: 3, nodes: 1
FUNCTION id: 4, function_name: not, function_type: ordinary, result_type: UInt8
ARGUMENTS
LIST id: 5, nodes: 1
FUNCTION id: 6, function_name: equals, function_type: ordinary, result_type: UInt8
ARGUMENTS
LIST id: 7, nodes: 2
FUNCTION id: 8, function_name: modulo, function_type: ordinary, result_type: UInt8
ARGUMENTS
LIST id: 9, nodes: 2
COLUMN id: 10, column_name: number, result_type: UInt64, source_id: 11
CONSTANT id: 12, constant_value: UInt64_2, constant_value_type: UInt8
CONSTANT id: 13, constant_value: UInt64_0, constant_value_type: UInt8
JOIN TREE
TABLE_FUNCTION id: 11, table_function_name: numbers
ARGUMENTS
LIST id: 14, nodes: 1
CONSTANT id: 15, constant_value: UInt64_10, constant_value_type: UInt8
--
5

View File

@ -0,0 +1,24 @@
SET allow_experimental_analyzer = 1;
SET optimize_rewrite_sum_if_to_count_if = 1;
EXPLAIN QUERY TREE (SELECT sumIf(1, (number % 2) == 0) FROM numbers(10));
SELECT '--';
SELECT sumIf(1, (number % 2) == 0) FROM numbers(10);
SELECT '--';
EXPLAIN QUERY TREE (SELECT sum(if((number % 2) == 0, 1, 0)) FROM numbers(10));
SELECT '--';
SELECT sum(if((number % 2) == 0, 1, 0)) FROM numbers(10);
SELECT '--';
EXPLAIN QUERY TREE (SELECT sum(if((number % 2) == 0, 0, 1)) FROM numbers(10));
SELECT '--';
SELECT sum(if((number % 2) == 0, 0, 1)) FROM numbers(10);