mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-21 15:12:02 +00:00
Rewrite sum(if())
and sumIf
to countIf
in special cases (#17041)
Co-authored-by: vdimir <vdimir@yandex-team.ru>
This commit is contained in:
parent
0424db3e67
commit
e75b116466
@ -416,6 +416,7 @@ class IColumn;
|
||||
M(Bool, use_antlr_parser, false, "Parse incoming queries using ANTLR-generated experimental parser", 0) \
|
||||
M(Bool, async_socket_for_remote, true, "Asynchronously read from socket executing remote query", 0) \
|
||||
\
|
||||
M(Bool, optimize_rewrite_sum_if_to_count_if, true, "Rewrite sumIf() and sum(if()) function countIf() function when logically equivalent", 0) \
|
||||
/** Obsolete settings that do nothing but left for compatibility reasons. Remove each one after half a year of obsolescence. */ \
|
||||
\
|
||||
M(UInt64, max_memory_usage_for_all_queries, 0, "Obsolete. Will be removed after 2020-10-20", 0) \
|
||||
|
@ -68,11 +68,11 @@ struct NeedChild
|
||||
};
|
||||
|
||||
/// Simple matcher for one node type. Use need_child function for complex traversal logic.
|
||||
template <typename Data_, NeedChild::Condition need_child = NeedChild::all, typename T = ASTPtr>
|
||||
template <typename DataImpl, NeedChild::Condition need_child = NeedChild::all, typename T = ASTPtr>
|
||||
class OneTypeMatcher
|
||||
{
|
||||
public:
|
||||
using Data = Data_;
|
||||
using Data = DataImpl;
|
||||
using TypeToVisit = typename Data::TypeToVisit;
|
||||
|
||||
static bool needChildVisit(const ASTPtr & node, const ASTPtr & child) { return need_child(node, child); }
|
||||
|
91
src/Interpreters/RewriteSumIfFunctionVisitor.cpp
Normal file
91
src/Interpreters/RewriteSumIfFunctionVisitor.cpp
Normal file
@ -0,0 +1,91 @@
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <Interpreters/RewriteSumIfFunctionVisitor.h>
|
||||
#include <Parsers/ASTFunction.h>
|
||||
#include <Parsers/ASTLiteral.h>
|
||||
#include <Common/typeid_cast.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
void RewriteSumIfFunctionMatcher::visit(ASTPtr & ast, Data & data)
|
||||
{
|
||||
if (auto * func = ast->as<ASTFunction>())
|
||||
visit(*func, ast, data);
|
||||
}
|
||||
|
||||
static ASTPtr createNewFunctionWithOneArgument(const String & func_name, const ASTPtr & argument)
|
||||
{
|
||||
auto new_func = std::make_shared<ASTFunction>();
|
||||
new_func->name = func_name;
|
||||
|
||||
auto new_arguments = std::make_shared<ASTExpressionList>();
|
||||
new_arguments->children.push_back(argument);
|
||||
new_func->arguments = new_arguments;
|
||||
new_func->children.push_back(new_arguments);
|
||||
return new_func;
|
||||
}
|
||||
|
||||
void RewriteSumIfFunctionMatcher::visit(const ASTFunction & func, ASTPtr & ast, Data &)
|
||||
{
|
||||
if (!func.arguments || func.arguments->children.empty())
|
||||
return;
|
||||
|
||||
auto lower_name = Poco::toLower(func.name);
|
||||
|
||||
if (lower_name != "sum" && lower_name != "sumif")
|
||||
return;
|
||||
|
||||
auto & func_arguments = func.arguments->children;
|
||||
|
||||
if (lower_name == "sumif")
|
||||
{
|
||||
/// sumIf(1, cond) -> countIf(cond)
|
||||
const auto * literal = func_arguments[0]->as<ASTLiteral>();
|
||||
if (func_arguments.size() == 2 && literal && literal->value.get<UInt64>() == 1)
|
||||
{
|
||||
auto new_func = createNewFunctionWithOneArgument("countIf", func_arguments[1]);
|
||||
new_func->setAlias(func.alias);
|
||||
ast = std::move(new_func);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
else
|
||||
{
|
||||
const auto * nested_func = func_arguments[0]->as<ASTFunction>();
|
||||
|
||||
if (!nested_func || Poco::toLower(nested_func->name) != "if" || nested_func->arguments->children.size() != 3)
|
||||
return;
|
||||
|
||||
auto & if_arguments = nested_func->arguments->children;
|
||||
|
||||
const auto * first_literal = if_arguments[1]->as<ASTLiteral>();
|
||||
const auto * second_literal = if_arguments[2]->as<ASTLiteral>();
|
||||
|
||||
if (first_literal && second_literal)
|
||||
{
|
||||
auto first_value = first_literal->value.get<UInt64>();
|
||||
auto second_value = second_literal->value.get<UInt64>();
|
||||
/// sum(if(cond, 1, 0)) -> countIf(cond)
|
||||
if (first_value == 1 && second_value == 0)
|
||||
{
|
||||
auto new_func = createNewFunctionWithOneArgument("countIf", if_arguments[0]);
|
||||
new_func->setAlias(func.alias);
|
||||
ast = std::move(new_func);
|
||||
return;
|
||||
}
|
||||
/// sum(if(cond, 0, 1)) -> countIf(not(cond))
|
||||
if (first_value == 0 && second_value == 1)
|
||||
{
|
||||
auto not_func = createNewFunctionWithOneArgument("not", if_arguments[0]);
|
||||
auto new_func = createNewFunctionWithOneArgument("countIf", not_func);
|
||||
new_func->setAlias(func.alias);
|
||||
ast = std::move(new_func);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
30
src/Interpreters/RewriteSumIfFunctionVisitor.h
Normal file
30
src/Interpreters/RewriteSumIfFunctionVisitor.h
Normal file
@ -0,0 +1,30 @@
|
||||
#pragma once
|
||||
|
||||
#include <unordered_set>
|
||||
|
||||
#include <Parsers/IAST.h>
|
||||
#include <Interpreters/InDepthNodeVisitor.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
class ASTFunction;
|
||||
|
||||
/// Rewrite 'sum(if())' and 'sumIf' functions to counIf.
|
||||
/// sumIf(1, cond) -> countIf(1, cond)
|
||||
/// sum(if(cond, 1, 0)) -> countIf(cond)
|
||||
/// sum(if(cond, 0, 1)) -> countIf(not(cond))
|
||||
class RewriteSumIfFunctionMatcher
|
||||
{
|
||||
public:
|
||||
struct Data
|
||||
{
|
||||
};
|
||||
|
||||
static void visit(ASTPtr & ast, Data &);
|
||||
static void visit(const ASTFunction &, ASTPtr & ast, Data &);
|
||||
static bool needChildVisit(const ASTPtr &, const ASTPtr &) { return true; }
|
||||
};
|
||||
|
||||
using RewriteSumIfFunctionVisitor = InDepthNodeVisitor<RewriteSumIfFunctionMatcher, false>;
|
||||
}
|
@ -28,6 +28,7 @@
|
||||
#include <Functions/FunctionFactory.h>
|
||||
#include <Storages/StorageInMemoryMetadata.h>
|
||||
|
||||
#include <Interpreters/RewriteSumIfFunctionVisitor.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
@ -548,6 +549,13 @@ void optimizeAnyFunctions(ASTPtr & query)
|
||||
RewriteAnyFunctionVisitor(data).visit(query);
|
||||
}
|
||||
|
||||
void optimizeSumIfFunctions(ASTPtr & query)
|
||||
{
|
||||
RewriteSumIfFunctionVisitor::Data data = {};
|
||||
RewriteSumIfFunctionVisitor(data).visit(query);
|
||||
}
|
||||
|
||||
|
||||
void optimizeInjectiveFunctionsInsideUniq(ASTPtr & query, const Context & context)
|
||||
{
|
||||
RemoveInjectiveFunctionsVisitor::Data data = {context};
|
||||
@ -608,6 +616,9 @@ void TreeOptimizer::apply(ASTPtr & query, Aliases & aliases, const NameSet & sou
|
||||
if (settings.optimize_move_functions_out_of_any)
|
||||
optimizeAnyFunctions(query);
|
||||
|
||||
if (settings.optimize_rewrite_sum_if_to_count_if)
|
||||
optimizeSumIfFunctions(query);
|
||||
|
||||
/// Remove injective functions inside uniq
|
||||
if (settings.optimize_injective_functions_inside_uniq)
|
||||
optimizeInjectiveFunctionsInsideUniq(query, context);
|
||||
|
@ -129,6 +129,7 @@ SRCS(
|
||||
RequiredSourceColumnsData.cpp
|
||||
RequiredSourceColumnsVisitor.cpp
|
||||
RewriteAnyFunctionVisitor.cpp
|
||||
RewriteSumIfFunctionVisitor.cpp
|
||||
RowRefs.cpp
|
||||
Set.cpp
|
||||
SetVariants.cpp
|
||||
|
4
tests/performance/rewrite_sumIf.xml
Normal file
4
tests/performance/rewrite_sumIf.xml
Normal file
@ -0,0 +1,4 @@
|
||||
<test>
|
||||
<query>SELECT sumIf(1, 0) FROM numbers(100000000)</query>
|
||||
<query>SELECT sumIf(1, 1) FROM numbers(100000000)</query>
|
||||
</test>
|
24
tests/queries/0_stateless/01646_rewrite_sum_if.reference
Normal file
24
tests/queries/0_stateless/01646_rewrite_sum_if.reference
Normal file
@ -0,0 +1,24 @@
|
||||
0
|
||||
0 0 1
|
||||
0
|
||||
50
|
||||
50 50 1
|
||||
50
|
||||
50
|
||||
50 50 50 1 0
|
||||
50
|
||||
50
|
||||
50 50 50 1 0
|
||||
50
|
||||
0
|
||||
0 0 1
|
||||
0
|
||||
50
|
||||
50 50 1
|
||||
50
|
||||
50
|
||||
50 50 50 1 0
|
||||
50
|
||||
50
|
||||
50 50 50 1 0
|
||||
50
|
35
tests/queries/0_stateless/01646_rewrite_sum_if.sql
Normal file
35
tests/queries/0_stateless/01646_rewrite_sum_if.sql
Normal file
@ -0,0 +1,35 @@
|
||||
SET optimize_rewrite_sum_if_to_count_if = 0;
|
||||
|
||||
SELECT sumIf(1, number % 2 > 2) FROM numbers(100);
|
||||
SELECT sumIf(1 as one_expr, number % 2 > 2 as cond_expr), sum(cond_expr), one_expr FROM numbers(100);
|
||||
SELECT countIf(number % 2 > 2) FROM numbers(100);
|
||||
|
||||
SELECT sumIf(1, number % 2 == 0) FROM numbers(100);
|
||||
SELECT sumIf(1 as one_expr, number % 2 == 0 as cond_expr), sum(cond_expr), one_expr FROM numbers(100);
|
||||
SELECT countIf(number % 2 == 0) FROM numbers(100);
|
||||
|
||||
SELECT sum(if(number % 2 == 0, 1, 0)) FROM numbers(100);
|
||||
SELECT sum(if(number % 2 == 0 as cond_expr, 1 as one_expr, 0 as zero_expr) as if_expr), sum(cond_expr), sum(if_expr), one_expr, zero_expr FROM numbers(100);
|
||||
SELECT countIf(number % 2 == 0) FROM numbers(100);
|
||||
|
||||
SELECT sum(if(number % 2 == 0, 0, 1)) FROM numbers(100);
|
||||
SELECT sum(if(number % 2 == 0 as cond_expr, 0 as zero_expr, 1 as one_expr) as if_expr), sum(cond_expr), sum(if_expr), one_expr, zero_expr FROM numbers(100);
|
||||
SELECT countIf(number % 2 != 0) FROM numbers(100);
|
||||
|
||||
SET optimize_rewrite_sum_if_to_count_if = 1;
|
||||
|
||||
SELECT sumIf(1, number % 2 > 2) FROM numbers(100);
|
||||
SELECT sumIf(1 as one_expr, number % 2 > 2 as cond_expr), sum(cond_expr), one_expr FROM numbers(100);
|
||||
SELECT countIf(number % 2 > 2) FROM numbers(100);
|
||||
|
||||
SELECT sumIf(1, number % 2 == 0) FROM numbers(100);
|
||||
SELECT sumIf(1 as one_expr, number % 2 == 0 as cond_expr), sum(cond_expr), one_expr FROM numbers(100);
|
||||
SELECT countIf(number % 2 == 0) FROM numbers(100);
|
||||
|
||||
SELECT sum(if(number % 2 == 0, 1, 0)) FROM numbers(100);
|
||||
SELECT sum(if(number % 2 == 0 as cond_expr, 1 as one_expr, 0 as zero_expr) as if_expr), sum(cond_expr), sum(if_expr), one_expr, zero_expr FROM numbers(100);
|
||||
SELECT countIf(number % 2 == 0) FROM numbers(100);
|
||||
|
||||
SELECT sum(if(number % 2 == 0, 0, 1)) FROM numbers(100);
|
||||
SELECT sum(if(number % 2 == 0 as cond_expr, 0 as zero_expr, 1 as one_expr) as if_expr), sum(cond_expr), sum(if_expr), one_expr, zero_expr FROM numbers(100);
|
||||
SELECT countIf(number % 2 != 0) FROM numbers(100);
|
Loading…
Reference in New Issue
Block a user