From eeef2dae7748107f1b62c83e654ff2ea84056ddc Mon Sep 17 00:00:00 2001 From: taiyang-li <654010905@qq.com> Date: Sun, 29 Jan 2023 19:48:46 +0800 Subject: [PATCH] fix cores --- src/Analyzer/IQueryTreeNode.h | 3 + src/Analyzer/Passes/SumIfToCountIfPass.cpp | 45 +++++++++-- .../01646_rewrite_sum_if.reference | 81 +++++++++++++++++++ .../0_stateless/01646_rewrite_sum_if.sql | 6 ++ 4 files changed, 128 insertions(+), 7 deletions(-) diff --git a/src/Analyzer/IQueryTreeNode.h b/src/Analyzer/IQueryTreeNode.h index 8aa834e60b7..719f92d773e 100644 --- a/src/Analyzer/IQueryTreeNode.h +++ b/src/Analyzer/IQueryTreeNode.h @@ -72,6 +72,9 @@ class IQueryTreeNode : public TypePromotion public: virtual ~IQueryTreeNode() = default; + IQueryTreeNode & operator=(const IQueryTreeNode &) = default; + IQueryTreeNode & operator=(IQueryTreeNode &&) = default; + /// Get query tree node type virtual QueryTreeNodeType getNodeType() const = 0; diff --git a/src/Analyzer/Passes/SumIfToCountIfPass.cpp b/src/Analyzer/Passes/SumIfToCountIfPass.cpp index 4462131ed7a..d55af278152 100644 --- a/src/Analyzer/Passes/SumIfToCountIfPass.cpp +++ b/src/Analyzer/Passes/SumIfToCountIfPass.cpp @@ -58,13 +58,20 @@ public: if (!isInt64OrUInt64FieldType(constant_value_literal.getType())) return; - if (constant_value_literal.get() != 1 || getSettings().aggregate_functions_null_for_empty) + if (getSettings().aggregate_functions_null_for_empty) return; + /// Rewrite `sumIf(1, cond)` into `countIf(cond)` + auto multiplier_node = function_node_arguments_nodes[0]; function_node_arguments_nodes[0] = std::move(function_node_arguments_nodes[1]); function_node_arguments_nodes.resize(1); - resolveAsCountIfAggregateFunction(*function_node, function_node_arguments_nodes[0]->getResultType()); + + if (constant_value_literal.get() != 1) + { + /// Rewrite `sumIf(123, cond)` into `123 * countIf(cond)` + node = getMultiplyFunction(std::move(multiplier_node), node); + } return; } @@ -79,7 +86,7 @@ public: if (!nested_function || nested_function->getFunctionName() != "if") return; - const auto & nested_if_function_arguments_nodes = nested_function->getArguments().getNodes(); + const auto nested_if_function_arguments_nodes = nested_function->getArguments().getNodes(); if (nested_if_function_arguments_nodes.size() != 3) return; @@ -100,19 +107,25 @@ public: auto if_true_condition_value = if_true_condition_constant_value_literal.get(); auto if_false_condition_value = if_false_condition_constant_value_literal.get(); - /// Rewrite `sum(if(cond, 1, 0))` into `countIf(cond)`. - if (if_true_condition_value == 1 && if_false_condition_value == 0) + if (if_false_condition_value == 0) { + /// Rewrite `sum(if(cond, 1, 0))` into `countIf(cond)`. function_node_arguments_nodes[0] = nested_if_function_arguments_nodes[0]; function_node_arguments_nodes.resize(1); resolveAsCountIfAggregateFunction(*function_node, function_node_arguments_nodes[0]->getResultType()); + + if (if_true_condition_value != 1) + { + /// Rewrite `sum(if(cond, 123, 0))` into `123 * countIf(cond)`. + node = getMultiplyFunction(nested_if_function_arguments_nodes[1], node); + } return; } - /// Rewrite `sum(if(cond, 0, 1))` into `countIf(not(cond))` if condition is not Nullable (otherwise the result can be different). - if (if_true_condition_value == 0 && if_false_condition_value == 1 && !cond_argument->getResultType()->isNullable()) + if (if_true_condition_value == 0 && !cond_argument->getResultType()->isNullable()) { + /// Rewrite `sum(if(cond, 0, 1))` into `countIf(not(cond))` if condition is not Nullable (otherwise the result can be different). DataTypePtr not_function_result_type = std::make_shared(); const auto & condition_result_type = nested_if_function_arguments_nodes[0]->getResultType(); @@ -130,6 +143,12 @@ public: function_node_arguments_nodes.resize(1); resolveAsCountIfAggregateFunction(*function_node, function_node_arguments_nodes[0]->getResultType()); + + if (if_false_condition_value != 1) + { + /// Rewrite `sum(if(cond, 0, 123))` into `123 * countIf(not(cond))` if condition is not Nullable (otherwise the result can be different). + node = getMultiplyFunction(nested_if_function_arguments_nodes[2], node); + } return; } } @@ -145,6 +164,18 @@ private: function_node.resolveAsAggregateFunction(std::move(aggregate_function)); } + + inline QueryTreeNodePtr getMultiplyFunction(QueryTreeNodePtr left, QueryTreeNodePtr right) + { + auto multiply_function_node = std::make_shared("multiply"); + auto & multiply_arguments_nodes = multiply_function_node->getArguments().getNodes(); + multiply_arguments_nodes.push_back(std::move(left)); + multiply_arguments_nodes.push_back(std::move(right)); + + auto multiply_function_base = FunctionFactory::instance().get("multiply", getContext())->build(multiply_function_node->getArgumentColumns()); + multiply_function_node->resolveAsFunction(std::move(multiply_function_base)); + return std::move(multiply_function_node); + } }; } diff --git a/tests/queries/0_stateless/01646_rewrite_sum_if.reference b/tests/queries/0_stateless/01646_rewrite_sum_if.reference index 0e37c91578a..871c75737c6 100644 --- a/tests/queries/0_stateless/01646_rewrite_sum_if.reference +++ b/tests/queries/0_stateless/01646_rewrite_sum_if.reference @@ -34,3 +34,84 @@ SELECT 123 * countIf((number % 2) = 0) FROM numbers(100) SELECT 123 * countIf(NOT ((number % 2) = 0)) FROM numbers(100) +QUERY id: 0 + PROJECTION COLUMNS + sumIf(123, equals(modulo(number, 2), 0)) UInt64 + PROJECTION + LIST id: 1, nodes: 1 + FUNCTION id: 2, function_name: multiply, function_type: ordinary, result_type: UInt64 + ARGUMENTS + LIST id: 3, nodes: 2 + CONSTANT id: 4, constant_value: UInt64_123, constant_value_type: UInt8 + FUNCTION id: 5, function_name: countIf, function_type: aggregate, result_type: UInt64 + ARGUMENTS + LIST id: 6, nodes: 1 + FUNCTION id: 7, function_name: equals, function_type: ordinary, result_type: UInt8 + ARGUMENTS + LIST id: 8, nodes: 2 + FUNCTION id: 9, function_name: modulo, function_type: ordinary, result_type: UInt8 + ARGUMENTS + LIST id: 10, nodes: 2 + COLUMN id: 11, column_name: number, result_type: UInt64, source_id: 12 + CONSTANT id: 13, constant_value: UInt64_2, constant_value_type: UInt8 + CONSTANT id: 14, constant_value: UInt64_0, constant_value_type: UInt8 + JOIN TREE + TABLE_FUNCTION id: 12, table_function_name: numbers + ARGUMENTS + LIST id: 15, nodes: 1 + CONSTANT id: 16, constant_value: UInt64_100, constant_value_type: UInt8 +QUERY id: 0 + PROJECTION COLUMNS + sum(if(equals(modulo(number, 2), 0), 123, 0)) UInt64 + PROJECTION + LIST id: 1, nodes: 1 + FUNCTION id: 2, function_name: multiply, function_type: ordinary, result_type: UInt64 + ARGUMENTS + LIST id: 3, nodes: 2 + CONSTANT id: 4, constant_value: UInt64_123, constant_value_type: UInt8 + FUNCTION id: 5, function_name: countIf, function_type: aggregate, result_type: UInt64 + ARGUMENTS + LIST id: 6, nodes: 1 + FUNCTION id: 7, function_name: equals, function_type: ordinary, result_type: UInt8 + ARGUMENTS + LIST id: 8, nodes: 2 + FUNCTION id: 9, function_name: modulo, function_type: ordinary, result_type: UInt8 + ARGUMENTS + LIST id: 10, nodes: 2 + COLUMN id: 11, column_name: number, result_type: UInt64, source_id: 12 + CONSTANT id: 13, constant_value: UInt64_2, constant_value_type: UInt8 + CONSTANT id: 14, constant_value: UInt64_0, constant_value_type: UInt8 + JOIN TREE + TABLE_FUNCTION id: 12, table_function_name: numbers + ARGUMENTS + LIST id: 15, nodes: 1 + CONSTANT id: 16, constant_value: UInt64_100, constant_value_type: UInt8 +QUERY id: 0 + PROJECTION COLUMNS + sum(if(equals(modulo(number, 2), 0), 0, 123)) UInt64 + PROJECTION + LIST id: 1, nodes: 1 + FUNCTION id: 2, function_name: multiply, function_type: ordinary, result_type: UInt64 + ARGUMENTS + LIST id: 3, nodes: 2 + CONSTANT id: 4, constant_value: UInt64_123, constant_value_type: UInt8 + FUNCTION id: 5, function_name: countIf, function_type: aggregate, result_type: UInt64 + ARGUMENTS + LIST id: 6, nodes: 1 + FUNCTION id: 7, function_name: not, function_type: ordinary, result_type: UInt8 + ARGUMENTS + LIST id: 8, nodes: 1 + FUNCTION id: 9, function_name: equals, function_type: ordinary, result_type: UInt8 + ARGUMENTS + LIST id: 10, nodes: 2 + FUNCTION id: 11, function_name: modulo, function_type: ordinary, result_type: UInt8 + ARGUMENTS + LIST id: 12, nodes: 2 + COLUMN id: 13, column_name: number, result_type: UInt64, source_id: 14 + CONSTANT id: 15, constant_value: UInt64_2, constant_value_type: UInt8 + CONSTANT id: 16, constant_value: UInt64_0, constant_value_type: UInt8 + JOIN TREE + TABLE_FUNCTION id: 14, table_function_name: numbers + ARGUMENTS + LIST id: 17, nodes: 1 + CONSTANT id: 18, constant_value: UInt64_100, constant_value_type: UInt8 diff --git a/tests/queries/0_stateless/01646_rewrite_sum_if.sql b/tests/queries/0_stateless/01646_rewrite_sum_if.sql index 9abe7f35005..b2de98e9e07 100644 --- a/tests/queries/0_stateless/01646_rewrite_sum_if.sql +++ b/tests/queries/0_stateless/01646_rewrite_sum_if.sql @@ -41,3 +41,9 @@ SELECT countIf(number % 2 != 0) FROM numbers(100); EXPLAIN SYNTAX SELECT sumIf(123, number % 2 == 0) FROM numbers(100); EXPLAIN SYNTAX SELECT sum(if(number % 2 == 0, 123, 0)) FROM numbers(100); EXPLAIN SYNTAX SELECT sum(if(number % 2 == 0, 0, 123)) FROM numbers(100); + +set allow_experimental_analyzer = true; + +EXPLAIN QUERY TREE run_passes=1 SELECT sumIf(123, number % 2 == 0) FROM numbers(100); +EXPLAIN QUERY TREE run_passes=1 SELECT sum(if(number % 2 == 0, 123, 0)) FROM numbers(100); +EXPLAIN QUERY TREE run_passes=1 SELECT sum(if(number % 2 == 0, 0, 123)) FROM numbers(100); \ No newline at end of file