Merge pull request #43873 from canhld94/ch_canh_fix_normalize

Fix some incorrect logic in ast level optimization
This commit is contained in:
Kruglov Pavel 2022-12-08 12:38:36 +01:00 committed by GitHub
commit 0598ca92a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 37 additions and 14 deletions

View File

@ -6,6 +6,7 @@
#include <Analyzer/InDepthQueryTreeVisitor.h>
#include <Analyzer/ConstantNode.h>
#include <Analyzer/FunctionNode.h>
#include <Interpreters/Context.h>
namespace DB
{
@ -16,7 +17,8 @@ namespace
class NormalizeCountVariantsVisitor : public InDepthQueryTreeVisitor<NormalizeCountVariantsVisitor>
{
public:
static void visitImpl(QueryTreeNodePtr & node)
explicit NormalizeCountVariantsVisitor(ContextPtr context_) : context(std::move(context_)) {}
void visitImpl(QueryTreeNodePtr & node)
{
auto * function_node = node->as<FunctionNode>();
if (!function_node || !function_node->isAggregateFunction() || (function_node->getFunctionName() != "count" && function_node->getFunctionName() != "sum"))
@ -39,13 +41,16 @@ public:
}
else if (function_node->getFunctionName() == "sum" &&
first_argument_constant_literal.getType() == Field::Types::UInt64 &&
first_argument_constant_literal.get<UInt64>() == 1)
first_argument_constant_literal.get<UInt64>() == 1 &&
!context->getSettingsRef().aggregate_functions_null_for_empty)
{
resolveAsCountAggregateFunction(*function_node);
function_node->getArguments().getNodes().clear();
}
}
private:
ContextPtr context;
static inline void resolveAsCountAggregateFunction(FunctionNode & function_node)
{
auto function_result_type = function_node.getResultType();
@ -59,9 +64,9 @@ private:
}
void NormalizeCountVariantsPass::run(QueryTreeNodePtr query_tree_node, ContextPtr)
void NormalizeCountVariantsPass::run(QueryTreeNodePtr query_tree_node, ContextPtr context)
{
NormalizeCountVariantsVisitor visitor;
NormalizeCountVariantsVisitor visitor(context);
visitor.visit(query_tree_node);
}

View File

@ -56,7 +56,7 @@ public:
if (!isInt64OrUInt64FieldType(constant_value_literal.getType()))
return;
if (constant_value_literal.get<UInt64>() != 1)
if (constant_value_literal.get<UInt64>() != 1 || context->getSettingsRef().aggregate_functions_null_for_empty)
return;
function_node_arguments_nodes[0] = std::move(function_node_arguments_nodes[1]);

View File

@ -6,6 +6,7 @@
#include <Poco/String.h>
#include <Common/typeid_cast.h>
#include <Common/checkStackSize.h>
#include <Interpreters/Context.h>
namespace DB
@ -52,7 +53,7 @@ void RewriteCountVariantsVisitor::visit(ASTFunction & func)
if (first_arg_literal->value.getType() == Field::Types::UInt64)
{
auto constant = first_arg_literal->value.get<UInt64>();
if (constant == 1)
if (constant == 1 && !context->getSettingsRef().aggregate_functions_null_for_empty)
transform = true;
}
}

View File

@ -1,6 +1,7 @@
#pragma once
#include <Parsers/IAST.h>
#include <Interpreters/Context_fwd.h>
namespace DB
{
@ -10,8 +11,11 @@ class ASTFunction;
class RewriteCountVariantsVisitor
{
public:
static void visit(ASTPtr &);
static void visit(ASTFunction &);
explicit RewriteCountVariantsVisitor(ContextPtr context_) : context(context_) {}
void visit(ASTPtr &);
void visit(ASTFunction &);
private:
ContextPtr context;
};
}

View File

@ -758,9 +758,9 @@ void TreeOptimizer::optimizeIf(ASTPtr & query, Aliases & aliases, bool if_chain_
OptimizeIfChainsVisitor().visit(query);
}
void TreeOptimizer::optimizeCountConstantAndSumOne(ASTPtr & query)
void TreeOptimizer::optimizeCountConstantAndSumOne(ASTPtr & query, ContextPtr context)
{
RewriteCountVariantsVisitor::visit(query);
RewriteCountVariantsVisitor(context).visit(query);
}
///eliminate functions of other GROUP BY keys
@ -835,7 +835,7 @@ void TreeOptimizer::apply(ASTPtr & query, TreeRewriterResult & result,
optimizeAnyFunctions(query);
if (settings.optimize_normalize_count_variants)
optimizeCountConstantAndSumOne(query);
optimizeCountConstantAndSumOne(query, context);
if (settings.optimize_multiif_to_if)
optimizeMultiIfToIf(query);

View File

@ -24,7 +24,7 @@ public:
ContextPtr context);
static void optimizeIf(ASTPtr & query, Aliases & aliases, bool if_chain_to_multiif);
static void optimizeCountConstantAndSumOne(ASTPtr & query);
static void optimizeCountConstantAndSumOne(ASTPtr & query, ContextPtr context);
static void optimizeGroupByFunctionKeys(ASTSelectQuery * select_query);
};

View File

@ -149,7 +149,7 @@ struct CustomizeAggregateFunctionsSuffixData
void visit(ASTFunction & func, ASTPtr &) const
{
const auto & instance = AggregateFunctionFactory::instance();
if (instance.isAggregateFunctionName(func.name) && !endsWith(func.name, customized_func_suffix))
if (instance.isAggregateFunctionName(func.name) && !endsWith(func.name, customized_func_suffix) && !endsWith(func.name, customized_func_suffix + "If"))
{
auto properties = instance.tryGetProperties(func.name);
if (properties && !properties->returns_default_when_only_null)
@ -1298,7 +1298,7 @@ TreeRewriterResultPtr TreeRewriter::analyzeSelect(
/// Perform it before analyzing JOINs, because it may change number of columns with names unique and break some logic inside JOINs
if (settings.optimize_normalize_count_variants)
TreeOptimizer::optimizeCountConstantAndSumOne(query);
TreeOptimizer::optimizeCountConstantAndSumOne(query, getContext());
if (tables_with_columns.size() > 1)
{

View File

@ -14,3 +14,6 @@
45
10
10
SELECT sumOrNullIf(1, number > 0)
FROM numbers(10)
WHERE 0

View File

@ -33,4 +33,7 @@ SELECT sumOrNull(n) FROM defaults;
SELECT count(n) FROM defaults;
SELECT countOrNull(n) FROM defaults;
EXPLAIN SYNTAX SELECT sumIf(1, number > 0) FROM numbers(10) WHERE 0;
DROP TABLE defaults;

View File

@ -4,3 +4,6 @@ SELECT
count(),
count(),
count(NULL)
SELECT sumOrNull(1)
FROM numbers(10)
WHERE 0

View File

@ -2,3 +2,7 @@
set optimize_normalize_count_variants = 1;
explain syntax select count(), count(1), count(-1), sum(1), count(null);
set aggregate_functions_null_for_empty = 1;
explain syntax select sum(1) from numbers(10) where 0;