mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-21 23:21:59 +00:00
Rewrite sum(if())
and sumIf
to countIf
in special cases (#17041)
Co-authored-by: vdimir <vdimir@yandex-team.ru>
This commit is contained in:
parent
0424db3e67
commit
e75b116466
@ -416,6 +416,7 @@ class IColumn;
|
|||||||
M(Bool, use_antlr_parser, false, "Parse incoming queries using ANTLR-generated experimental parser", 0) \
|
M(Bool, use_antlr_parser, false, "Parse incoming queries using ANTLR-generated experimental parser", 0) \
|
||||||
M(Bool, async_socket_for_remote, true, "Asynchronously read from socket executing remote query", 0) \
|
M(Bool, async_socket_for_remote, true, "Asynchronously read from socket executing remote query", 0) \
|
||||||
\
|
\
|
||||||
|
M(Bool, optimize_rewrite_sum_if_to_count_if, true, "Rewrite sumIf() and sum(if()) function countIf() function when logically equivalent", 0) \
|
||||||
/** Obsolete settings that do nothing but left for compatibility reasons. Remove each one after half a year of obsolescence. */ \
|
/** Obsolete settings that do nothing but left for compatibility reasons. Remove each one after half a year of obsolescence. */ \
|
||||||
\
|
\
|
||||||
M(UInt64, max_memory_usage_for_all_queries, 0, "Obsolete. Will be removed after 2020-10-20", 0) \
|
M(UInt64, max_memory_usage_for_all_queries, 0, "Obsolete. Will be removed after 2020-10-20", 0) \
|
||||||
|
@ -68,11 +68,11 @@ struct NeedChild
|
|||||||
};
|
};
|
||||||
|
|
||||||
/// Simple matcher for one node type. Use need_child function for complex traversal logic.
|
/// Simple matcher for one node type. Use need_child function for complex traversal logic.
|
||||||
template <typename Data_, NeedChild::Condition need_child = NeedChild::all, typename T = ASTPtr>
|
template <typename DataImpl, NeedChild::Condition need_child = NeedChild::all, typename T = ASTPtr>
|
||||||
class OneTypeMatcher
|
class OneTypeMatcher
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
using Data = Data_;
|
using Data = DataImpl;
|
||||||
using TypeToVisit = typename Data::TypeToVisit;
|
using TypeToVisit = typename Data::TypeToVisit;
|
||||||
|
|
||||||
static bool needChildVisit(const ASTPtr & node, const ASTPtr & child) { return need_child(node, child); }
|
static bool needChildVisit(const ASTPtr & node, const ASTPtr & child) { return need_child(node, child); }
|
||||||
|
91
src/Interpreters/RewriteSumIfFunctionVisitor.cpp
Normal file
91
src/Interpreters/RewriteSumIfFunctionVisitor.cpp
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||||
|
#include <Interpreters/RewriteSumIfFunctionVisitor.h>
|
||||||
|
#include <Parsers/ASTFunction.h>
|
||||||
|
#include <Parsers/ASTLiteral.h>
|
||||||
|
#include <Common/typeid_cast.h>
|
||||||
|
|
||||||
|
namespace DB
|
||||||
|
{
|
||||||
|
|
||||||
|
void RewriteSumIfFunctionMatcher::visit(ASTPtr & ast, Data & data)
|
||||||
|
{
|
||||||
|
if (auto * func = ast->as<ASTFunction>())
|
||||||
|
visit(*func, ast, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
static ASTPtr createNewFunctionWithOneArgument(const String & func_name, const ASTPtr & argument)
|
||||||
|
{
|
||||||
|
auto new_func = std::make_shared<ASTFunction>();
|
||||||
|
new_func->name = func_name;
|
||||||
|
|
||||||
|
auto new_arguments = std::make_shared<ASTExpressionList>();
|
||||||
|
new_arguments->children.push_back(argument);
|
||||||
|
new_func->arguments = new_arguments;
|
||||||
|
new_func->children.push_back(new_arguments);
|
||||||
|
return new_func;
|
||||||
|
}
|
||||||
|
|
||||||
|
void RewriteSumIfFunctionMatcher::visit(const ASTFunction & func, ASTPtr & ast, Data &)
|
||||||
|
{
|
||||||
|
if (!func.arguments || func.arguments->children.empty())
|
||||||
|
return;
|
||||||
|
|
||||||
|
auto lower_name = Poco::toLower(func.name);
|
||||||
|
|
||||||
|
if (lower_name != "sum" && lower_name != "sumif")
|
||||||
|
return;
|
||||||
|
|
||||||
|
auto & func_arguments = func.arguments->children;
|
||||||
|
|
||||||
|
if (lower_name == "sumif")
|
||||||
|
{
|
||||||
|
/// sumIf(1, cond) -> countIf(cond)
|
||||||
|
const auto * literal = func_arguments[0]->as<ASTLiteral>();
|
||||||
|
if (func_arguments.size() == 2 && literal && literal->value.get<UInt64>() == 1)
|
||||||
|
{
|
||||||
|
auto new_func = createNewFunctionWithOneArgument("countIf", func_arguments[1]);
|
||||||
|
new_func->setAlias(func.alias);
|
||||||
|
ast = std::move(new_func);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
else
|
||||||
|
{
|
||||||
|
const auto * nested_func = func_arguments[0]->as<ASTFunction>();
|
||||||
|
|
||||||
|
if (!nested_func || Poco::toLower(nested_func->name) != "if" || nested_func->arguments->children.size() != 3)
|
||||||
|
return;
|
||||||
|
|
||||||
|
auto & if_arguments = nested_func->arguments->children;
|
||||||
|
|
||||||
|
const auto * first_literal = if_arguments[1]->as<ASTLiteral>();
|
||||||
|
const auto * second_literal = if_arguments[2]->as<ASTLiteral>();
|
||||||
|
|
||||||
|
if (first_literal && second_literal)
|
||||||
|
{
|
||||||
|
auto first_value = first_literal->value.get<UInt64>();
|
||||||
|
auto second_value = second_literal->value.get<UInt64>();
|
||||||
|
/// sum(if(cond, 1, 0)) -> countIf(cond)
|
||||||
|
if (first_value == 1 && second_value == 0)
|
||||||
|
{
|
||||||
|
auto new_func = createNewFunctionWithOneArgument("countIf", if_arguments[0]);
|
||||||
|
new_func->setAlias(func.alias);
|
||||||
|
ast = std::move(new_func);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
/// sum(if(cond, 0, 1)) -> countIf(not(cond))
|
||||||
|
if (first_value == 0 && second_value == 1)
|
||||||
|
{
|
||||||
|
auto not_func = createNewFunctionWithOneArgument("not", if_arguments[0]);
|
||||||
|
auto new_func = createNewFunctionWithOneArgument("countIf", not_func);
|
||||||
|
new_func->setAlias(func.alias);
|
||||||
|
ast = std::move(new_func);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
30
src/Interpreters/RewriteSumIfFunctionVisitor.h
Normal file
30
src/Interpreters/RewriteSumIfFunctionVisitor.h
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <unordered_set>
|
||||||
|
|
||||||
|
#include <Parsers/IAST.h>
|
||||||
|
#include <Interpreters/InDepthNodeVisitor.h>
|
||||||
|
|
||||||
|
namespace DB
|
||||||
|
{
|
||||||
|
|
||||||
|
class ASTFunction;
|
||||||
|
|
||||||
|
/// Rewrite 'sum(if())' and 'sumIf' functions to counIf.
|
||||||
|
/// sumIf(1, cond) -> countIf(1, cond)
|
||||||
|
/// sum(if(cond, 1, 0)) -> countIf(cond)
|
||||||
|
/// sum(if(cond, 0, 1)) -> countIf(not(cond))
|
||||||
|
class RewriteSumIfFunctionMatcher
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
struct Data
|
||||||
|
{
|
||||||
|
};
|
||||||
|
|
||||||
|
static void visit(ASTPtr & ast, Data &);
|
||||||
|
static void visit(const ASTFunction &, ASTPtr & ast, Data &);
|
||||||
|
static bool needChildVisit(const ASTPtr &, const ASTPtr &) { return true; }
|
||||||
|
};
|
||||||
|
|
||||||
|
using RewriteSumIfFunctionVisitor = InDepthNodeVisitor<RewriteSumIfFunctionMatcher, false>;
|
||||||
|
}
|
@ -28,6 +28,7 @@
|
|||||||
#include <Functions/FunctionFactory.h>
|
#include <Functions/FunctionFactory.h>
|
||||||
#include <Storages/StorageInMemoryMetadata.h>
|
#include <Storages/StorageInMemoryMetadata.h>
|
||||||
|
|
||||||
|
#include <Interpreters/RewriteSumIfFunctionVisitor.h>
|
||||||
|
|
||||||
namespace DB
|
namespace DB
|
||||||
{
|
{
|
||||||
@ -548,6 +549,13 @@ void optimizeAnyFunctions(ASTPtr & query)
|
|||||||
RewriteAnyFunctionVisitor(data).visit(query);
|
RewriteAnyFunctionVisitor(data).visit(query);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void optimizeSumIfFunctions(ASTPtr & query)
|
||||||
|
{
|
||||||
|
RewriteSumIfFunctionVisitor::Data data = {};
|
||||||
|
RewriteSumIfFunctionVisitor(data).visit(query);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
void optimizeInjectiveFunctionsInsideUniq(ASTPtr & query, const Context & context)
|
void optimizeInjectiveFunctionsInsideUniq(ASTPtr & query, const Context & context)
|
||||||
{
|
{
|
||||||
RemoveInjectiveFunctionsVisitor::Data data = {context};
|
RemoveInjectiveFunctionsVisitor::Data data = {context};
|
||||||
@ -608,6 +616,9 @@ void TreeOptimizer::apply(ASTPtr & query, Aliases & aliases, const NameSet & sou
|
|||||||
if (settings.optimize_move_functions_out_of_any)
|
if (settings.optimize_move_functions_out_of_any)
|
||||||
optimizeAnyFunctions(query);
|
optimizeAnyFunctions(query);
|
||||||
|
|
||||||
|
if (settings.optimize_rewrite_sum_if_to_count_if)
|
||||||
|
optimizeSumIfFunctions(query);
|
||||||
|
|
||||||
/// Remove injective functions inside uniq
|
/// Remove injective functions inside uniq
|
||||||
if (settings.optimize_injective_functions_inside_uniq)
|
if (settings.optimize_injective_functions_inside_uniq)
|
||||||
optimizeInjectiveFunctionsInsideUniq(query, context);
|
optimizeInjectiveFunctionsInsideUniq(query, context);
|
||||||
|
@ -129,6 +129,7 @@ SRCS(
|
|||||||
RequiredSourceColumnsData.cpp
|
RequiredSourceColumnsData.cpp
|
||||||
RequiredSourceColumnsVisitor.cpp
|
RequiredSourceColumnsVisitor.cpp
|
||||||
RewriteAnyFunctionVisitor.cpp
|
RewriteAnyFunctionVisitor.cpp
|
||||||
|
RewriteSumIfFunctionVisitor.cpp
|
||||||
RowRefs.cpp
|
RowRefs.cpp
|
||||||
Set.cpp
|
Set.cpp
|
||||||
SetVariants.cpp
|
SetVariants.cpp
|
||||||
|
4
tests/performance/rewrite_sumIf.xml
Normal file
4
tests/performance/rewrite_sumIf.xml
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
<test>
|
||||||
|
<query>SELECT sumIf(1, 0) FROM numbers(100000000)</query>
|
||||||
|
<query>SELECT sumIf(1, 1) FROM numbers(100000000)</query>
|
||||||
|
</test>
|
24
tests/queries/0_stateless/01646_rewrite_sum_if.reference
Normal file
24
tests/queries/0_stateless/01646_rewrite_sum_if.reference
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
0
|
||||||
|
0 0 1
|
||||||
|
0
|
||||||
|
50
|
||||||
|
50 50 1
|
||||||
|
50
|
||||||
|
50
|
||||||
|
50 50 50 1 0
|
||||||
|
50
|
||||||
|
50
|
||||||
|
50 50 50 1 0
|
||||||
|
50
|
||||||
|
0
|
||||||
|
0 0 1
|
||||||
|
0
|
||||||
|
50
|
||||||
|
50 50 1
|
||||||
|
50
|
||||||
|
50
|
||||||
|
50 50 50 1 0
|
||||||
|
50
|
||||||
|
50
|
||||||
|
50 50 50 1 0
|
||||||
|
50
|
35
tests/queries/0_stateless/01646_rewrite_sum_if.sql
Normal file
35
tests/queries/0_stateless/01646_rewrite_sum_if.sql
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
SET optimize_rewrite_sum_if_to_count_if = 0;
|
||||||
|
|
||||||
|
SELECT sumIf(1, number % 2 > 2) FROM numbers(100);
|
||||||
|
SELECT sumIf(1 as one_expr, number % 2 > 2 as cond_expr), sum(cond_expr), one_expr FROM numbers(100);
|
||||||
|
SELECT countIf(number % 2 > 2) FROM numbers(100);
|
||||||
|
|
||||||
|
SELECT sumIf(1, number % 2 == 0) FROM numbers(100);
|
||||||
|
SELECT sumIf(1 as one_expr, number % 2 == 0 as cond_expr), sum(cond_expr), one_expr FROM numbers(100);
|
||||||
|
SELECT countIf(number % 2 == 0) FROM numbers(100);
|
||||||
|
|
||||||
|
SELECT sum(if(number % 2 == 0, 1, 0)) FROM numbers(100);
|
||||||
|
SELECT sum(if(number % 2 == 0 as cond_expr, 1 as one_expr, 0 as zero_expr) as if_expr), sum(cond_expr), sum(if_expr), one_expr, zero_expr FROM numbers(100);
|
||||||
|
SELECT countIf(number % 2 == 0) FROM numbers(100);
|
||||||
|
|
||||||
|
SELECT sum(if(number % 2 == 0, 0, 1)) FROM numbers(100);
|
||||||
|
SELECT sum(if(number % 2 == 0 as cond_expr, 0 as zero_expr, 1 as one_expr) as if_expr), sum(cond_expr), sum(if_expr), one_expr, zero_expr FROM numbers(100);
|
||||||
|
SELECT countIf(number % 2 != 0) FROM numbers(100);
|
||||||
|
|
||||||
|
SET optimize_rewrite_sum_if_to_count_if = 1;
|
||||||
|
|
||||||
|
SELECT sumIf(1, number % 2 > 2) FROM numbers(100);
|
||||||
|
SELECT sumIf(1 as one_expr, number % 2 > 2 as cond_expr), sum(cond_expr), one_expr FROM numbers(100);
|
||||||
|
SELECT countIf(number % 2 > 2) FROM numbers(100);
|
||||||
|
|
||||||
|
SELECT sumIf(1, number % 2 == 0) FROM numbers(100);
|
||||||
|
SELECT sumIf(1 as one_expr, number % 2 == 0 as cond_expr), sum(cond_expr), one_expr FROM numbers(100);
|
||||||
|
SELECT countIf(number % 2 == 0) FROM numbers(100);
|
||||||
|
|
||||||
|
SELECT sum(if(number % 2 == 0, 1, 0)) FROM numbers(100);
|
||||||
|
SELECT sum(if(number % 2 == 0 as cond_expr, 1 as one_expr, 0 as zero_expr) as if_expr), sum(cond_expr), sum(if_expr), one_expr, zero_expr FROM numbers(100);
|
||||||
|
SELECT countIf(number % 2 == 0) FROM numbers(100);
|
||||||
|
|
||||||
|
SELECT sum(if(number % 2 == 0, 0, 1)) FROM numbers(100);
|
||||||
|
SELECT sum(if(number % 2 == 0 as cond_expr, 0 as zero_expr, 1 as one_expr) as if_expr), sum(cond_expr), sum(if_expr), one_expr, zero_expr FROM numbers(100);
|
||||||
|
SELECT countIf(number % 2 != 0) FROM numbers(100);
|
Loading…
Reference in New Issue
Block a user