mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-12-03 13:02:00 +00:00
Fix move_functions_out_of_any optimisation with lambda (#12664)
This commit is contained in:
parent
de8328b2b0
commit
2041d7d0d8
@ -1,98 +0,0 @@
|
||||
#include <Common/typeid_cast.h>
|
||||
#include <Parsers/ASTLiteral.h>
|
||||
#include <Parsers/ASTFunction.h>
|
||||
#include <Parsers/ASTIdentifier.h>
|
||||
#include <Interpreters/AnyInputOptimize.h>
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <Parsers/ASTTablesInSelectQuery.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int LOGICAL_ERROR;
|
||||
extern const int ILLEGAL_AGGREGATION;
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
constexpr auto * any = "any";
|
||||
constexpr auto * anyLast = "anyLast";
|
||||
}
|
||||
|
||||
ASTPtr * getExactChild(const ASTPtr & ast, const size_t ind)
|
||||
{
|
||||
if (ast && ast->as<ASTFunction>()->arguments->children[ind])
|
||||
return &ast->as<ASTFunction>()->arguments->children[ind];
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
///recursive searching of identifiers
|
||||
void changeAllIdentifiers(ASTPtr & ast, size_t ind, const std::string & name)
|
||||
{
|
||||
ASTPtr * exact_child = getExactChild(ast, ind);
|
||||
if (!exact_child)
|
||||
return;
|
||||
|
||||
if ((*exact_child)->as<ASTIdentifier>())
|
||||
{
|
||||
///put new any
|
||||
ASTPtr old_ast = *exact_child;
|
||||
*exact_child = makeASTFunction(name);
|
||||
(*exact_child)->as<ASTFunction>()->arguments->children.push_back(old_ast);
|
||||
}
|
||||
else if ((*exact_child)->as<ASTFunction>())
|
||||
{
|
||||
if (AggregateFunctionFactory::instance().isAggregateFunctionName((*exact_child)->as<ASTFunction>()->name))
|
||||
throw Exception("Aggregate function " + (*exact_child)->as<ASTFunction>()->name +
|
||||
" is found inside aggregate function " + name + " in query", ErrorCodes::ILLEGAL_AGGREGATION);
|
||||
|
||||
for (size_t i = 0; i < (*exact_child)->as<ASTFunction>()->arguments->children.size(); i++)
|
||||
changeAllIdentifiers(*exact_child, i, name);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
///cut old any, put any to identifiers. any(functions(x)) -> functions(any(x))
|
||||
void AnyInputMatcher::visit(ASTPtr & current_ast, Data data)
|
||||
{
|
||||
data = {};
|
||||
if (!current_ast)
|
||||
return;
|
||||
|
||||
auto * function_node = current_ast->as<ASTFunction>();
|
||||
if (!function_node || function_node->arguments->children.empty())
|
||||
return;
|
||||
|
||||
const auto & function_argument = function_node->arguments->children[0];
|
||||
if ((function_node->name == any || function_node->name == anyLast)
|
||||
&& function_argument && function_argument->as<ASTFunction>())
|
||||
{
|
||||
auto name = function_node->name;
|
||||
auto alias = function_node->alias;
|
||||
|
||||
///cut any or anyLast
|
||||
if (!function_argument->as<ASTFunction>()->arguments->children.empty())
|
||||
{
|
||||
current_ast = function_argument->clone();
|
||||
current_ast->setAlias(alias);
|
||||
for (size_t i = 0; i < current_ast->as<ASTFunction>()->arguments->children.size(); ++i)
|
||||
changeAllIdentifiers(current_ast, i, name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool AnyInputMatcher::needChildVisit(const ASTPtr & node, const ASTPtr & child)
|
||||
{
|
||||
if (!child)
|
||||
throw Exception("AST item should not have nullptr in children", ErrorCodes::LOGICAL_ERROR);
|
||||
|
||||
if (node->as<ASTTableExpression>() || node->as<ASTArrayJoin>())
|
||||
return false; // NOLINT
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
}
|
@ -1,19 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <Parsers/IAST.h>
|
||||
#include <Interpreters/InDepthNodeVisitor.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
///This optimiser is similar to ArithmeticOperationsInAgrFunc optimizer, but for function any we can extract any functions.
|
||||
class AnyInputMatcher
|
||||
{
|
||||
public:
|
||||
struct Data {};
|
||||
|
||||
static void visit(ASTPtr & ast, Data data);
|
||||
static bool needChildVisit(const ASTPtr & node, const ASTPtr & child);
|
||||
};
|
||||
using AnyInputVisitor = InDepthNodeVisitor<AnyInputMatcher, true>;
|
||||
}
|
@ -153,7 +153,9 @@ void ArithmeticOperationsInAgrFuncMatcher::visit(ASTPtr & ast, Data & data)
|
||||
|
||||
bool ArithmeticOperationsInAgrFuncMatcher::needChildVisit(const ASTPtr & node, const ASTPtr &)
|
||||
{
|
||||
return !node->as<ASTSubquery>() && !node->as<ASTTableExpression>();
|
||||
return !node->as<ASTSubquery>() &&
|
||||
!node->as<ASTTableExpression>() &&
|
||||
!node->as<ASTArrayJoin>();
|
||||
}
|
||||
|
||||
}
|
||||
|
96
src/Interpreters/RewriteAnyFunctionVisitor.cpp
Normal file
96
src/Interpreters/RewriteAnyFunctionVisitor.cpp
Normal file
@ -0,0 +1,96 @@
|
||||
#include <Common/typeid_cast.h>
|
||||
#include <Parsers/ASTLiteral.h>
|
||||
#include <Parsers/ASTFunction.h>
|
||||
#include <Parsers/ASTIdentifier.h>
|
||||
#include <Parsers/ASTSubquery.h>
|
||||
#include <Interpreters/RewriteAnyFunctionVisitor.h>
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <Parsers/ASTTablesInSelectQuery.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
bool extractIdentifiers(const ASTFunction & func, std::vector<ASTPtr *> & identifiers)
|
||||
{
|
||||
for (auto & arg : func.arguments->children)
|
||||
{
|
||||
if (const auto * arg_func = arg->as<ASTFunction>())
|
||||
{
|
||||
if (arg_func->name == "lambda")
|
||||
return false;
|
||||
|
||||
if (AggregateFunctionFactory::instance().isAggregateFunctionName(arg_func->name))
|
||||
return false;
|
||||
|
||||
if (!extractIdentifiers(*arg_func, identifiers))
|
||||
return false;
|
||||
}
|
||||
else if (arg->as<ASTIdentifier>())
|
||||
identifiers.emplace_back(&arg);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
void RewriteAnyFunctionMatcher::visit(ASTPtr & ast, Data & data)
|
||||
{
|
||||
if (auto * func = ast->as<ASTFunction>())
|
||||
visit(*func, ast, data);
|
||||
}
|
||||
|
||||
void RewriteAnyFunctionMatcher::visit(const ASTFunction & func, ASTPtr & ast, Data & data)
|
||||
{
|
||||
if (func.arguments->children.empty() || !func.arguments->children[0])
|
||||
return;
|
||||
|
||||
if (func.name != "any" && func.name != "anyLast")
|
||||
return;
|
||||
|
||||
auto & func_arguments = func.arguments->children;
|
||||
|
||||
const auto * first_arg_func = func_arguments[0]->as<ASTFunction>();
|
||||
if (!first_arg_func || first_arg_func->arguments->children.empty())
|
||||
return;
|
||||
|
||||
/// We have rewritten this function. Just unwrap its argument.
|
||||
if (data.rewritten.count(ast.get()))
|
||||
{
|
||||
func_arguments[0]->setAlias(func.alias);
|
||||
ast = func_arguments[0];
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<ASTPtr *> identifiers;
|
||||
if (!extractIdentifiers(func, identifiers))
|
||||
return;
|
||||
|
||||
/// Wrap identifiers: any(f(x, y, g(z))) -> any(f(any(x), any(y), g(any(z))))
|
||||
for (auto * ast_to_change : identifiers)
|
||||
{
|
||||
ASTPtr identifier_ast = *ast_to_change;
|
||||
*ast_to_change = makeASTFunction(func.name);
|
||||
(*ast_to_change)->as<ASTFunction>()->arguments->children.emplace_back(identifier_ast);
|
||||
}
|
||||
|
||||
data.rewritten.insert(ast.get());
|
||||
|
||||
/// Unwrap function: any(f(any(x), any(y), g(any(z)))) -> f(any(x), any(y), g(any(z)))
|
||||
func_arguments[0]->setAlias(func.alias);
|
||||
ast = func_arguments[0];
|
||||
}
|
||||
|
||||
bool RewriteAnyFunctionMatcher::needChildVisit(const ASTPtr & node, const ASTPtr &)
|
||||
{
|
||||
return !node->as<ASTSubquery>() &&
|
||||
!node->as<ASTTableExpression>() &&
|
||||
!node->as<ASTArrayJoin>();
|
||||
}
|
||||
|
||||
}
|
29
src/Interpreters/RewriteAnyFunctionVisitor.h
Normal file
29
src/Interpreters/RewriteAnyFunctionVisitor.h
Normal file
@ -0,0 +1,29 @@
|
||||
#pragma once
|
||||
|
||||
#include <unordered_set>
|
||||
|
||||
#include <Parsers/IAST.h>
|
||||
#include <Interpreters/InDepthNodeVisitor.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
class ASTFunction;
|
||||
|
||||
/// Rewrite 'any' and 'anyLast' functions pushing them inside original function.
|
||||
/// any(f(x, y, g(z))) -> f(any(x), any(y), g(any(z)))
|
||||
class RewriteAnyFunctionMatcher
|
||||
{
|
||||
public:
|
||||
struct Data
|
||||
{
|
||||
std::unordered_set<IAST *> rewritten;
|
||||
};
|
||||
|
||||
static void visit(ASTPtr & ast, Data & data);
|
||||
static void visit(const ASTFunction &, ASTPtr & ast, Data & data);
|
||||
static bool needChildVisit(const ASTPtr & node, const ASTPtr & child);
|
||||
};
|
||||
using RewriteAnyFunctionVisitor = InDepthNodeVisitor<RewriteAnyFunctionMatcher, false>;
|
||||
|
||||
}
|
@ -8,7 +8,7 @@
|
||||
#include <Interpreters/DuplicateOrderByVisitor.h>
|
||||
#include <Interpreters/GroupByFunctionKeysVisitor.h>
|
||||
#include <Interpreters/AggregateFunctionOfGroupByKeysVisitor.h>
|
||||
#include <Interpreters/AnyInputOptimize.h>
|
||||
#include <Interpreters/RewriteAnyFunctionVisitor.h>
|
||||
#include <Interpreters/RemoveInjectiveFunctionsVisitor.h>
|
||||
#include <Interpreters/RedundantFunctionsInOrderByVisitor.h>
|
||||
#include <Interpreters/MonotonicityCheckVisitor.h>
|
||||
@ -458,11 +458,10 @@ void optimizeAggregationFunctions(ASTPtr & query)
|
||||
ArithmeticOperationsInAgrFuncVisitor(data).visit(query);
|
||||
}
|
||||
|
||||
void optimizeAnyInput(ASTPtr & query)
|
||||
void optimizeAnyFunctions(ASTPtr & query)
|
||||
{
|
||||
/// Removing arithmetic operations from functions
|
||||
AnyInputVisitor::Data data = {};
|
||||
AnyInputVisitor(data).visit(query);
|
||||
RewriteAnyFunctionVisitor::Data data = {};
|
||||
RewriteAnyFunctionVisitor(data).visit(query);
|
||||
}
|
||||
|
||||
void optimizeInjectiveFunctionsInsideUniq(ASTPtr & query, const Context & context)
|
||||
@ -520,9 +519,9 @@ void TreeOptimizer::apply(ASTPtr & query, Aliases & aliases, const NameSet & sou
|
||||
if (settings.optimize_group_by_function_keys)
|
||||
optimizeGroupByFunctionKeys(select_query);
|
||||
|
||||
///Move all operations out of any function
|
||||
/// Move all operations out of any function
|
||||
if (settings.optimize_move_functions_out_of_any)
|
||||
optimizeAnyInput(query);
|
||||
optimizeAnyFunctions(query);
|
||||
|
||||
/// Remove injective functions inside uniq
|
||||
if (settings.optimize_injective_functions_inside_uniq)
|
||||
|
@ -20,7 +20,6 @@ SRCS(
|
||||
addTypeConversionToAST.cpp
|
||||
AggregateDescription.cpp
|
||||
Aggregator.cpp
|
||||
AnyInputOptimize.cpp
|
||||
ArithmeticOperationsInAgrFuncOptimize.cpp
|
||||
ArithmeticOperationsInAgrFuncOptimize.h
|
||||
ArrayJoinAction.cpp
|
||||
@ -127,6 +126,7 @@ SRCS(
|
||||
ReplaceQueryParameterVisitor.cpp
|
||||
RequiredSourceColumnsData.cpp
|
||||
RequiredSourceColumnsVisitor.cpp
|
||||
RewriteAnyFunctionVisitor.cpp
|
||||
RowRefs.cpp
|
||||
Set.cpp
|
||||
SetVariants.cpp
|
||||
|
@ -1,3 +1,30 @@
|
||||
9
|
||||
SELECT any(number) + (any(number) * 2)
|
||||
FROM numbers(3, 10)
|
||||
FROM numbers(1, 2)
|
||||
3
|
||||
SELECT anyLast(number) + (anyLast(number) * 2)
|
||||
FROM numbers(1, 2)
|
||||
6
|
||||
WITH any(number) * 3 AS x
|
||||
SELECT x
|
||||
FROM numbers(1, 2)
|
||||
3
|
||||
SELECT
|
||||
anyLast(number) * 3 AS x,
|
||||
x
|
||||
FROM numbers(1, 2)
|
||||
6 6
|
||||
SELECT any(number + (number * 2))
|
||||
FROM numbers(1, 2)
|
||||
3
|
||||
SELECT anyLast(number + (number * 2))
|
||||
FROM numbers(1, 2)
|
||||
6
|
||||
WITH any(number * 3) AS x
|
||||
SELECT x
|
||||
FROM numbers(1, 2)
|
||||
3
|
||||
SELECT
|
||||
anyLast(number * 3) AS x,
|
||||
x
|
||||
FROM numbers(1, 2)
|
||||
6 6
|
||||
|
@ -1,4 +1,32 @@
|
||||
SET optimize_move_functions_out_of_any=1;
|
||||
SET enable_debug_queries=1;
|
||||
SELECT any(number + number * 2) FROM numbers(3, 10);
|
||||
ANALYZE SELECT any(number + number * 2) FROM numbers(3, 10);
|
||||
SET enable_debug_queries = 1;
|
||||
SET optimize_move_functions_out_of_any = 1;
|
||||
|
||||
ANALYZE SELECT any(number + number * 2) FROM numbers(1, 2);
|
||||
SELECT any(number + number * 2) FROM numbers(1, 2);
|
||||
|
||||
ANALYZE SELECT anyLast(number + number * 2) FROM numbers(1, 2);
|
||||
SELECT anyLast(number + number * 2) FROM numbers(1, 2);
|
||||
|
||||
ANALYZE WITH any(number * 3) AS x SELECT x FROM numbers(1, 2);
|
||||
WITH any(number * 3) AS x SELECT x FROM numbers(1, 2);
|
||||
|
||||
ANALYZE SELECT anyLast(number * 3) AS x, x FROM numbers(1, 2);
|
||||
SELECT anyLast(number * 3) AS x, x FROM numbers(1, 2);
|
||||
|
||||
SELECT any(anyLast(number)) FROM numbers(1); -- { serverError 184 }
|
||||
|
||||
SET optimize_move_functions_out_of_any = 0;
|
||||
|
||||
ANALYZE SELECT any(number + number * 2) FROM numbers(1, 2);
|
||||
SELECT any(number + number * 2) FROM numbers(1, 2);
|
||||
|
||||
ANALYZE SELECT anyLast(number + number * 2) FROM numbers(1, 2);
|
||||
SELECT anyLast(number + number * 2) FROM numbers(1, 2);
|
||||
|
||||
ANALYZE WITH any(number * 3) AS x SELECT x FROM numbers(1, 2);
|
||||
WITH any(number * 3) AS x SELECT x FROM numbers(1, 2);
|
||||
|
||||
ANALYZE SELECT anyLast(number * 3) AS x, x FROM numbers(1, 2);
|
||||
SELECT anyLast(number * 3) AS x, x FROM numbers(1, 2);
|
||||
|
||||
SELECT any(anyLast(number)) FROM numbers(1); -- { serverError 184 }
|
||||
|
@ -0,0 +1 @@
|
||||
0
|
18
tests/queries/0_stateless/01414_optimize_any_bug.sql
Normal file
18
tests/queries/0_stateless/01414_optimize_any_bug.sql
Normal file
@ -0,0 +1,18 @@
|
||||
DROP TABLE IF EXISTS test;
|
||||
|
||||
CREATE TABLE test
|
||||
(
|
||||
`Source.C1` Array(UInt64),
|
||||
`Source.C2` Array(UInt64)
|
||||
)
|
||||
ENGINE = MergeTree()
|
||||
ORDER BY tuple();
|
||||
|
||||
SET optimize_move_functions_out_of_any = 1;
|
||||
|
||||
SELECT any(arrayFilter((c, d) -> (4 = d), `Source.C1`, `Source.C2`)[1]) AS x
|
||||
FROM test
|
||||
WHERE 0
|
||||
GROUP BY 42;
|
||||
|
||||
DROP TABLE test;
|
Loading…
Reference in New Issue
Block a user