remove ast optimizer

This commit is contained in:
taiyang-li 2023-01-30 15:20:34 +08:00
parent a728dd71ac
commit a9bc770505
5 changed files with 212 additions and 179 deletions

View File

@ -1,74 +0,0 @@
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <Interpreters/RewriteAggregateFunctionWithIfVisitor.h>
#include <Parsers/ASTFunction.h>
#include <Parsers/ASTLiteral.h>
#include <Common/typeid_cast.h>
namespace DB
{
void RewriteAggregateFunctionWithIfMatcher::visit(ASTPtr & ast, Data & data)
{
if (auto * func = ast->as<ASTFunction>())
{
if (func->is_window_function)
return;
visit(*func, ast, data);
}
}
void RewriteAggregateFunctionWithIfMatcher::visit(const ASTFunction & func, ASTPtr & ast, Data &)
{
const auto & factory = AggregateFunctionFactory::instance();
if (!factory.isAggregateFunctionName(func.name))
return;
if (!func.arguments || func.arguments->children.size() != 1)
return;
auto * if_func = func.arguments->children.back()->as<ASTFunction>();
if (!if_func || Poco::toLower(if_func->name) != "if")
return;
auto lower_name = Poco::toLower(func.name);
const auto & if_arguments = if_func->arguments->children;
const auto * first_literal = if_arguments[1]->as<ASTLiteral>();
const auto * second_literal = if_arguments[2]->as<ASTLiteral>();
if (second_literal)
{
if (second_literal->value.isNull()
|| (lower_name == "sum" && isInt64OrUInt64FieldType(second_literal->value.getType())
&& second_literal->value.get<UInt64>() == 0))
{
/// avg(if(cond, a, null)) -> avgIf(a, cond)
/// sum(if(cond, a, 0)) -> sumIf(a, cond)
auto new_func
= makeASTFunction(func.name + (second_literal->value.isNull() ? "IfOrNull" : "If"), if_arguments[1], if_arguments[0]);
new_func->setAlias(func.alias);
new_func->parameters = func.parameters;
ast = std::move(new_func);
return;
}
}
else if (first_literal)
{
if (first_literal->value.isNull()
|| (lower_name == "sum" && isInt64OrUInt64FieldType(first_literal->value.getType()) && first_literal->value.get<UInt64>() == 0))
{
/// avg(if(cond, null, a) -> avgIf(a, !cond))
/// sum(if(cond, 0, a) -> sumIf(a, !cond))
auto not_func = makeASTFunction("not", if_arguments[0]);
auto new_func
= makeASTFunction(func.name + (first_literal->value.isNull() ? "IfOrNull" : "If"), if_arguments[2], std::move(not_func));
new_func->setAlias(func.alias);
new_func->parameters = func.parameters;
ast = std::move(new_func);
return;
}
}
}
}

View File

@ -1,31 +0,0 @@
#pragma once
#include <unordered_set>
#include <Parsers/IAST.h>
#include <Interpreters/InDepthNodeVisitor.h>
namespace DB
{
class ASTFunction;
/// Rewrite '<aggregate-function>(if())' to '<aggregate-function>If[OrNull]()'
/// sum(if(cond, a, 0)) -> sumIf[OrNull](a, cond)
/// sum(if(cond, a, null)) -> sumIf[OrNull](a, cond)
/// avg(if(cond, a, null)) -> avgIf[OrNull](a, cond)
/// ...
class RewriteAggregateFunctionWithIfMatcher
{
public:
struct Data
{
};
static void visit(ASTPtr & ast, Data &);
static void visit(const ASTFunction &, ASTPtr & ast, Data &);
static bool needChildVisit(const ASTPtr &, const ASTPtr &) { return true; }
};
using RewriteAggregateFunctionWithIfVisitor = InDepthNodeVisitor<RewriteAggregateFunctionWithIfMatcher, false>;
}

View File

@ -38,7 +38,6 @@
#include <Storages/IStorage.h>
#include <Interpreters/RewriteSumIfFunctionVisitor.h>
#include <Interpreters/RewriteAggregateFunctionWithIfVisitor.h>
namespace DB
{
@ -659,12 +658,6 @@ void optimizeSumIfFunctions(ASTPtr & query)
RewriteSumIfFunctionVisitor(data).visit(query);
}
void optimizeAggregateFunctionsWithIf(ASTPtr & query)
{
RewriteAggregateFunctionWithIfVisitor::Data data = {};
RewriteAggregateFunctionWithIfVisitor(data).visit(query);
}
void optimizeMultiIfToIf(ASTPtr & query)
{
OptimizeMultiIfToIfVisitor::Data data;
@ -794,9 +787,6 @@ void TreeOptimizer::apply(ASTPtr & query, TreeRewriterResult & result,
if (settings.optimize_multiif_to_if)
optimizeMultiIfToIf(query);
if (settings.optimize_rewrite_aggregate_function_with_if)
optimizeAggregateFunctionsWithIf(query);
if (settings.optimize_rewrite_sum_if_to_count_if)
optimizeSumIfFunctions(query);

View File

@ -1,35 +1,197 @@
SELECT sum(if(number % 2, number, 0))
FROM numbers(100)
SELECT sum(if(number % 2, 0, number))
FROM numbers(100)
SELECT sum(if(number % 2, number, NULL))
FROM numbers(100)
SELECT sum(if(number % 2, NULL, number))
FROM numbers(100)
SELECT avg(if(number % 2, number, NULL))
FROM numbers(100)
SELECT avg(if(number % 2, NULL, number))
FROM numbers(100)
SELECT quantiles(0.5, 0.9, 0.99)(if(number % 2, number, NULL))
FROM numbers(100)
SELECT quantiles(0.5, 0.9, 0.99)(if(number % 2, NULL, number))
FROM numbers(100)
SELECT sumIf(number, number % 2)
FROM numbers(100)
SELECT sumIf(number, NOT (number % 2))
FROM numbers(100)
SELECT sumIf(number, number % 2)
FROM numbers(100)
SELECT sumIf(number, NOT (number % 2))
FROM numbers(100)
SELECT avgIf(number, number % 2)
FROM numbers(100)
SELECT avgIf(number, NOT (number % 2))
FROM numbers(100)
SELECT quantilesIf(0.5, 0.9, 0.99)(number, number % 2)
FROM numbers(100)
SELECT quantilesIf(0.5, 0.9, 0.99)(number, NOT (number % 2))
FROM numbers(100)
QUERY id: 0
PROJECTION COLUMNS
sum(if(modulo(number, 2), number, 0)) UInt64
PROJECTION
LIST id: 1, nodes: 1
FUNCTION id: 2, function_name: sum, function_type: aggregate, result_type: UInt64
ARGUMENTS
LIST id: 3, nodes: 1
FUNCTION id: 4, function_name: if, function_type: ordinary, result_type: UInt64
ARGUMENTS
LIST id: 5, nodes: 3
FUNCTION id: 6, function_name: modulo, function_type: ordinary, result_type: UInt8
ARGUMENTS
LIST id: 7, nodes: 2
COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 9
CONSTANT id: 10, constant_value: UInt64_2, constant_value_type: UInt8
COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 9
CONSTANT id: 11, constant_value: UInt64_0, constant_value_type: UInt8
JOIN TREE
TABLE_FUNCTION id: 9, table_function_name: numbers
ARGUMENTS
LIST id: 12, nodes: 1
CONSTANT id: 13, constant_value: UInt64_100, constant_value_type: UInt8
QUERY id: 0
PROJECTION COLUMNS
sum(if(modulo(number, 2), 0, number)) UInt64
PROJECTION
LIST id: 1, nodes: 1
FUNCTION id: 2, function_name: sum, function_type: aggregate, result_type: UInt64
ARGUMENTS
LIST id: 3, nodes: 1
FUNCTION id: 4, function_name: if, function_type: ordinary, result_type: UInt64
ARGUMENTS
LIST id: 5, nodes: 3
FUNCTION id: 6, function_name: modulo, function_type: ordinary, result_type: UInt8
ARGUMENTS
LIST id: 7, nodes: 2
COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 9
CONSTANT id: 10, constant_value: UInt64_2, constant_value_type: UInt8
CONSTANT id: 11, constant_value: UInt64_0, constant_value_type: UInt8
COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 9
JOIN TREE
TABLE_FUNCTION id: 9, table_function_name: numbers
ARGUMENTS
LIST id: 12, nodes: 1
CONSTANT id: 13, constant_value: UInt64_100, constant_value_type: UInt8
QUERY id: 0
PROJECTION COLUMNS
sum(if(modulo(number, 2), number, NULL)) Nullable(UInt64)
PROJECTION
LIST id: 1, nodes: 1
FUNCTION id: 2, function_name: sum, function_type: aggregate, result_type: Nullable(UInt64)
ARGUMENTS
LIST id: 3, nodes: 1
FUNCTION id: 4, function_name: if, function_type: ordinary, result_type: Nullable(UInt64)
ARGUMENTS
LIST id: 5, nodes: 3
FUNCTION id: 6, function_name: modulo, function_type: ordinary, result_type: UInt8
ARGUMENTS
LIST id: 7, nodes: 2
COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 9
CONSTANT id: 10, constant_value: UInt64_2, constant_value_type: UInt8
COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 9
CONSTANT id: 11, constant_value: NULL, constant_value_type: Nullable(Nothing)
JOIN TREE
TABLE_FUNCTION id: 9, table_function_name: numbers
ARGUMENTS
LIST id: 12, nodes: 1
CONSTANT id: 13, constant_value: UInt64_100, constant_value_type: UInt8
QUERY id: 0
PROJECTION COLUMNS
sum(if(modulo(number, 2), NULL, number)) Nullable(UInt64)
PROJECTION
LIST id: 1, nodes: 1
FUNCTION id: 2, function_name: sum, function_type: aggregate, result_type: Nullable(UInt64)
ARGUMENTS
LIST id: 3, nodes: 1
FUNCTION id: 4, function_name: if, function_type: ordinary, result_type: Nullable(UInt64)
ARGUMENTS
LIST id: 5, nodes: 3
FUNCTION id: 6, function_name: modulo, function_type: ordinary, result_type: UInt8
ARGUMENTS
LIST id: 7, nodes: 2
COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 9
CONSTANT id: 10, constant_value: UInt64_2, constant_value_type: UInt8
CONSTANT id: 11, constant_value: NULL, constant_value_type: Nullable(Nothing)
COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 9
JOIN TREE
TABLE_FUNCTION id: 9, table_function_name: numbers
ARGUMENTS
LIST id: 12, nodes: 1
CONSTANT id: 13, constant_value: UInt64_100, constant_value_type: UInt8
QUERY id: 0
PROJECTION COLUMNS
avg(if(modulo(number, 2), number, NULL)) Nullable(Float64)
PROJECTION
LIST id: 1, nodes: 1
FUNCTION id: 2, function_name: avg, function_type: aggregate, result_type: Nullable(Float64)
ARGUMENTS
LIST id: 3, nodes: 1
FUNCTION id: 4, function_name: if, function_type: ordinary, result_type: Nullable(UInt64)
ARGUMENTS
LIST id: 5, nodes: 3
FUNCTION id: 6, function_name: modulo, function_type: ordinary, result_type: UInt8
ARGUMENTS
LIST id: 7, nodes: 2
COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 9
CONSTANT id: 10, constant_value: UInt64_2, constant_value_type: UInt8
COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 9
CONSTANT id: 11, constant_value: NULL, constant_value_type: Nullable(Nothing)
JOIN TREE
TABLE_FUNCTION id: 9, table_function_name: numbers
ARGUMENTS
LIST id: 12, nodes: 1
CONSTANT id: 13, constant_value: UInt64_100, constant_value_type: UInt8
QUERY id: 0
PROJECTION COLUMNS
avg(if(modulo(number, 2), NULL, number)) Nullable(Float64)
PROJECTION
LIST id: 1, nodes: 1
FUNCTION id: 2, function_name: avg, function_type: aggregate, result_type: Nullable(Float64)
ARGUMENTS
LIST id: 3, nodes: 1
FUNCTION id: 4, function_name: if, function_type: ordinary, result_type: Nullable(UInt64)
ARGUMENTS
LIST id: 5, nodes: 3
FUNCTION id: 6, function_name: modulo, function_type: ordinary, result_type: UInt8
ARGUMENTS
LIST id: 7, nodes: 2
COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 9
CONSTANT id: 10, constant_value: UInt64_2, constant_value_type: UInt8
CONSTANT id: 11, constant_value: NULL, constant_value_type: Nullable(Nothing)
COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 9
JOIN TREE
TABLE_FUNCTION id: 9, table_function_name: numbers
ARGUMENTS
LIST id: 12, nodes: 1
CONSTANT id: 13, constant_value: UInt64_100, constant_value_type: UInt8
QUERY id: 0
PROJECTION COLUMNS
quantiles(0.5, 0.9, 0.99)(if(modulo(number, 2), number, NULL)) Array(Float64)
PROJECTION
LIST id: 1, nodes: 1
FUNCTION id: 2, function_name: quantiles, function_type: aggregate, result_type: Array(Float64)
PARAMETERS
LIST id: 3, nodes: 3
CONSTANT id: 4, constant_value: Float64_0.5, constant_value_type: Float64
CONSTANT id: 5, constant_value: Float64_0.9, constant_value_type: Float64
CONSTANT id: 6, constant_value: Float64_0.99, constant_value_type: Float64
ARGUMENTS
LIST id: 7, nodes: 1
FUNCTION id: 8, function_name: if, function_type: ordinary, result_type: Nullable(UInt64)
ARGUMENTS
LIST id: 9, nodes: 3
FUNCTION id: 10, function_name: modulo, function_type: ordinary, result_type: UInt8
ARGUMENTS
LIST id: 11, nodes: 2
COLUMN id: 12, column_name: number, result_type: UInt64, source_id: 13
CONSTANT id: 14, constant_value: UInt64_2, constant_value_type: UInt8
COLUMN id: 12, column_name: number, result_type: UInt64, source_id: 13
CONSTANT id: 15, constant_value: NULL, constant_value_type: Nullable(Nothing)
JOIN TREE
TABLE_FUNCTION id: 13, table_function_name: numbers
ARGUMENTS
LIST id: 16, nodes: 1
CONSTANT id: 17, constant_value: UInt64_100, constant_value_type: UInt8
QUERY id: 0
PROJECTION COLUMNS
quantiles(0.5, 0.9, 0.99)(if(modulo(number, 2), NULL, number)) Array(Float64)
PROJECTION
LIST id: 1, nodes: 1
FUNCTION id: 2, function_name: quantiles, function_type: aggregate, result_type: Array(Float64)
PARAMETERS
LIST id: 3, nodes: 3
CONSTANT id: 4, constant_value: Float64_0.5, constant_value_type: Float64
CONSTANT id: 5, constant_value: Float64_0.9, constant_value_type: Float64
CONSTANT id: 6, constant_value: Float64_0.99, constant_value_type: Float64
ARGUMENTS
LIST id: 7, nodes: 1
FUNCTION id: 8, function_name: if, function_type: ordinary, result_type: Nullable(UInt64)
ARGUMENTS
LIST id: 9, nodes: 3
FUNCTION id: 10, function_name: modulo, function_type: ordinary, result_type: UInt8
ARGUMENTS
LIST id: 11, nodes: 2
COLUMN id: 12, column_name: number, result_type: UInt64, source_id: 13
CONSTANT id: 14, constant_value: UInt64_2, constant_value_type: UInt8
CONSTANT id: 15, constant_value: NULL, constant_value_type: Nullable(Nothing)
COLUMN id: 12, column_name: number, result_type: UInt64, source_id: 13
JOIN TREE
TABLE_FUNCTION id: 13, table_function_name: numbers
ARGUMENTS
LIST id: 16, nodes: 1
CONSTANT id: 17, constant_value: UInt64_100, constant_value_type: UInt8
QUERY id: 0
PROJECTION COLUMNS
sum(if(modulo(number, 2), number, 0)) UInt64
@ -76,7 +238,7 @@ QUERY id: 0
sum(if(modulo(number, 2), number, NULL)) Nullable(UInt64)
PROJECTION
LIST id: 1, nodes: 1
FUNCTION id: 2, function_name: sumIf, function_type: aggregate, result_type: UInt64
FUNCTION id: 2, function_name: sumOrNullIf, function_type: aggregate, result_type: Nullable(UInt64)
ARGUMENTS
LIST id: 3, nodes: 2
COLUMN id: 4, column_name: number, result_type: UInt64, source_id: 5
@ -95,7 +257,7 @@ QUERY id: 0
sum(if(modulo(number, 2), NULL, number)) Nullable(UInt64)
PROJECTION
LIST id: 1, nodes: 1
FUNCTION id: 2, function_name: sumIf, function_type: aggregate, result_type: UInt64
FUNCTION id: 2, function_name: sumOrNullIf, function_type: aggregate, result_type: Nullable(UInt64)
ARGUMENTS
LIST id: 3, nodes: 2
COLUMN id: 4, column_name: number, result_type: UInt64, source_id: 5
@ -117,7 +279,7 @@ QUERY id: 0
avg(if(modulo(number, 2), number, NULL)) Nullable(Float64)
PROJECTION
LIST id: 1, nodes: 1
FUNCTION id: 2, function_name: avgIf, function_type: aggregate, result_type: Float64
FUNCTION id: 2, function_name: avgOrNullIf, function_type: aggregate, result_type: Nullable(Float64)
ARGUMENTS
LIST id: 3, nodes: 2
COLUMN id: 4, column_name: number, result_type: UInt64, source_id: 5
@ -136,7 +298,7 @@ QUERY id: 0
avg(if(modulo(number, 2), NULL, number)) Nullable(Float64)
PROJECTION
LIST id: 1, nodes: 1
FUNCTION id: 2, function_name: avgIf, function_type: aggregate, result_type: Float64
FUNCTION id: 2, function_name: avgOrNullIf, function_type: aggregate, result_type: Nullable(Float64)
ARGUMENTS
LIST id: 3, nodes: 2
COLUMN id: 4, column_name: number, result_type: UInt64, source_id: 5

View File

@ -1,33 +1,19 @@
set optimize_rewrite_aggregate_function_with_if = false;
explain syntax select sum(if(number % 2, number, 0)) from numbers(100);
explain syntax select sum(if(number % 2, 0, number)) from numbers(100);
explain syntax select sum(if(number % 2, number, null)) from numbers(100);
explain syntax select sum(if(number % 2, null, number)) from numbers(100);
explain syntax select avg(if(number % 2, number, null)) from numbers(100);
explain syntax select avg(if(number % 2, null, number)) from numbers(100);
explain syntax select quantiles(0.5, 0.9, 0.99)(if(number % 2, number, null)) from numbers(100);
explain syntax select quantiles(0.5, 0.9, 0.99)(if(number % 2, null, number)) from numbers(100);
set optimize_rewrite_aggregate_function_with_if = true;
explain syntax select sum(if(number % 2, number, 0)) from numbers(100);
explain syntax select sum(if(number % 2, 0, number)) from numbers(100);
explain syntax select sum(if(number % 2, number, null)) from numbers(100);
explain syntax select sum(if(number % 2, null, number)) from numbers(100);
explain syntax select avg(if(number % 2, number, null)) from numbers(100);
explain syntax select avg(if(number % 2, null, number)) from numbers(100);
explain syntax select quantiles(0.5, 0.9, 0.99)(if(number % 2, number, null)) from numbers(100);
explain syntax select quantiles(0.5, 0.9, 0.99)(if(number % 2, null, number)) from numbers(100);
set allow_experimental_analyzer = true;
set optimize_rewrite_aggregate_function_with_if = false;
EXPLAIN QUERY TREE run_passes = 1 select sum(if(number % 2, number, 0)) from numbers(100);
EXPLAIN QUERY TREE run_passes = 1 select sum(if(number % 2, 0, number)) from numbers(100);
EXPLAIN QUERY TREE run_passes = 1 select sum(if(number % 2, number, null)) from numbers(100);
EXPLAIN QUERY TREE run_passes = 1 select sum(if(number % 2, null, number)) from numbers(100);
EXPLAIN QUERY TREE run_passes = 1 select avg(if(number % 2, number, null)) from numbers(100);
EXPLAIN QUERY TREE run_passes = 1 select avg(if(number % 2, null, number)) from numbers(100);
EXPLAIN QUERY TREE run_passes = 1 select quantiles(0.5, 0.9, 0.99)(if(number % 2, number, null)) from numbers(100);
EXPLAIN QUERY TREE run_passes = 1 select quantiles(0.5, 0.9, 0.99)(if(number % 2, null, number)) from numbers(100);
set optimize_rewrite_aggregate_function_with_if = true;
EXPLAIN QUERY TREE run_passes = 1 select sum(if(number % 2, number, 0)) from numbers(100);
EXPLAIN QUERY TREE run_passes = 1 select sum(if(number % 2, 0, number)) from numbers(100);