diff --git a/src/Analyzer/Passes/RewriteAggregateFunctionWithIfPass.cpp b/src/Analyzer/Passes/RewriteAggregateFunctionWithIfPass.cpp index c09e12d4c1d..b7fb6b8ca7d 100644 --- a/src/Analyzer/Passes/RewriteAggregateFunctionWithIfPass.cpp +++ b/src/Analyzer/Passes/RewriteAggregateFunctionWithIfPass.cpp @@ -21,15 +21,17 @@ namespace DB namespace { -class RewriteAggregateFunctionWithIfVisitor : public InDepthQueryTreeVisitor +class RewriteAggregateFunctionWithIfVisitor : public InDepthQueryTreeVisitorWithContext { public: - explicit RewriteAggregateFunctionWithIfVisitor(ContextPtr & context_) - : context(context_) - {} + using Base = InDepthQueryTreeVisitorWithContext; + using Base::Base; void visitImpl(QueryTreeNodePtr & node) { + if (!getSettings().optimize_rewrite_aggregate_function_with_if) + return; + auto * function_node = node->as(); if (!function_node || !function_node->isAggregateFunction()) return; @@ -58,7 +60,9 @@ public: function_arguments_nodes[0] = std::move(if_arguments_nodes[1]); function_arguments_nodes[1] = std::move(if_arguments_nodes[0]); resolveAsAggregateFunctionWithIf( - *function_node, {function_arguments_nodes[0]->getResultType(), function_arguments_nodes[1]->getResultType()}); + *function_node, + {function_arguments_nodes[0]->getResultType(), function_arguments_nodes[1]->getResultType()}, + second_const_value.isNull()); } } else if (first_const_node) @@ -72,28 +76,34 @@ public: auto not_function = std::make_shared("not"); auto & not_function_arguments = not_function->getArguments().getNodes(); not_function_arguments.push_back(std::move(if_arguments_nodes[0])); - not_function->resolveAsFunction(FunctionFactory::instance().get("not", context)->build(not_function->getArgumentColumns())); + not_function->resolveAsFunction( + FunctionFactory::instance().get("not", getContext())->build(not_function->getArgumentColumns())); function_arguments_nodes.resize(2); function_arguments_nodes[0] = std::move(if_arguments_nodes[2]); function_arguments_nodes[1] = std::move(not_function); resolveAsAggregateFunctionWithIf( - *function_node, {function_arguments_nodes[0]->getResultType(), function_arguments_nodes[1]->getResultType()}); + *function_node, + {function_arguments_nodes[0]->getResultType(), function_arguments_nodes[1]->getResultType()}, + first_const_value.isNull()); } } } private: - static inline void resolveAsAggregateFunctionWithIf(FunctionNode & function_node, const DataTypes & argument_types) + static inline void resolveAsAggregateFunctionWithIf(FunctionNode & function_node, const DataTypes & argument_types, bool need_or_null) { AggregateFunctionProperties properties; auto aggregate_function = AggregateFunctionFactory::instance().get( - function_node.getFunctionName() + "If", argument_types, function_node.getAggregateFunction()->getParameters(), properties); + function_node.getFunctionName() + (need_or_null ? "IfOrNull" : "If"), + argument_types, + function_node.getAggregateFunction()->getParameters(), + properties); + std::cout << "functionname:" << aggregate_function->getName() << std::endl; function_node.resolveAsAggregateFunction(std::move(aggregate_function)); + std::cout << "functionnode:" << function_node.dumpTree() << std::endl; } - - ContextPtr & context; }; } diff --git a/src/Interpreters/RewriteAggregateFunctionWithIfVisitor.cpp b/src/Interpreters/RewriteAggregateFunctionWithIfVisitor.cpp index cb799f004ab..52976d8c31e 100644 --- a/src/Interpreters/RewriteAggregateFunctionWithIfVisitor.cpp +++ b/src/Interpreters/RewriteAggregateFunctionWithIfVisitor.cpp @@ -43,7 +43,8 @@ void RewriteAggregateFunctionWithIfMatcher::visit(const ASTFunction & func, ASTP { /// avg(if(cond, a, null)) -> avgIf(a, cond) /// sum(if(cond, a, 0)) -> sumIf(a, cond) - auto new_func = makeASTFunction(func.name + "If", if_arguments[1], if_arguments[0]); + auto new_func + = makeASTFunction(func.name + (second_literal->value.isNull() ? "IfOrNull" : "If"), if_arguments[1], if_arguments[0]); new_func->setAlias(func.alias); new_func->parameters = func.parameters; @@ -59,7 +60,8 @@ void RewriteAggregateFunctionWithIfMatcher::visit(const ASTFunction & func, ASTP /// avg(if(cond, null, a) -> avgIf(a, !cond)) /// sum(if(cond, 0, a) -> sumIf(a, !cond)) auto not_func = makeASTFunction("not", if_arguments[0]); - auto new_func = makeASTFunction(func.name + "If", if_arguments[2], std::move(not_func)); + auto new_func + = makeASTFunction(func.name + (first_literal->value.isNull() ? "IfOrNull" : "If"), if_arguments[2], std::move(not_func)); new_func->setAlias(func.alias); new_func->parameters = func.parameters;