diff --git a/src/Analyzer/Passes/NormalizeCountVariantsPass.cpp b/src/Analyzer/Passes/NormalizeCountVariantsPass.cpp index 28e5af3f5db..cd6aa4d76f4 100644 --- a/src/Analyzer/Passes/NormalizeCountVariantsPass.cpp +++ b/src/Analyzer/Passes/NormalizeCountVariantsPass.cpp @@ -6,6 +6,7 @@ #include #include #include +#include namespace DB { @@ -16,7 +17,8 @@ namespace class NormalizeCountVariantsVisitor : public InDepthQueryTreeVisitor { public: - static void visitImpl(QueryTreeNodePtr & node) + explicit NormalizeCountVariantsVisitor(ContextPtr context_) : context(std::move(context_)) {} + void visitImpl(QueryTreeNodePtr & node) { auto * function_node = node->as(); if (!function_node || !function_node->isAggregateFunction() || (function_node->getFunctionName() != "count" && function_node->getFunctionName() != "sum")) @@ -39,13 +41,16 @@ public: } else if (function_node->getFunctionName() == "sum" && first_argument_constant_literal.getType() == Field::Types::UInt64 && - first_argument_constant_literal.get() == 1) + first_argument_constant_literal.get() == 1 && + !context->getSettingsRef().aggregate_functions_null_for_empty) { resolveAsCountAggregateFunction(*function_node); function_node->getArguments().getNodes().clear(); } } private: + ContextPtr context; + static inline void resolveAsCountAggregateFunction(FunctionNode & function_node) { auto function_result_type = function_node.getResultType(); @@ -59,9 +64,9 @@ private: } -void NormalizeCountVariantsPass::run(QueryTreeNodePtr query_tree_node, ContextPtr) +void NormalizeCountVariantsPass::run(QueryTreeNodePtr query_tree_node, ContextPtr context) { - NormalizeCountVariantsVisitor visitor; + NormalizeCountVariantsVisitor visitor(context); visitor.visit(query_tree_node); } diff --git a/src/Analyzer/Passes/SumIfToCountIfPass.cpp b/src/Analyzer/Passes/SumIfToCountIfPass.cpp index e40ba25a965..91c277d35b3 100644 --- a/src/Analyzer/Passes/SumIfToCountIfPass.cpp +++ b/src/Analyzer/Passes/SumIfToCountIfPass.cpp @@ -56,7 +56,7 @@ public: if (!isInt64OrUInt64FieldType(constant_value_literal.getType())) return; - if (constant_value_literal.get() != 1) + if (constant_value_literal.get() != 1 || context->getSettingsRef().aggregate_functions_null_for_empty) return; function_node_arguments_nodes[0] = std::move(function_node_arguments_nodes[1]); diff --git a/src/Interpreters/RewriteCountVariantsVisitor.cpp b/src/Interpreters/RewriteCountVariantsVisitor.cpp index 741dc3e8cb7..f207bc51527 100644 --- a/src/Interpreters/RewriteCountVariantsVisitor.cpp +++ b/src/Interpreters/RewriteCountVariantsVisitor.cpp @@ -6,6 +6,7 @@ #include #include #include +#include namespace DB @@ -52,7 +53,7 @@ void RewriteCountVariantsVisitor::visit(ASTFunction & func) if (first_arg_literal->value.getType() == Field::Types::UInt64) { auto constant = first_arg_literal->value.get(); - if (constant == 1) + if (constant == 1 && !context->getSettingsRef().aggregate_functions_null_for_empty) transform = true; } } diff --git a/src/Interpreters/RewriteCountVariantsVisitor.h b/src/Interpreters/RewriteCountVariantsVisitor.h index 6f731c8c463..36c026bdfd7 100644 --- a/src/Interpreters/RewriteCountVariantsVisitor.h +++ b/src/Interpreters/RewriteCountVariantsVisitor.h @@ -1,6 +1,7 @@ #pragma once #include +#include namespace DB { @@ -10,8 +11,11 @@ class ASTFunction; class RewriteCountVariantsVisitor { public: - static void visit(ASTPtr &); - static void visit(ASTFunction &); + explicit RewriteCountVariantsVisitor(ContextPtr context_) : context(context_) {} + void visit(ASTPtr &); + void visit(ASTFunction &); +private: + ContextPtr context; }; } diff --git a/src/Interpreters/TreeOptimizer.cpp b/src/Interpreters/TreeOptimizer.cpp index 182e9623c61..6461a35dae6 100644 --- a/src/Interpreters/TreeOptimizer.cpp +++ b/src/Interpreters/TreeOptimizer.cpp @@ -758,9 +758,9 @@ void TreeOptimizer::optimizeIf(ASTPtr & query, Aliases & aliases, bool if_chain_ OptimizeIfChainsVisitor().visit(query); } -void TreeOptimizer::optimizeCountConstantAndSumOne(ASTPtr & query) +void TreeOptimizer::optimizeCountConstantAndSumOne(ASTPtr & query, ContextPtr context) { - RewriteCountVariantsVisitor::visit(query); + RewriteCountVariantsVisitor(context).visit(query); } ///eliminate functions of other GROUP BY keys @@ -835,7 +835,7 @@ void TreeOptimizer::apply(ASTPtr & query, TreeRewriterResult & result, optimizeAnyFunctions(query); if (settings.optimize_normalize_count_variants) - optimizeCountConstantAndSumOne(query); + optimizeCountConstantAndSumOne(query, context); if (settings.optimize_multiif_to_if) optimizeMultiIfToIf(query); diff --git a/src/Interpreters/TreeOptimizer.h b/src/Interpreters/TreeOptimizer.h index 72a240d83b5..07ae2fbd12d 100644 --- a/src/Interpreters/TreeOptimizer.h +++ b/src/Interpreters/TreeOptimizer.h @@ -24,7 +24,7 @@ public: ContextPtr context); static void optimizeIf(ASTPtr & query, Aliases & aliases, bool if_chain_to_multiif); - static void optimizeCountConstantAndSumOne(ASTPtr & query); + static void optimizeCountConstantAndSumOne(ASTPtr & query, ContextPtr context); static void optimizeGroupByFunctionKeys(ASTSelectQuery * select_query); }; diff --git a/src/Interpreters/TreeRewriter.cpp b/src/Interpreters/TreeRewriter.cpp index e49ed73fc9a..5d3efaba996 100644 --- a/src/Interpreters/TreeRewriter.cpp +++ b/src/Interpreters/TreeRewriter.cpp @@ -149,7 +149,7 @@ struct CustomizeAggregateFunctionsSuffixData void visit(ASTFunction & func, ASTPtr &) const { const auto & instance = AggregateFunctionFactory::instance(); - if (instance.isAggregateFunctionName(func.name) && !endsWith(func.name, customized_func_suffix)) + if (instance.isAggregateFunctionName(func.name) && !endsWith(func.name, customized_func_suffix) && !endsWith(func.name, customized_func_suffix + "If")) { auto properties = instance.tryGetProperties(func.name); if (properties && !properties->returns_default_when_only_null) @@ -1298,7 +1298,7 @@ TreeRewriterResultPtr TreeRewriter::analyzeSelect( /// Perform it before analyzing JOINs, because it may change number of columns with names unique and break some logic inside JOINs if (settings.optimize_normalize_count_variants) - TreeOptimizer::optimizeCountConstantAndSumOne(query); + TreeOptimizer::optimizeCountConstantAndSumOne(query, getContext()); if (tables_with_columns.size() > 1) { diff --git a/tests/queries/0_stateless/01528_setting_aggregate_functions_null_for_empty.reference b/tests/queries/0_stateless/01528_setting_aggregate_functions_null_for_empty.reference index 9c6ae9c65ab..8b1aa83d73c 100644 --- a/tests/queries/0_stateless/01528_setting_aggregate_functions_null_for_empty.reference +++ b/tests/queries/0_stateless/01528_setting_aggregate_functions_null_for_empty.reference @@ -14,3 +14,6 @@ 45 10 10 +SELECT sumOrNullIf(1, number > 0) +FROM numbers(10) +WHERE 0 diff --git a/tests/queries/0_stateless/01528_setting_aggregate_functions_null_for_empty.sql b/tests/queries/0_stateless/01528_setting_aggregate_functions_null_for_empty.sql index e76ce667bbc..b57a492e375 100644 --- a/tests/queries/0_stateless/01528_setting_aggregate_functions_null_for_empty.sql +++ b/tests/queries/0_stateless/01528_setting_aggregate_functions_null_for_empty.sql @@ -33,4 +33,7 @@ SELECT sumOrNull(n) FROM defaults; SELECT count(n) FROM defaults; SELECT countOrNull(n) FROM defaults; + +EXPLAIN SYNTAX SELECT sumIf(1, number > 0) FROM numbers(10) WHERE 0; + DROP TABLE defaults; diff --git a/tests/queries/0_stateless/01706_optimize_normalize_count_variants.reference b/tests/queries/0_stateless/01706_optimize_normalize_count_variants.reference index 0343ad84abb..3080226da32 100644 --- a/tests/queries/0_stateless/01706_optimize_normalize_count_variants.reference +++ b/tests/queries/0_stateless/01706_optimize_normalize_count_variants.reference @@ -4,3 +4,6 @@ SELECT count(), count(), count(NULL) +SELECT sumOrNull(1) +FROM numbers(10) +WHERE 0 diff --git a/tests/queries/0_stateless/01706_optimize_normalize_count_variants.sql b/tests/queries/0_stateless/01706_optimize_normalize_count_variants.sql index d20f23feef8..9c85d6bc2fd 100644 --- a/tests/queries/0_stateless/01706_optimize_normalize_count_variants.sql +++ b/tests/queries/0_stateless/01706_optimize_normalize_count_variants.sql @@ -2,3 +2,7 @@ set optimize_normalize_count_variants = 1; explain syntax select count(), count(1), count(-1), sum(1), count(null); + +set aggregate_functions_null_for_empty = 1; + +explain syntax select sum(1) from numbers(10) where 0;