From f732c97cf6d5b9b8d5edf39c660e6d8213567db4 Mon Sep 17 00:00:00 2001 From: vdimir Date: Mon, 7 Nov 2022 14:24:27 +0000 Subject: [PATCH] fixes for FuseFunctionsPass --- src/Analyzer/Passes/FuseFunctionsPass.cpp | 11 +- src/Analyzer/QueryTreePassManager.cpp | 3 +- .../02476_fuse_sum_count.reference | 293 ++++++++++++------ .../0_stateless/02476_fuse_sum_count.sql | 17 +- 4 files changed, 231 insertions(+), 93 deletions(-) diff --git a/src/Analyzer/Passes/FuseFunctionsPass.cpp b/src/Analyzer/Passes/FuseFunctionsPass.cpp index 3d448be349f..d58df9d954c 100644 --- a/src/Analyzer/Passes/FuseFunctionsPass.cpp +++ b/src/Analyzer/Passes/FuseFunctionsPass.cpp @@ -38,9 +38,13 @@ public: if (!function_node || !function_node->isAggregateFunction() || !matchFunctionName(function_node->getFunctionName())) return; + if (function_node->getResultType()->isNullable()) + /// Do not apply to functions with Nullable result type, because `sumCount` handles it different from `sum` and `avg`. + return; + const auto & arguments = function_node->getArgumentsNode()->getChildren(); if (arguments.size() != 1) - throw Exception(ErrorCodes::LOGICAL_ERROR, "Aggregate function {} must have exactly one argument", function_node->getFunctionName()); + throw Exception(ErrorCodes::LOGICAL_ERROR, "Aggregate function {} should have exactly one argument", function_node->getFunctionName()); mapping[QueryTreeNodeWithHash(arguments[0])].push_back(&node); } @@ -111,12 +115,12 @@ void replaceWithSumCount(QueryTreeNodePtr & node, const FunctionNodePtr & sum_co if (function_name == "sum") { - assert(node->getResultType() == sum_count_result_type->getElement(0)); + assert(node->getResultType()->equals(*sum_count_result_type->getElement(0))); node = createTupleElementFunction(context, node->getResultType(), sum_count_node, 1); } else if (function_name == "count") { - assert(node->getResultType() == sum_count_result_type->getElement(1)); + assert(node->getResultType()->equals(*sum_count_result_type->getElement(1))); node = createTupleElementFunction(context, node->getResultType(), sum_count_node, 2); } else if (function_name == "avg") @@ -148,6 +152,7 @@ void FuseFunctionsPass::run(QueryTreeNodePtr query_tree_node, ContextPtr context auto sum_count_node = createSumCoundNode(argument.node); for (auto * node : nodes) { + assert(node); replaceWithSumCount(*node, sum_count_node, context); } } diff --git a/src/Analyzer/QueryTreePassManager.cpp b/src/Analyzer/QueryTreePassManager.cpp index 08cd57d4c3d..126b2dc823b 100644 --- a/src/Analyzer/QueryTreePassManager.cpp +++ b/src/Analyzer/QueryTreePassManager.cpp @@ -67,7 +67,6 @@ public: * * TODO: Support _shard_num into shardNum() rewriting. * TODO: Support logical expressions optimizer. - * TODO: Support fuse sum count optimize_fuse_sum_count_avg, optimize_syntax_fuse_functions. * TODO: Support setting convert_query_to_cnf. * TODO: Support setting optimize_using_constraints. * TODO: Support setting optimize_substitute_columns. @@ -79,7 +78,7 @@ public: * TODO: Support setting optimize_redundant_functions_in_order_by. * TODO: Support setting optimize_monotonous_functions_in_order_by. * TODO: Support setting optimize_if_transform_strings_to_enum. - * TODO: Support settings.optimize_syntax_fuse_functions. + * TODO: Support fuse quantile functions optimize_syntax_fuse_functions. * TODO: Support settings.optimize_or_like_chain. * TODO: Add optimizations based on function semantics. Example: SELECT * FROM test_table WHERE id != id. (id is not nullable column). */ diff --git a/tests/queries/0_stateless/02476_fuse_sum_count.reference b/tests/queries/0_stateless/02476_fuse_sum_count.reference index 20f9c4c6d27..5b6936110ba 100644 --- a/tests/queries/0_stateless/02476_fuse_sum_count.reference +++ b/tests/queries/0_stateless/02476_fuse_sum_count.reference @@ -1,5 +1,64 @@ +1.5 3 +\N \N +1.5 3 +2 6 +6 10 9 5 6 3 2 2 7 2 5 6 3 2 2 7 2 +QUERY id: 0 + PROJECTION COLUMNS + sum(a) Nullable(Int64) + avg(a) Nullable(Float64) + PROJECTION + LIST id: 1, nodes: 2 + FUNCTION id: 2, function_name: sum, function_type: aggregate, result_type: Nullable(Int64) + ARGUMENTS + LIST id: 3, nodes: 1 + COLUMN id: 4, column_name: a, result_type: Nullable(Int8), source_id: 5 + FUNCTION id: 6, function_name: avg, function_type: aggregate, result_type: Nullable(Float64) + ARGUMENTS + LIST id: 7, nodes: 1 + COLUMN id: 4, column_name: a, result_type: Nullable(Int8), source_id: 5 + JOIN TREE + TABLE id: 5, table_name: default.fuse_tbl +QUERY id: 0 + PROJECTION COLUMNS + sum(b) Int64 + avg(b) Float64 + PROJECTION + LIST id: 1, nodes: 2 + FUNCTION id: 2, function_name: tupleElement, function_type: ordinary, result_type: Int64 + ARGUMENTS + LIST id: 3, nodes: 2 + FUNCTION id: 4, function_name: sumCount, function_type: aggregate, result_type: Tuple(Int64, UInt64) + ARGUMENTS + LIST id: 5, nodes: 1 + COLUMN id: 6, column_name: b, result_type: Int8, source_id: 7 + CONSTANT id: 8, constant_value: UInt64_1, constant_value_type: UInt8 + FUNCTION id: 9, function_name: divide, function_type: ordinary, result_type: Float64 + ARGUMENTS + LIST id: 10, nodes: 2 + FUNCTION id: 11, function_name: tupleElement, function_type: ordinary, result_type: Int64 + ARGUMENTS + LIST id: 12, nodes: 2 + FUNCTION id: 4, function_name: sumCount, function_type: aggregate, result_type: Tuple(Int64, UInt64) + ARGUMENTS + LIST id: 5, nodes: 1 + COLUMN id: 6, column_name: b, result_type: Int8, source_id: 7 + CONSTANT id: 13, constant_value: UInt64_1, constant_value_type: UInt8 + FUNCTION id: 14, function_name: toFloat64, function_type: ordinary, result_type: Float64 + ARGUMENTS + LIST id: 15, nodes: 1 + FUNCTION id: 16, function_name: tupleElement, function_type: ordinary, result_type: UInt64 + ARGUMENTS + LIST id: 17, nodes: 2 + FUNCTION id: 4, function_name: sumCount, function_type: aggregate, result_type: Tuple(Int64, UInt64) + ARGUMENTS + LIST id: 5, nodes: 1 + COLUMN id: 6, column_name: b, result_type: Int8, source_id: 7 + CONSTANT id: 18, constant_value: UInt64_2, constant_value_type: UInt8 + JOIN TREE + TABLE id: 7, table_name: default.fuse_tbl QUERY id: 0 PROJECTION COLUMNS sum(plus(a, 1)) Nullable(Int64) @@ -11,98 +70,158 @@ QUERY id: 0 count(a) UInt64 PROJECTION LIST id: 1, nodes: 7 - FUNCTION id: 2, function_name: tupleElement, function_type: ordinary, result_type: Nullable(Int64) - ARGUMENTS - LIST id: 3, nodes: 2 - FUNCTION id: 4, function_name: sumCount, function_type: aggregate, result_type: Tuple(Int64, UInt64) - ARGUMENTS - LIST id: 5, nodes: 1 - FUNCTION id: 6, function_name: plus, function_type: ordinary, result_type: Nullable(Int16) - ARGUMENTS - LIST id: 7, nodes: 2 - COLUMN id: 8, column_name: a, result_type: Nullable(Int8), source_id: 9 - CONSTANT id: 10, constant_value: UInt64_1, constant_value_type: UInt8 - CONSTANT id: 11, constant_value: UInt64_1, constant_value_type: UInt8 - FUNCTION id: 12, function_name: tupleElement, function_type: ordinary, result_type: Int64 - ARGUMENTS - LIST id: 13, nodes: 2 - FUNCTION id: 14, function_name: sumCount, function_type: aggregate, result_type: Tuple(Int64, UInt64) - ARGUMENTS - LIST id: 15, nodes: 1 - COLUMN id: 16, column_name: b, result_type: Int8, source_id: 9 - CONSTANT id: 17, constant_value: UInt64_1, constant_value_type: UInt8 - FUNCTION id: 18, function_name: tupleElement, function_type: ordinary, result_type: UInt64 - ARGUMENTS - LIST id: 19, nodes: 2 - FUNCTION id: 14, function_name: sumCount, function_type: aggregate, result_type: Tuple(Int64, UInt64) - ARGUMENTS - LIST id: 15, nodes: 1 - COLUMN id: 16, column_name: b, result_type: Int8, source_id: 9 - CONSTANT id: 20, constant_value: UInt64_2, constant_value_type: UInt8 - FUNCTION id: 21, function_name: divide, function_type: ordinary, result_type: Float64 - ARGUMENTS - LIST id: 22, nodes: 2 - FUNCTION id: 23, function_name: tupleElement, function_type: ordinary, result_type: Int64 - ARGUMENTS - LIST id: 24, nodes: 2 - FUNCTION id: 14, function_name: sumCount, function_type: aggregate, result_type: Tuple(Int64, UInt64) - ARGUMENTS - LIST id: 15, nodes: 1 - COLUMN id: 16, column_name: b, result_type: Int8, source_id: 9 - CONSTANT id: 25, constant_value: UInt64_1, constant_value_type: UInt8 - FUNCTION id: 26, function_name: toFloat64, function_type: ordinary, result_type: Float64 - ARGUMENTS - LIST id: 27, nodes: 1 - FUNCTION id: 28, function_name: tupleElement, function_type: ordinary, result_type: UInt64 - ARGUMENTS - LIST id: 29, nodes: 2 - FUNCTION id: 14, function_name: sumCount, function_type: aggregate, result_type: Tuple(Int64, UInt64) - ARGUMENTS - LIST id: 15, nodes: 1 - COLUMN id: 16, column_name: b, result_type: Int8, source_id: 9 - CONSTANT id: 30, constant_value: UInt64_2, constant_value_type: UInt8 - FUNCTION id: 31, function_name: tupleElement, function_type: ordinary, result_type: UInt64 - ARGUMENTS - LIST id: 32, nodes: 2 - FUNCTION id: 4, function_name: sumCount, function_type: aggregate, result_type: Tuple(Int64, UInt64) - ARGUMENTS - LIST id: 5, nodes: 1 - FUNCTION id: 6, function_name: plus, function_type: ordinary, result_type: Nullable(Int16) - ARGUMENTS - LIST id: 7, nodes: 2 - COLUMN id: 8, column_name: a, result_type: Nullable(Int8), source_id: 9 - CONSTANT id: 10, constant_value: UInt64_1, constant_value_type: UInt8 - CONSTANT id: 33, constant_value: UInt64_2, constant_value_type: UInt8 - FUNCTION id: 34, function_name: sum, function_type: aggregate, result_type: Nullable(Int64) - ARGUMENTS - LIST id: 35, nodes: 1 - FUNCTION id: 36, function_name: plus, function_type: ordinary, result_type: Nullable(Int16) - ARGUMENTS - LIST id: 37, nodes: 2 - COLUMN id: 8, column_name: a, result_type: Nullable(Int8), source_id: 9 - CONSTANT id: 38, constant_value: UInt64_2, constant_value_type: UInt8 - FUNCTION id: 39, function_name: count, function_type: aggregate, result_type: UInt64 - ARGUMENTS - LIST id: 40, nodes: 1 - COLUMN id: 8, column_name: a, result_type: Nullable(Int8), source_id: 9 - JOIN TREE - TABLE id: 9, table_name: default.fuse_tbl -QUERY id: 0 - PROJECTION COLUMNS - sum(a) Nullable(Int64) - avg(b) Float64 - PROJECTION - LIST id: 1, nodes: 2 FUNCTION id: 2, function_name: sum, function_type: aggregate, result_type: Nullable(Int64) ARGUMENTS LIST id: 3, nodes: 1 - COLUMN id: 4, column_name: a, result_type: Nullable(Int8), source_id: 5 - FUNCTION id: 6, function_name: avg, function_type: aggregate, result_type: Float64 + FUNCTION id: 4, function_name: plus, function_type: ordinary, result_type: Nullable(Int16) + ARGUMENTS + LIST id: 5, nodes: 2 + COLUMN id: 6, column_name: a, result_type: Nullable(Int8), source_id: 7 + CONSTANT id: 8, constant_value: UInt64_1, constant_value_type: UInt8 + FUNCTION id: 9, function_name: tupleElement, function_type: ordinary, result_type: Int64 ARGUMENTS - LIST id: 7, nodes: 1 - COLUMN id: 8, column_name: b, result_type: Int8, source_id: 5 + LIST id: 10, nodes: 2 + FUNCTION id: 11, function_name: sumCount, function_type: aggregate, result_type: Tuple(Int64, UInt64) + ARGUMENTS + LIST id: 12, nodes: 1 + COLUMN id: 13, column_name: b, result_type: Int8, source_id: 7 + CONSTANT id: 14, constant_value: UInt64_1, constant_value_type: UInt8 + FUNCTION id: 15, function_name: tupleElement, function_type: ordinary, result_type: UInt64 + ARGUMENTS + LIST id: 16, nodes: 2 + FUNCTION id: 11, function_name: sumCount, function_type: aggregate, result_type: Tuple(Int64, UInt64) + ARGUMENTS + LIST id: 12, nodes: 1 + COLUMN id: 13, column_name: b, result_type: Int8, source_id: 7 + CONSTANT id: 17, constant_value: UInt64_2, constant_value_type: UInt8 + FUNCTION id: 18, function_name: divide, function_type: ordinary, result_type: Float64 + ARGUMENTS + LIST id: 19, nodes: 2 + FUNCTION id: 20, function_name: tupleElement, function_type: ordinary, result_type: Int64 + ARGUMENTS + LIST id: 21, nodes: 2 + FUNCTION id: 11, function_name: sumCount, function_type: aggregate, result_type: Tuple(Int64, UInt64) + ARGUMENTS + LIST id: 12, nodes: 1 + COLUMN id: 13, column_name: b, result_type: Int8, source_id: 7 + CONSTANT id: 22, constant_value: UInt64_1, constant_value_type: UInt8 + FUNCTION id: 23, function_name: toFloat64, function_type: ordinary, result_type: Float64 + ARGUMENTS + LIST id: 24, nodes: 1 + FUNCTION id: 25, function_name: tupleElement, function_type: ordinary, result_type: UInt64 + ARGUMENTS + LIST id: 26, nodes: 2 + FUNCTION id: 11, function_name: sumCount, function_type: aggregate, result_type: Tuple(Int64, UInt64) + ARGUMENTS + LIST id: 12, nodes: 1 + COLUMN id: 13, column_name: b, result_type: Int8, source_id: 7 + CONSTANT id: 27, constant_value: UInt64_2, constant_value_type: UInt8 + FUNCTION id: 28, function_name: count, function_type: aggregate, result_type: UInt64 + ARGUMENTS + LIST id: 29, nodes: 1 + FUNCTION id: 30, function_name: plus, function_type: ordinary, result_type: Nullable(Int16) + ARGUMENTS + LIST id: 31, nodes: 2 + COLUMN id: 6, column_name: a, result_type: Nullable(Int8), source_id: 7 + CONSTANT id: 32, constant_value: UInt64_1, constant_value_type: UInt8 + FUNCTION id: 33, function_name: sum, function_type: aggregate, result_type: Nullable(Int64) + ARGUMENTS + LIST id: 34, nodes: 1 + FUNCTION id: 35, function_name: plus, function_type: ordinary, result_type: Nullable(Int16) + ARGUMENTS + LIST id: 36, nodes: 2 + COLUMN id: 6, column_name: a, result_type: Nullable(Int8), source_id: 7 + CONSTANT id: 37, constant_value: UInt64_2, constant_value_type: UInt8 + FUNCTION id: 38, function_name: count, function_type: aggregate, result_type: UInt64 + ARGUMENTS + LIST id: 39, nodes: 1 + COLUMN id: 6, column_name: a, result_type: Nullable(Int8), source_id: 7 JOIN TREE - TABLE id: 5, table_name: default.fuse_tbl + TABLE id: 7, table_name: default.fuse_tbl +QUERY id: 0 + PROJECTION COLUMNS + multiply(avg(b), 3) Float64 + plus(plus(sum(b), 1), count(b)) Int64 + multiply(count(b), count(b)) UInt64 + PROJECTION + LIST id: 1, nodes: 3 + FUNCTION id: 2, function_name: multiply, function_type: ordinary, result_type: Float64 + ARGUMENTS + LIST id: 3, nodes: 2 + FUNCTION id: 4, function_name: divide, function_type: ordinary, result_type: Float64 + ARGUMENTS + LIST id: 5, nodes: 2 + FUNCTION id: 6, function_name: tupleElement, function_type: ordinary, result_type: Int64 + ARGUMENTS + LIST id: 7, nodes: 2 + FUNCTION id: 8, function_name: sumCount, function_type: aggregate, result_type: Tuple(Int64, UInt64) + ARGUMENTS + LIST id: 9, nodes: 1 + COLUMN id: 10, column_name: b, result_type: Int8, source_id: 11 + CONSTANT id: 12, constant_value: UInt64_1, constant_value_type: UInt8 + FUNCTION id: 13, function_name: toFloat64, function_type: ordinary, result_type: Float64 + ARGUMENTS + LIST id: 14, nodes: 1 + FUNCTION id: 15, function_name: tupleElement, function_type: ordinary, result_type: UInt64 + ARGUMENTS + LIST id: 16, nodes: 2 + FUNCTION id: 8, function_name: sumCount, function_type: aggregate, result_type: Tuple(Int64, UInt64) + ARGUMENTS + LIST id: 9, nodes: 1 + COLUMN id: 10, column_name: b, result_type: Int8, source_id: 11 + CONSTANT id: 17, constant_value: UInt64_2, constant_value_type: UInt8 + CONSTANT id: 18, constant_value: UInt64_3, constant_value_type: UInt8 + FUNCTION id: 19, function_name: plus, function_type: ordinary, result_type: Int64 + ARGUMENTS + LIST id: 20, nodes: 2 + FUNCTION id: 21, function_name: plus, function_type: ordinary, result_type: Int64 + ARGUMENTS + LIST id: 22, nodes: 2 + FUNCTION id: 23, function_name: tupleElement, function_type: ordinary, result_type: Int64 + ARGUMENTS + LIST id: 24, nodes: 2 + FUNCTION id: 8, function_name: sumCount, function_type: aggregate, result_type: Tuple(Int64, UInt64) + ARGUMENTS + LIST id: 9, nodes: 1 + COLUMN id: 10, column_name: b, result_type: Int8, source_id: 11 + CONSTANT id: 25, constant_value: UInt64_1, constant_value_type: UInt8 + CONSTANT id: 26, constant_value: UInt64_1, constant_value_type: UInt8 + FUNCTION id: 27, function_name: tupleElement, function_type: ordinary, result_type: UInt64 + ARGUMENTS + LIST id: 28, nodes: 2 + FUNCTION id: 8, function_name: sumCount, function_type: aggregate, result_type: Tuple(Int64, UInt64) + ARGUMENTS + LIST id: 9, nodes: 1 + COLUMN id: 10, column_name: b, result_type: Int8, source_id: 11 + CONSTANT id: 29, constant_value: UInt64_2, constant_value_type: UInt8 + FUNCTION id: 30, function_name: multiply, function_type: ordinary, result_type: UInt64 + ARGUMENTS + LIST id: 31, nodes: 2 + FUNCTION id: 32, function_name: tupleElement, function_type: ordinary, result_type: UInt64 + ARGUMENTS + LIST id: 33, nodes: 2 + FUNCTION id: 8, function_name: sumCount, function_type: aggregate, result_type: Tuple(Int64, UInt64) + ARGUMENTS + LIST id: 9, nodes: 1 + COLUMN id: 10, column_name: b, result_type: Int8, source_id: 11 + CONSTANT id: 34, constant_value: UInt64_2, constant_value_type: UInt8 + FUNCTION id: 35, function_name: tupleElement, function_type: ordinary, result_type: UInt64 + ARGUMENTS + LIST id: 36, nodes: 2 + FUNCTION id: 8, function_name: sumCount, function_type: aggregate, result_type: Tuple(Int64, UInt64) + ARGUMENTS + LIST id: 9, nodes: 1 + COLUMN id: 10, column_name: b, result_type: Int8, source_id: 11 + CONSTANT id: 37, constant_value: UInt64_2, constant_value_type: UInt8 + JOIN TREE + QUERY id: 11, is_subquery: 1 + PROJECTION COLUMNS + b Int8 + PROJECTION + LIST id: 38, nodes: 1 + COLUMN id: 39, column_name: b, result_type: Int8, source_id: 40 + JOIN TREE + TABLE id: 40, table_name: default.fuse_tbl 0 0 nan 0 0 nan 45 10 4.5 Decimal(38, 0) UInt64 Float64 diff --git a/tests/queries/0_stateless/02476_fuse_sum_count.sql b/tests/queries/0_stateless/02476_fuse_sum_count.sql index 985f1f4d2c3..fe4b196d4e5 100644 --- a/tests/queries/0_stateless/02476_fuse_sum_count.sql +++ b/tests/queries/0_stateless/02476_fuse_sum_count.sql @@ -7,11 +7,26 @@ CREATE TABLE fuse_tbl(a Nullable(Int8), b Int8) Engine = Log; INSERT INTO fuse_tbl VALUES (1, 1), (2, 2), (NULL, 3); +SELECT avg(a), sum(a) FROM (SELECT a FROM fuse_tbl); +SELECT avg(a), sum(a) FROM (SELECT a FROM fuse_tbl WHERE isNull(a)); +SELECT avg(a), sum(a) FROM (SELECT a FROM fuse_tbl WHERE isNotNull(a)); + +SELECT avg(b), sum(b) FROM (SELECT b FROM fuse_tbl); +SELECT avg(b) * 3, sum(b) + 1 + count(b), count(b) * count(b) FROM (SELECT b FROM fuse_tbl); + +-- TODO(@vdimir): uncomment after https://github.com/ClickHouse/ClickHouse/pull/42865 +-- SELECT sum(b), count(b) from (SELECT x as b FROM (SELECT sum(b) as x, count(b) FROM fuse_tbl) ); + SELECT sum(a + 1), sum(b), count(b), avg(b), count(a + 1), sum(a + 2), count(a) from fuse_tbl SETTINGS optimize_syntax_fuse_functions = 0; SELECT sum(a + 1), sum(b), count(b), avg(b), count(a + 1), sum(a + 2), count(a) from fuse_tbl; +EXPLAIN QUERY TREE run_passes = 1 SELECT sum(a), avg(a) from fuse_tbl; +EXPLAIN QUERY TREE run_passes = 1 SELECT sum(b), avg(b) from fuse_tbl; EXPLAIN QUERY TREE run_passes = 1 SELECT sum(a + 1), sum(b), count(b), avg(b), count(a + 1), sum(a + 2), count(a) from fuse_tbl; -EXPLAIN QUERY TREE run_passes = 1 SELECT sum(a), avg(b) from fuse_tbl; +EXPLAIN QUERY TREE run_passes = 1 SELECT avg(b) * 3, sum(b) + 1 + count(b), count(b) * count(b) FROM (SELECT b FROM fuse_tbl); + +-- TODO(@vdimir): uncomment after https://github.com/ClickHouse/ClickHouse/pull/42865 +-- EXPLAIN QUERY TREE run_passes = 1 SELECT sum(b), count(b) from (SELECT x as b FROM (SELECT sum(b) as x, count(b) FROM fuse_tbl) ); SELECT sum(x), count(x), avg(x) FROM (SELECT number :: Decimal32(0) AS x FROM numbers(0)) SETTINGS optimize_syntax_fuse_functions = 0; SELECT sum(x), count(x), avg(x) FROM (SELECT number :: Decimal32(0) AS x FROM numbers(0));