Rewrite sum(if()) and sumIf to countIf in special cases (#17041)

Co-authored-by: vdimir <vdimir@yandex-team.ru>
This commit is contained in:
flynn 2021-01-21 17:01:35 +08:00 committed by GitHub
parent 0424db3e67
commit e75b116466
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 199 additions and 2 deletions

View File

@ -416,6 +416,7 @@ class IColumn;
M(Bool, use_antlr_parser, false, "Parse incoming queries using ANTLR-generated experimental parser", 0) \
M(Bool, async_socket_for_remote, true, "Asynchronously read from socket executing remote query", 0) \
\
M(Bool, optimize_rewrite_sum_if_to_count_if, true, "Rewrite sumIf() and sum(if()) function countIf() function when logically equivalent", 0) \
/** Obsolete settings that do nothing but left for compatibility reasons. Remove each one after half a year of obsolescence. */ \
\
M(UInt64, max_memory_usage_for_all_queries, 0, "Obsolete. Will be removed after 2020-10-20", 0) \

View File

@ -68,11 +68,11 @@ struct NeedChild
};
/// Simple matcher for one node type. Use need_child function for complex traversal logic.
template <typename Data_, NeedChild::Condition need_child = NeedChild::all, typename T = ASTPtr>
template <typename DataImpl, NeedChild::Condition need_child = NeedChild::all, typename T = ASTPtr>
class OneTypeMatcher
{
public:
using Data = Data_;
using Data = DataImpl;
using TypeToVisit = typename Data::TypeToVisit;
static bool needChildVisit(const ASTPtr & node, const ASTPtr & child) { return need_child(node, child); }

View File

@ -0,0 +1,91 @@
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <Interpreters/RewriteSumIfFunctionVisitor.h>
#include <Parsers/ASTFunction.h>
#include <Parsers/ASTLiteral.h>
#include <Common/typeid_cast.h>
namespace DB
{
void RewriteSumIfFunctionMatcher::visit(ASTPtr & ast, Data & data)
{
if (auto * func = ast->as<ASTFunction>())
visit(*func, ast, data);
}
static ASTPtr createNewFunctionWithOneArgument(const String & func_name, const ASTPtr & argument)
{
auto new_func = std::make_shared<ASTFunction>();
new_func->name = func_name;
auto new_arguments = std::make_shared<ASTExpressionList>();
new_arguments->children.push_back(argument);
new_func->arguments = new_arguments;
new_func->children.push_back(new_arguments);
return new_func;
}
void RewriteSumIfFunctionMatcher::visit(const ASTFunction & func, ASTPtr & ast, Data &)
{
if (!func.arguments || func.arguments->children.empty())
return;
auto lower_name = Poco::toLower(func.name);
if (lower_name != "sum" && lower_name != "sumif")
return;
auto & func_arguments = func.arguments->children;
if (lower_name == "sumif")
{
/// sumIf(1, cond) -> countIf(cond)
const auto * literal = func_arguments[0]->as<ASTLiteral>();
if (func_arguments.size() == 2 && literal && literal->value.get<UInt64>() == 1)
{
auto new_func = createNewFunctionWithOneArgument("countIf", func_arguments[1]);
new_func->setAlias(func.alias);
ast = std::move(new_func);
return;
}
}
else
{
const auto * nested_func = func_arguments[0]->as<ASTFunction>();
if (!nested_func || Poco::toLower(nested_func->name) != "if" || nested_func->arguments->children.size() != 3)
return;
auto & if_arguments = nested_func->arguments->children;
const auto * first_literal = if_arguments[1]->as<ASTLiteral>();
const auto * second_literal = if_arguments[2]->as<ASTLiteral>();
if (first_literal && second_literal)
{
auto first_value = first_literal->value.get<UInt64>();
auto second_value = second_literal->value.get<UInt64>();
/// sum(if(cond, 1, 0)) -> countIf(cond)
if (first_value == 1 && second_value == 0)
{
auto new_func = createNewFunctionWithOneArgument("countIf", if_arguments[0]);
new_func->setAlias(func.alias);
ast = std::move(new_func);
return;
}
/// sum(if(cond, 0, 1)) -> countIf(not(cond))
if (first_value == 0 && second_value == 1)
{
auto not_func = createNewFunctionWithOneArgument("not", if_arguments[0]);
auto new_func = createNewFunctionWithOneArgument("countIf", not_func);
new_func->setAlias(func.alias);
ast = std::move(new_func);
return;
}
}
}
}
}

View File

@ -0,0 +1,30 @@
#pragma once
#include <unordered_set>
#include <Parsers/IAST.h>
#include <Interpreters/InDepthNodeVisitor.h>
namespace DB
{
class ASTFunction;
/// Rewrite 'sum(if())' and 'sumIf' functions to counIf.
/// sumIf(1, cond) -> countIf(1, cond)
/// sum(if(cond, 1, 0)) -> countIf(cond)
/// sum(if(cond, 0, 1)) -> countIf(not(cond))
class RewriteSumIfFunctionMatcher
{
public:
struct Data
{
};
static void visit(ASTPtr & ast, Data &);
static void visit(const ASTFunction &, ASTPtr & ast, Data &);
static bool needChildVisit(const ASTPtr &, const ASTPtr &) { return true; }
};
using RewriteSumIfFunctionVisitor = InDepthNodeVisitor<RewriteSumIfFunctionMatcher, false>;
}

View File

@ -28,6 +28,7 @@
#include <Functions/FunctionFactory.h>
#include <Storages/StorageInMemoryMetadata.h>
#include <Interpreters/RewriteSumIfFunctionVisitor.h>
namespace DB
{
@ -548,6 +549,13 @@ void optimizeAnyFunctions(ASTPtr & query)
RewriteAnyFunctionVisitor(data).visit(query);
}
void optimizeSumIfFunctions(ASTPtr & query)
{
RewriteSumIfFunctionVisitor::Data data = {};
RewriteSumIfFunctionVisitor(data).visit(query);
}
void optimizeInjectiveFunctionsInsideUniq(ASTPtr & query, const Context & context)
{
RemoveInjectiveFunctionsVisitor::Data data = {context};
@ -608,6 +616,9 @@ void TreeOptimizer::apply(ASTPtr & query, Aliases & aliases, const NameSet & sou
if (settings.optimize_move_functions_out_of_any)
optimizeAnyFunctions(query);
if (settings.optimize_rewrite_sum_if_to_count_if)
optimizeSumIfFunctions(query);
/// Remove injective functions inside uniq
if (settings.optimize_injective_functions_inside_uniq)
optimizeInjectiveFunctionsInsideUniq(query, context);

View File

@ -129,6 +129,7 @@ SRCS(
RequiredSourceColumnsData.cpp
RequiredSourceColumnsVisitor.cpp
RewriteAnyFunctionVisitor.cpp
RewriteSumIfFunctionVisitor.cpp
RowRefs.cpp
Set.cpp
SetVariants.cpp

View File

@ -0,0 +1,4 @@
<test>
<query>SELECT sumIf(1, 0) FROM numbers(100000000)</query>
<query>SELECT sumIf(1, 1) FROM numbers(100000000)</query>
</test>

View File

@ -0,0 +1,24 @@
0
0 0 1
0
50
50 50 1
50
50
50 50 50 1 0
50
50
50 50 50 1 0
50
0
0 0 1
0
50
50 50 1
50
50
50 50 50 1 0
50
50
50 50 50 1 0
50

View File

@ -0,0 +1,35 @@
SET optimize_rewrite_sum_if_to_count_if = 0;
SELECT sumIf(1, number % 2 > 2) FROM numbers(100);
SELECT sumIf(1 as one_expr, number % 2 > 2 as cond_expr), sum(cond_expr), one_expr FROM numbers(100);
SELECT countIf(number % 2 > 2) FROM numbers(100);
SELECT sumIf(1, number % 2 == 0) FROM numbers(100);
SELECT sumIf(1 as one_expr, number % 2 == 0 as cond_expr), sum(cond_expr), one_expr FROM numbers(100);
SELECT countIf(number % 2 == 0) FROM numbers(100);
SELECT sum(if(number % 2 == 0, 1, 0)) FROM numbers(100);
SELECT sum(if(number % 2 == 0 as cond_expr, 1 as one_expr, 0 as zero_expr) as if_expr), sum(cond_expr), sum(if_expr), one_expr, zero_expr FROM numbers(100);
SELECT countIf(number % 2 == 0) FROM numbers(100);
SELECT sum(if(number % 2 == 0, 0, 1)) FROM numbers(100);
SELECT sum(if(number % 2 == 0 as cond_expr, 0 as zero_expr, 1 as one_expr) as if_expr), sum(cond_expr), sum(if_expr), one_expr, zero_expr FROM numbers(100);
SELECT countIf(number % 2 != 0) FROM numbers(100);
SET optimize_rewrite_sum_if_to_count_if = 1;
SELECT sumIf(1, number % 2 > 2) FROM numbers(100);
SELECT sumIf(1 as one_expr, number % 2 > 2 as cond_expr), sum(cond_expr), one_expr FROM numbers(100);
SELECT countIf(number % 2 > 2) FROM numbers(100);
SELECT sumIf(1, number % 2 == 0) FROM numbers(100);
SELECT sumIf(1 as one_expr, number % 2 == 0 as cond_expr), sum(cond_expr), one_expr FROM numbers(100);
SELECT countIf(number % 2 == 0) FROM numbers(100);
SELECT sum(if(number % 2 == 0, 1, 0)) FROM numbers(100);
SELECT sum(if(number % 2 == 0 as cond_expr, 1 as one_expr, 0 as zero_expr) as if_expr), sum(cond_expr), sum(if_expr), one_expr, zero_expr FROM numbers(100);
SELECT countIf(number % 2 == 0) FROM numbers(100);
SELECT sum(if(number % 2 == 0, 0, 1)) FROM numbers(100);
SELECT sum(if(number % 2 == 0 as cond_expr, 0 as zero_expr, 1 as one_expr) as if_expr), sum(cond_expr), sum(if_expr), one_expr, zero_expr FROM numbers(100);
SELECT countIf(number % 2 != 0) FROM numbers(100);