This commit is contained in:
taiyang-li 2023-01-29 20:48:46 +08:00
parent 952058e69e
commit aa3d67e2d6
2 changed files with 25 additions and 13 deletions

View File

@ -21,15 +21,17 @@ namespace DB
namespace
{
class RewriteAggregateFunctionWithIfVisitor : public InDepthQueryTreeVisitor<RewriteAggregateFunctionWithIfVisitor>
class RewriteAggregateFunctionWithIfVisitor : public InDepthQueryTreeVisitorWithContext<RewriteAggregateFunctionWithIfVisitor>
{
public:
explicit RewriteAggregateFunctionWithIfVisitor(ContextPtr & context_)
: context(context_)
{}
using Base = InDepthQueryTreeVisitorWithContext<RewriteAggregateFunctionWithIfVisitor>;
using Base::Base;
void visitImpl(QueryTreeNodePtr & node)
{
if (!getSettings().optimize_rewrite_aggregate_function_with_if)
return;
auto * function_node = node->as<FunctionNode>();
if (!function_node || !function_node->isAggregateFunction())
return;
@ -58,7 +60,9 @@ public:
function_arguments_nodes[0] = std::move(if_arguments_nodes[1]);
function_arguments_nodes[1] = std::move(if_arguments_nodes[0]);
resolveAsAggregateFunctionWithIf(
*function_node, {function_arguments_nodes[0]->getResultType(), function_arguments_nodes[1]->getResultType()});
*function_node,
{function_arguments_nodes[0]->getResultType(), function_arguments_nodes[1]->getResultType()},
second_const_value.isNull());
}
}
else if (first_const_node)
@ -72,28 +76,34 @@ public:
auto not_function = std::make_shared<FunctionNode>("not");
auto & not_function_arguments = not_function->getArguments().getNodes();
not_function_arguments.push_back(std::move(if_arguments_nodes[0]));
not_function->resolveAsFunction(FunctionFactory::instance().get("not", context)->build(not_function->getArgumentColumns()));
not_function->resolveAsFunction(
FunctionFactory::instance().get("not", getContext())->build(not_function->getArgumentColumns()));
function_arguments_nodes.resize(2);
function_arguments_nodes[0] = std::move(if_arguments_nodes[2]);
function_arguments_nodes[1] = std::move(not_function);
resolveAsAggregateFunctionWithIf(
*function_node, {function_arguments_nodes[0]->getResultType(), function_arguments_nodes[1]->getResultType()});
*function_node,
{function_arguments_nodes[0]->getResultType(), function_arguments_nodes[1]->getResultType()},
first_const_value.isNull());
}
}
}
private:
static inline void resolveAsAggregateFunctionWithIf(FunctionNode & function_node, const DataTypes & argument_types)
static inline void resolveAsAggregateFunctionWithIf(FunctionNode & function_node, const DataTypes & argument_types, bool need_or_null)
{
AggregateFunctionProperties properties;
auto aggregate_function = AggregateFunctionFactory::instance().get(
function_node.getFunctionName() + "If", argument_types, function_node.getAggregateFunction()->getParameters(), properties);
function_node.getFunctionName() + (need_or_null ? "IfOrNull" : "If"),
argument_types,
function_node.getAggregateFunction()->getParameters(),
properties);
std::cout << "functionname:" << aggregate_function->getName() << std::endl;
function_node.resolveAsAggregateFunction(std::move(aggregate_function));
std::cout << "functionnode:" << function_node.dumpTree() << std::endl;
}
ContextPtr & context;
};
}

View File

@ -43,7 +43,8 @@ void RewriteAggregateFunctionWithIfMatcher::visit(const ASTFunction & func, ASTP
{
/// avg(if(cond, a, null)) -> avgIf(a, cond)
/// sum(if(cond, a, 0)) -> sumIf(a, cond)
auto new_func = makeASTFunction(func.name + "If", if_arguments[1], if_arguments[0]);
auto new_func
= makeASTFunction(func.name + (second_literal->value.isNull() ? "IfOrNull" : "If"), if_arguments[1], if_arguments[0]);
new_func->setAlias(func.alias);
new_func->parameters = func.parameters;
@ -59,7 +60,8 @@ void RewriteAggregateFunctionWithIfMatcher::visit(const ASTFunction & func, ASTP
/// avg(if(cond, null, a) -> avgIf(a, !cond))
/// sum(if(cond, 0, a) -> sumIf(a, !cond))
auto not_func = makeASTFunction("not", if_arguments[0]);
auto new_func = makeASTFunction(func.name + "If", if_arguments[2], std::move(not_func));
auto new_func
= makeASTFunction(func.name + (first_literal->value.isNull() ? "IfOrNull" : "If"), if_arguments[2], std::move(not_func));
new_func->setAlias(func.alias);
new_func->parameters = func.parameters;