Fix move_functions_out_of_any optimisation with lambda (#12664)

This commit is contained in:
Artem Zuikov 2020-07-23 18:15:22 +03:00 committed by GitHub
parent de8328b2b0
commit 2041d7d0d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 215 additions and 132 deletions

View File

@ -1,98 +0,0 @@
#include <Common/typeid_cast.h>
#include <Parsers/ASTLiteral.h>
#include <Parsers/ASTFunction.h>
#include <Parsers/ASTIdentifier.h>
#include <Interpreters/AnyInputOptimize.h>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <IO/WriteHelpers.h>
#include <Parsers/ASTTablesInSelectQuery.h>
namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
extern const int ILLEGAL_AGGREGATION;
}
namespace
{
constexpr auto * any = "any";
constexpr auto * anyLast = "anyLast";
}
ASTPtr * getExactChild(const ASTPtr & ast, const size_t ind)
{
if (ast && ast->as<ASTFunction>()->arguments->children[ind])
return &ast->as<ASTFunction>()->arguments->children[ind];
return nullptr;
}
///recursive searching of identifiers
void changeAllIdentifiers(ASTPtr & ast, size_t ind, const std::string & name)
{
ASTPtr * exact_child = getExactChild(ast, ind);
if (!exact_child)
return;
if ((*exact_child)->as<ASTIdentifier>())
{
///put new any
ASTPtr old_ast = *exact_child;
*exact_child = makeASTFunction(name);
(*exact_child)->as<ASTFunction>()->arguments->children.push_back(old_ast);
}
else if ((*exact_child)->as<ASTFunction>())
{
if (AggregateFunctionFactory::instance().isAggregateFunctionName((*exact_child)->as<ASTFunction>()->name))
throw Exception("Aggregate function " + (*exact_child)->as<ASTFunction>()->name +
" is found inside aggregate function " + name + " in query", ErrorCodes::ILLEGAL_AGGREGATION);
for (size_t i = 0; i < (*exact_child)->as<ASTFunction>()->arguments->children.size(); i++)
changeAllIdentifiers(*exact_child, i, name);
}
}
///cut old any, put any to identifiers. any(functions(x)) -> functions(any(x))
void AnyInputMatcher::visit(ASTPtr & current_ast, Data data)
{
data = {};
if (!current_ast)
return;
auto * function_node = current_ast->as<ASTFunction>();
if (!function_node || function_node->arguments->children.empty())
return;
const auto & function_argument = function_node->arguments->children[0];
if ((function_node->name == any || function_node->name == anyLast)
&& function_argument && function_argument->as<ASTFunction>())
{
auto name = function_node->name;
auto alias = function_node->alias;
///cut any or anyLast
if (!function_argument->as<ASTFunction>()->arguments->children.empty())
{
current_ast = function_argument->clone();
current_ast->setAlias(alias);
for (size_t i = 0; i < current_ast->as<ASTFunction>()->arguments->children.size(); ++i)
changeAllIdentifiers(current_ast, i, name);
}
}
}
bool AnyInputMatcher::needChildVisit(const ASTPtr & node, const ASTPtr & child)
{
if (!child)
throw Exception("AST item should not have nullptr in children", ErrorCodes::LOGICAL_ERROR);
if (node->as<ASTTableExpression>() || node->as<ASTArrayJoin>())
return false; // NOLINT
return true;
}
}

View File

@ -1,19 +0,0 @@
#pragma once
#include <Parsers/IAST.h>
#include <Interpreters/InDepthNodeVisitor.h>
namespace DB
{
///This optimiser is similar to ArithmeticOperationsInAgrFunc optimizer, but for function any we can extract any functions.
class AnyInputMatcher
{
public:
struct Data {};
static void visit(ASTPtr & ast, Data data);
static bool needChildVisit(const ASTPtr & node, const ASTPtr & child);
};
using AnyInputVisitor = InDepthNodeVisitor<AnyInputMatcher, true>;
}

View File

@ -153,7 +153,9 @@ void ArithmeticOperationsInAgrFuncMatcher::visit(ASTPtr & ast, Data & data)
bool ArithmeticOperationsInAgrFuncMatcher::needChildVisit(const ASTPtr & node, const ASTPtr &)
{
return !node->as<ASTSubquery>() && !node->as<ASTTableExpression>();
return !node->as<ASTSubquery>() &&
!node->as<ASTTableExpression>() &&
!node->as<ASTArrayJoin>();
}
}

View File

@ -0,0 +1,96 @@
#include <Common/typeid_cast.h>
#include <Parsers/ASTLiteral.h>
#include <Parsers/ASTFunction.h>
#include <Parsers/ASTIdentifier.h>
#include <Parsers/ASTSubquery.h>
#include <Interpreters/RewriteAnyFunctionVisitor.h>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <IO/WriteHelpers.h>
#include <Parsers/ASTTablesInSelectQuery.h>
namespace DB
{
namespace
{
bool extractIdentifiers(const ASTFunction & func, std::vector<ASTPtr *> & identifiers)
{
for (auto & arg : func.arguments->children)
{
if (const auto * arg_func = arg->as<ASTFunction>())
{
if (arg_func->name == "lambda")
return false;
if (AggregateFunctionFactory::instance().isAggregateFunctionName(arg_func->name))
return false;
if (!extractIdentifiers(*arg_func, identifiers))
return false;
}
else if (arg->as<ASTIdentifier>())
identifiers.emplace_back(&arg);
}
return true;
}
}
void RewriteAnyFunctionMatcher::visit(ASTPtr & ast, Data & data)
{
if (auto * func = ast->as<ASTFunction>())
visit(*func, ast, data);
}
void RewriteAnyFunctionMatcher::visit(const ASTFunction & func, ASTPtr & ast, Data & data)
{
if (func.arguments->children.empty() || !func.arguments->children[0])
return;
if (func.name != "any" && func.name != "anyLast")
return;
auto & func_arguments = func.arguments->children;
const auto * first_arg_func = func_arguments[0]->as<ASTFunction>();
if (!first_arg_func || first_arg_func->arguments->children.empty())
return;
/// We have rewritten this function. Just unwrap its argument.
if (data.rewritten.count(ast.get()))
{
func_arguments[0]->setAlias(func.alias);
ast = func_arguments[0];
return;
}
std::vector<ASTPtr *> identifiers;
if (!extractIdentifiers(func, identifiers))
return;
/// Wrap identifiers: any(f(x, y, g(z))) -> any(f(any(x), any(y), g(any(z))))
for (auto * ast_to_change : identifiers)
{
ASTPtr identifier_ast = *ast_to_change;
*ast_to_change = makeASTFunction(func.name);
(*ast_to_change)->as<ASTFunction>()->arguments->children.emplace_back(identifier_ast);
}
data.rewritten.insert(ast.get());
/// Unwrap function: any(f(any(x), any(y), g(any(z)))) -> f(any(x), any(y), g(any(z)))
func_arguments[0]->setAlias(func.alias);
ast = func_arguments[0];
}
bool RewriteAnyFunctionMatcher::needChildVisit(const ASTPtr & node, const ASTPtr &)
{
return !node->as<ASTSubquery>() &&
!node->as<ASTTableExpression>() &&
!node->as<ASTArrayJoin>();
}
}

View File

@ -0,0 +1,29 @@
#pragma once
#include <unordered_set>
#include <Parsers/IAST.h>
#include <Interpreters/InDepthNodeVisitor.h>
namespace DB
{
class ASTFunction;
/// Rewrite 'any' and 'anyLast' functions pushing them inside original function.
/// any(f(x, y, g(z))) -> f(any(x), any(y), g(any(z)))
class RewriteAnyFunctionMatcher
{
public:
struct Data
{
std::unordered_set<IAST *> rewritten;
};
static void visit(ASTPtr & ast, Data & data);
static void visit(const ASTFunction &, ASTPtr & ast, Data & data);
static bool needChildVisit(const ASTPtr & node, const ASTPtr & child);
};
using RewriteAnyFunctionVisitor = InDepthNodeVisitor<RewriteAnyFunctionMatcher, false>;
}

View File

@ -8,7 +8,7 @@
#include <Interpreters/DuplicateOrderByVisitor.h>
#include <Interpreters/GroupByFunctionKeysVisitor.h>
#include <Interpreters/AggregateFunctionOfGroupByKeysVisitor.h>
#include <Interpreters/AnyInputOptimize.h>
#include <Interpreters/RewriteAnyFunctionVisitor.h>
#include <Interpreters/RemoveInjectiveFunctionsVisitor.h>
#include <Interpreters/RedundantFunctionsInOrderByVisitor.h>
#include <Interpreters/MonotonicityCheckVisitor.h>
@ -458,11 +458,10 @@ void optimizeAggregationFunctions(ASTPtr & query)
ArithmeticOperationsInAgrFuncVisitor(data).visit(query);
}
void optimizeAnyInput(ASTPtr & query)
void optimizeAnyFunctions(ASTPtr & query)
{
/// Removing arithmetic operations from functions
AnyInputVisitor::Data data = {};
AnyInputVisitor(data).visit(query);
RewriteAnyFunctionVisitor::Data data = {};
RewriteAnyFunctionVisitor(data).visit(query);
}
void optimizeInjectiveFunctionsInsideUniq(ASTPtr & query, const Context & context)
@ -522,7 +521,7 @@ void TreeOptimizer::apply(ASTPtr & query, Aliases & aliases, const NameSet & sou
/// Move all operations out of any function
if (settings.optimize_move_functions_out_of_any)
optimizeAnyInput(query);
optimizeAnyFunctions(query);
/// Remove injective functions inside uniq
if (settings.optimize_injective_functions_inside_uniq)

View File

@ -20,7 +20,6 @@ SRCS(
addTypeConversionToAST.cpp
AggregateDescription.cpp
Aggregator.cpp
AnyInputOptimize.cpp
ArithmeticOperationsInAgrFuncOptimize.cpp
ArithmeticOperationsInAgrFuncOptimize.h
ArrayJoinAction.cpp
@ -127,6 +126,7 @@ SRCS(
ReplaceQueryParameterVisitor.cpp
RequiredSourceColumnsData.cpp
RequiredSourceColumnsVisitor.cpp
RewriteAnyFunctionVisitor.cpp
RowRefs.cpp
Set.cpp
SetVariants.cpp

View File

@ -1,3 +1,30 @@
9
SELECT any(number) + (any(number) * 2)
FROM numbers(3, 10)
FROM numbers(1, 2)
3
SELECT anyLast(number) + (anyLast(number) * 2)
FROM numbers(1, 2)
6
WITH any(number) * 3 AS x
SELECT x
FROM numbers(1, 2)
3
SELECT
anyLast(number) * 3 AS x,
x
FROM numbers(1, 2)
6 6
SELECT any(number + (number * 2))
FROM numbers(1, 2)
3
SELECT anyLast(number + (number * 2))
FROM numbers(1, 2)
6
WITH any(number * 3) AS x
SELECT x
FROM numbers(1, 2)
3
SELECT
anyLast(number * 3) AS x,
x
FROM numbers(1, 2)
6 6

View File

@ -1,4 +1,32 @@
SET optimize_move_functions_out_of_any=1;
SET enable_debug_queries = 1;
SELECT any(number + number * 2) FROM numbers(3, 10);
ANALYZE SELECT any(number + number * 2) FROM numbers(3, 10);
SET optimize_move_functions_out_of_any = 1;
ANALYZE SELECT any(number + number * 2) FROM numbers(1, 2);
SELECT any(number + number * 2) FROM numbers(1, 2);
ANALYZE SELECT anyLast(number + number * 2) FROM numbers(1, 2);
SELECT anyLast(number + number * 2) FROM numbers(1, 2);
ANALYZE WITH any(number * 3) AS x SELECT x FROM numbers(1, 2);
WITH any(number * 3) AS x SELECT x FROM numbers(1, 2);
ANALYZE SELECT anyLast(number * 3) AS x, x FROM numbers(1, 2);
SELECT anyLast(number * 3) AS x, x FROM numbers(1, 2);
SELECT any(anyLast(number)) FROM numbers(1); -- { serverError 184 }
SET optimize_move_functions_out_of_any = 0;
ANALYZE SELECT any(number + number * 2) FROM numbers(1, 2);
SELECT any(number + number * 2) FROM numbers(1, 2);
ANALYZE SELECT anyLast(number + number * 2) FROM numbers(1, 2);
SELECT anyLast(number + number * 2) FROM numbers(1, 2);
ANALYZE WITH any(number * 3) AS x SELECT x FROM numbers(1, 2);
WITH any(number * 3) AS x SELECT x FROM numbers(1, 2);
ANALYZE SELECT anyLast(number * 3) AS x, x FROM numbers(1, 2);
SELECT anyLast(number * 3) AS x, x FROM numbers(1, 2);
SELECT any(anyLast(number)) FROM numbers(1); -- { serverError 184 }

View File

@ -0,0 +1 @@
0

View File

@ -0,0 +1,18 @@
DROP TABLE IF EXISTS test;
CREATE TABLE test
(
`Source.C1` Array(UInt64),
`Source.C2` Array(UInt64)
)
ENGINE = MergeTree()
ORDER BY tuple();
SET optimize_move_functions_out_of_any = 1;
SELECT any(arrayFilter((c, d) -> (4 = d), `Source.C1`, `Source.C2`)[1]) AS x
FROM test
WHERE 0
GROUP BY 42;
DROP TABLE test;