mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-23 08:02:02 +00:00
Merge pull request #43873 from canhld94/ch_canh_fix_normalize
Fix some incorrect logic in ast level optimization
This commit is contained in:
commit
0598ca92a3
@ -6,6 +6,7 @@
|
||||
#include <Analyzer/InDepthQueryTreeVisitor.h>
|
||||
#include <Analyzer/ConstantNode.h>
|
||||
#include <Analyzer/FunctionNode.h>
|
||||
#include <Interpreters/Context.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
@ -16,7 +17,8 @@ namespace
|
||||
class NormalizeCountVariantsVisitor : public InDepthQueryTreeVisitor<NormalizeCountVariantsVisitor>
|
||||
{
|
||||
public:
|
||||
static void visitImpl(QueryTreeNodePtr & node)
|
||||
explicit NormalizeCountVariantsVisitor(ContextPtr context_) : context(std::move(context_)) {}
|
||||
void visitImpl(QueryTreeNodePtr & node)
|
||||
{
|
||||
auto * function_node = node->as<FunctionNode>();
|
||||
if (!function_node || !function_node->isAggregateFunction() || (function_node->getFunctionName() != "count" && function_node->getFunctionName() != "sum"))
|
||||
@ -39,13 +41,16 @@ public:
|
||||
}
|
||||
else if (function_node->getFunctionName() == "sum" &&
|
||||
first_argument_constant_literal.getType() == Field::Types::UInt64 &&
|
||||
first_argument_constant_literal.get<UInt64>() == 1)
|
||||
first_argument_constant_literal.get<UInt64>() == 1 &&
|
||||
!context->getSettingsRef().aggregate_functions_null_for_empty)
|
||||
{
|
||||
resolveAsCountAggregateFunction(*function_node);
|
||||
function_node->getArguments().getNodes().clear();
|
||||
}
|
||||
}
|
||||
private:
|
||||
ContextPtr context;
|
||||
|
||||
static inline void resolveAsCountAggregateFunction(FunctionNode & function_node)
|
||||
{
|
||||
auto function_result_type = function_node.getResultType();
|
||||
@ -59,9 +64,9 @@ private:
|
||||
|
||||
}
|
||||
|
||||
void NormalizeCountVariantsPass::run(QueryTreeNodePtr query_tree_node, ContextPtr)
|
||||
void NormalizeCountVariantsPass::run(QueryTreeNodePtr query_tree_node, ContextPtr context)
|
||||
{
|
||||
NormalizeCountVariantsVisitor visitor;
|
||||
NormalizeCountVariantsVisitor visitor(context);
|
||||
visitor.visit(query_tree_node);
|
||||
}
|
||||
|
||||
|
@ -56,7 +56,7 @@ public:
|
||||
if (!isInt64OrUInt64FieldType(constant_value_literal.getType()))
|
||||
return;
|
||||
|
||||
if (constant_value_literal.get<UInt64>() != 1)
|
||||
if (constant_value_literal.get<UInt64>() != 1 || context->getSettingsRef().aggregate_functions_null_for_empty)
|
||||
return;
|
||||
|
||||
function_node_arguments_nodes[0] = std::move(function_node_arguments_nodes[1]);
|
||||
|
@ -6,6 +6,7 @@
|
||||
#include <Poco/String.h>
|
||||
#include <Common/typeid_cast.h>
|
||||
#include <Common/checkStackSize.h>
|
||||
#include <Interpreters/Context.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
@ -52,7 +53,7 @@ void RewriteCountVariantsVisitor::visit(ASTFunction & func)
|
||||
if (first_arg_literal->value.getType() == Field::Types::UInt64)
|
||||
{
|
||||
auto constant = first_arg_literal->value.get<UInt64>();
|
||||
if (constant == 1)
|
||||
if (constant == 1 && !context->getSettingsRef().aggregate_functions_null_for_empty)
|
||||
transform = true;
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <Parsers/IAST.h>
|
||||
#include <Interpreters/Context_fwd.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
@ -10,8 +11,11 @@ class ASTFunction;
|
||||
class RewriteCountVariantsVisitor
|
||||
{
|
||||
public:
|
||||
static void visit(ASTPtr &);
|
||||
static void visit(ASTFunction &);
|
||||
explicit RewriteCountVariantsVisitor(ContextPtr context_) : context(context_) {}
|
||||
void visit(ASTPtr &);
|
||||
void visit(ASTFunction &);
|
||||
private:
|
||||
ContextPtr context;
|
||||
};
|
||||
|
||||
}
|
||||
|
@ -758,9 +758,9 @@ void TreeOptimizer::optimizeIf(ASTPtr & query, Aliases & aliases, bool if_chain_
|
||||
OptimizeIfChainsVisitor().visit(query);
|
||||
}
|
||||
|
||||
void TreeOptimizer::optimizeCountConstantAndSumOne(ASTPtr & query)
|
||||
void TreeOptimizer::optimizeCountConstantAndSumOne(ASTPtr & query, ContextPtr context)
|
||||
{
|
||||
RewriteCountVariantsVisitor::visit(query);
|
||||
RewriteCountVariantsVisitor(context).visit(query);
|
||||
}
|
||||
|
||||
///eliminate functions of other GROUP BY keys
|
||||
@ -835,7 +835,7 @@ void TreeOptimizer::apply(ASTPtr & query, TreeRewriterResult & result,
|
||||
optimizeAnyFunctions(query);
|
||||
|
||||
if (settings.optimize_normalize_count_variants)
|
||||
optimizeCountConstantAndSumOne(query);
|
||||
optimizeCountConstantAndSumOne(query, context);
|
||||
|
||||
if (settings.optimize_multiif_to_if)
|
||||
optimizeMultiIfToIf(query);
|
||||
|
@ -24,7 +24,7 @@ public:
|
||||
ContextPtr context);
|
||||
|
||||
static void optimizeIf(ASTPtr & query, Aliases & aliases, bool if_chain_to_multiif);
|
||||
static void optimizeCountConstantAndSumOne(ASTPtr & query);
|
||||
static void optimizeCountConstantAndSumOne(ASTPtr & query, ContextPtr context);
|
||||
static void optimizeGroupByFunctionKeys(ASTSelectQuery * select_query);
|
||||
};
|
||||
|
||||
|
@ -149,7 +149,7 @@ struct CustomizeAggregateFunctionsSuffixData
|
||||
void visit(ASTFunction & func, ASTPtr &) const
|
||||
{
|
||||
const auto & instance = AggregateFunctionFactory::instance();
|
||||
if (instance.isAggregateFunctionName(func.name) && !endsWith(func.name, customized_func_suffix))
|
||||
if (instance.isAggregateFunctionName(func.name) && !endsWith(func.name, customized_func_suffix) && !endsWith(func.name, customized_func_suffix + "If"))
|
||||
{
|
||||
auto properties = instance.tryGetProperties(func.name);
|
||||
if (properties && !properties->returns_default_when_only_null)
|
||||
@ -1298,7 +1298,7 @@ TreeRewriterResultPtr TreeRewriter::analyzeSelect(
|
||||
|
||||
/// Perform it before analyzing JOINs, because it may change number of columns with names unique and break some logic inside JOINs
|
||||
if (settings.optimize_normalize_count_variants)
|
||||
TreeOptimizer::optimizeCountConstantAndSumOne(query);
|
||||
TreeOptimizer::optimizeCountConstantAndSumOne(query, getContext());
|
||||
|
||||
if (tables_with_columns.size() > 1)
|
||||
{
|
||||
|
@ -14,3 +14,6 @@
|
||||
45
|
||||
10
|
||||
10
|
||||
SELECT sumOrNullIf(1, number > 0)
|
||||
FROM numbers(10)
|
||||
WHERE 0
|
||||
|
@ -33,4 +33,7 @@ SELECT sumOrNull(n) FROM defaults;
|
||||
SELECT count(n) FROM defaults;
|
||||
SELECT countOrNull(n) FROM defaults;
|
||||
|
||||
|
||||
EXPLAIN SYNTAX SELECT sumIf(1, number > 0) FROM numbers(10) WHERE 0;
|
||||
|
||||
DROP TABLE defaults;
|
||||
|
@ -4,3 +4,6 @@ SELECT
|
||||
count(),
|
||||
count(),
|
||||
count(NULL)
|
||||
SELECT sumOrNull(1)
|
||||
FROM numbers(10)
|
||||
WHERE 0
|
||||
|
@ -2,3 +2,7 @@
|
||||
set optimize_normalize_count_variants = 1;
|
||||
|
||||
explain syntax select count(), count(1), count(-1), sum(1), count(null);
|
||||
|
||||
set aggregate_functions_null_for_empty = 1;
|
||||
|
||||
explain syntax select sum(1) from numbers(10) where 0;
|
||||
|
Loading…
Reference in New Issue
Block a user