mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-12-02 12:32:04 +00:00
fix bugs
This commit is contained in:
parent
952058e69e
commit
aa3d67e2d6
@ -21,15 +21,17 @@ namespace DB
|
||||
namespace
|
||||
{
|
||||
|
||||
class RewriteAggregateFunctionWithIfVisitor : public InDepthQueryTreeVisitor<RewriteAggregateFunctionWithIfVisitor>
|
||||
class RewriteAggregateFunctionWithIfVisitor : public InDepthQueryTreeVisitorWithContext<RewriteAggregateFunctionWithIfVisitor>
|
||||
{
|
||||
public:
|
||||
explicit RewriteAggregateFunctionWithIfVisitor(ContextPtr & context_)
|
||||
: context(context_)
|
||||
{}
|
||||
using Base = InDepthQueryTreeVisitorWithContext<RewriteAggregateFunctionWithIfVisitor>;
|
||||
using Base::Base;
|
||||
|
||||
void visitImpl(QueryTreeNodePtr & node)
|
||||
{
|
||||
if (!getSettings().optimize_rewrite_aggregate_function_with_if)
|
||||
return;
|
||||
|
||||
auto * function_node = node->as<FunctionNode>();
|
||||
if (!function_node || !function_node->isAggregateFunction())
|
||||
return;
|
||||
@ -58,7 +60,9 @@ public:
|
||||
function_arguments_nodes[0] = std::move(if_arguments_nodes[1]);
|
||||
function_arguments_nodes[1] = std::move(if_arguments_nodes[0]);
|
||||
resolveAsAggregateFunctionWithIf(
|
||||
*function_node, {function_arguments_nodes[0]->getResultType(), function_arguments_nodes[1]->getResultType()});
|
||||
*function_node,
|
||||
{function_arguments_nodes[0]->getResultType(), function_arguments_nodes[1]->getResultType()},
|
||||
second_const_value.isNull());
|
||||
}
|
||||
}
|
||||
else if (first_const_node)
|
||||
@ -72,28 +76,34 @@ public:
|
||||
auto not_function = std::make_shared<FunctionNode>("not");
|
||||
auto & not_function_arguments = not_function->getArguments().getNodes();
|
||||
not_function_arguments.push_back(std::move(if_arguments_nodes[0]));
|
||||
not_function->resolveAsFunction(FunctionFactory::instance().get("not", context)->build(not_function->getArgumentColumns()));
|
||||
not_function->resolveAsFunction(
|
||||
FunctionFactory::instance().get("not", getContext())->build(not_function->getArgumentColumns()));
|
||||
|
||||
function_arguments_nodes.resize(2);
|
||||
function_arguments_nodes[0] = std::move(if_arguments_nodes[2]);
|
||||
function_arguments_nodes[1] = std::move(not_function);
|
||||
resolveAsAggregateFunctionWithIf(
|
||||
*function_node, {function_arguments_nodes[0]->getResultType(), function_arguments_nodes[1]->getResultType()});
|
||||
*function_node,
|
||||
{function_arguments_nodes[0]->getResultType(), function_arguments_nodes[1]->getResultType()},
|
||||
first_const_value.isNull());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
static inline void resolveAsAggregateFunctionWithIf(FunctionNode & function_node, const DataTypes & argument_types)
|
||||
static inline void resolveAsAggregateFunctionWithIf(FunctionNode & function_node, const DataTypes & argument_types, bool need_or_null)
|
||||
{
|
||||
AggregateFunctionProperties properties;
|
||||
auto aggregate_function = AggregateFunctionFactory::instance().get(
|
||||
function_node.getFunctionName() + "If", argument_types, function_node.getAggregateFunction()->getParameters(), properties);
|
||||
function_node.getFunctionName() + (need_or_null ? "IfOrNull" : "If"),
|
||||
argument_types,
|
||||
function_node.getAggregateFunction()->getParameters(),
|
||||
properties);
|
||||
|
||||
std::cout << "functionname:" << aggregate_function->getName() << std::endl;
|
||||
function_node.resolveAsAggregateFunction(std::move(aggregate_function));
|
||||
std::cout << "functionnode:" << function_node.dumpTree() << std::endl;
|
||||
}
|
||||
|
||||
ContextPtr & context;
|
||||
};
|
||||
|
||||
}
|
||||
|
@ -43,7 +43,8 @@ void RewriteAggregateFunctionWithIfMatcher::visit(const ASTFunction & func, ASTP
|
||||
{
|
||||
/// avg(if(cond, a, null)) -> avgIf(a, cond)
|
||||
/// sum(if(cond, a, 0)) -> sumIf(a, cond)
|
||||
auto new_func = makeASTFunction(func.name + "If", if_arguments[1], if_arguments[0]);
|
||||
auto new_func
|
||||
= makeASTFunction(func.name + (second_literal->value.isNull() ? "IfOrNull" : "If"), if_arguments[1], if_arguments[0]);
|
||||
new_func->setAlias(func.alias);
|
||||
new_func->parameters = func.parameters;
|
||||
|
||||
@ -59,7 +60,8 @@ void RewriteAggregateFunctionWithIfMatcher::visit(const ASTFunction & func, ASTP
|
||||
/// avg(if(cond, null, a) -> avgIf(a, !cond))
|
||||
/// sum(if(cond, 0, a) -> sumIf(a, !cond))
|
||||
auto not_func = makeASTFunction("not", if_arguments[0]);
|
||||
auto new_func = makeASTFunction(func.name + "If", if_arguments[2], std::move(not_func));
|
||||
auto new_func
|
||||
= makeASTFunction(func.name + (first_literal->value.isNull() ? "IfOrNull" : "If"), if_arguments[2], std::move(not_func));
|
||||
new_func->setAlias(func.alias);
|
||||
new_func->parameters = func.parameters;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user