mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-24 08:32:02 +00:00
Fix move_functions_out_of_any optimisation with lambda (#12664)
This commit is contained in:
parent
de8328b2b0
commit
2041d7d0d8
@ -1,98 +0,0 @@
|
|||||||
#include <Common/typeid_cast.h>
|
|
||||||
#include <Parsers/ASTLiteral.h>
|
|
||||||
#include <Parsers/ASTFunction.h>
|
|
||||||
#include <Parsers/ASTIdentifier.h>
|
|
||||||
#include <Interpreters/AnyInputOptimize.h>
|
|
||||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
|
||||||
#include <IO/WriteHelpers.h>
|
|
||||||
#include <Parsers/ASTTablesInSelectQuery.h>
|
|
||||||
|
|
||||||
namespace DB
|
|
||||||
{
|
|
||||||
|
|
||||||
namespace ErrorCodes
|
|
||||||
{
|
|
||||||
extern const int LOGICAL_ERROR;
|
|
||||||
extern const int ILLEGAL_AGGREGATION;
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace
|
|
||||||
{
|
|
||||||
constexpr auto * any = "any";
|
|
||||||
constexpr auto * anyLast = "anyLast";
|
|
||||||
}
|
|
||||||
|
|
||||||
ASTPtr * getExactChild(const ASTPtr & ast, const size_t ind)
|
|
||||||
{
|
|
||||||
if (ast && ast->as<ASTFunction>()->arguments->children[ind])
|
|
||||||
return &ast->as<ASTFunction>()->arguments->children[ind];
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
///recursive searching of identifiers
|
|
||||||
void changeAllIdentifiers(ASTPtr & ast, size_t ind, const std::string & name)
|
|
||||||
{
|
|
||||||
ASTPtr * exact_child = getExactChild(ast, ind);
|
|
||||||
if (!exact_child)
|
|
||||||
return;
|
|
||||||
|
|
||||||
if ((*exact_child)->as<ASTIdentifier>())
|
|
||||||
{
|
|
||||||
///put new any
|
|
||||||
ASTPtr old_ast = *exact_child;
|
|
||||||
*exact_child = makeASTFunction(name);
|
|
||||||
(*exact_child)->as<ASTFunction>()->arguments->children.push_back(old_ast);
|
|
||||||
}
|
|
||||||
else if ((*exact_child)->as<ASTFunction>())
|
|
||||||
{
|
|
||||||
if (AggregateFunctionFactory::instance().isAggregateFunctionName((*exact_child)->as<ASTFunction>()->name))
|
|
||||||
throw Exception("Aggregate function " + (*exact_child)->as<ASTFunction>()->name +
|
|
||||||
" is found inside aggregate function " + name + " in query", ErrorCodes::ILLEGAL_AGGREGATION);
|
|
||||||
|
|
||||||
for (size_t i = 0; i < (*exact_child)->as<ASTFunction>()->arguments->children.size(); i++)
|
|
||||||
changeAllIdentifiers(*exact_child, i, name);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
///cut old any, put any to identifiers. any(functions(x)) -> functions(any(x))
|
|
||||||
void AnyInputMatcher::visit(ASTPtr & current_ast, Data data)
|
|
||||||
{
|
|
||||||
data = {};
|
|
||||||
if (!current_ast)
|
|
||||||
return;
|
|
||||||
|
|
||||||
auto * function_node = current_ast->as<ASTFunction>();
|
|
||||||
if (!function_node || function_node->arguments->children.empty())
|
|
||||||
return;
|
|
||||||
|
|
||||||
const auto & function_argument = function_node->arguments->children[0];
|
|
||||||
if ((function_node->name == any || function_node->name == anyLast)
|
|
||||||
&& function_argument && function_argument->as<ASTFunction>())
|
|
||||||
{
|
|
||||||
auto name = function_node->name;
|
|
||||||
auto alias = function_node->alias;
|
|
||||||
|
|
||||||
///cut any or anyLast
|
|
||||||
if (!function_argument->as<ASTFunction>()->arguments->children.empty())
|
|
||||||
{
|
|
||||||
current_ast = function_argument->clone();
|
|
||||||
current_ast->setAlias(alias);
|
|
||||||
for (size_t i = 0; i < current_ast->as<ASTFunction>()->arguments->children.size(); ++i)
|
|
||||||
changeAllIdentifiers(current_ast, i, name);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool AnyInputMatcher::needChildVisit(const ASTPtr & node, const ASTPtr & child)
|
|
||||||
{
|
|
||||||
if (!child)
|
|
||||||
throw Exception("AST item should not have nullptr in children", ErrorCodes::LOGICAL_ERROR);
|
|
||||||
|
|
||||||
if (node->as<ASTTableExpression>() || node->as<ASTArrayJoin>())
|
|
||||||
return false; // NOLINT
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
@ -1,19 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include <Parsers/IAST.h>
|
|
||||||
#include <Interpreters/InDepthNodeVisitor.h>
|
|
||||||
|
|
||||||
namespace DB
|
|
||||||
{
|
|
||||||
|
|
||||||
///This optimiser is similar to ArithmeticOperationsInAgrFunc optimizer, but for function any we can extract any functions.
|
|
||||||
class AnyInputMatcher
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
struct Data {};
|
|
||||||
|
|
||||||
static void visit(ASTPtr & ast, Data data);
|
|
||||||
static bool needChildVisit(const ASTPtr & node, const ASTPtr & child);
|
|
||||||
};
|
|
||||||
using AnyInputVisitor = InDepthNodeVisitor<AnyInputMatcher, true>;
|
|
||||||
}
|
|
@ -153,7 +153,9 @@ void ArithmeticOperationsInAgrFuncMatcher::visit(ASTPtr & ast, Data & data)
|
|||||||
|
|
||||||
bool ArithmeticOperationsInAgrFuncMatcher::needChildVisit(const ASTPtr & node, const ASTPtr &)
|
bool ArithmeticOperationsInAgrFuncMatcher::needChildVisit(const ASTPtr & node, const ASTPtr &)
|
||||||
{
|
{
|
||||||
return !node->as<ASTSubquery>() && !node->as<ASTTableExpression>();
|
return !node->as<ASTSubquery>() &&
|
||||||
|
!node->as<ASTTableExpression>() &&
|
||||||
|
!node->as<ASTArrayJoin>();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
96
src/Interpreters/RewriteAnyFunctionVisitor.cpp
Normal file
96
src/Interpreters/RewriteAnyFunctionVisitor.cpp
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
#include <Common/typeid_cast.h>
|
||||||
|
#include <Parsers/ASTLiteral.h>
|
||||||
|
#include <Parsers/ASTFunction.h>
|
||||||
|
#include <Parsers/ASTIdentifier.h>
|
||||||
|
#include <Parsers/ASTSubquery.h>
|
||||||
|
#include <Interpreters/RewriteAnyFunctionVisitor.h>
|
||||||
|
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||||
|
#include <IO/WriteHelpers.h>
|
||||||
|
#include <Parsers/ASTTablesInSelectQuery.h>
|
||||||
|
|
||||||
|
namespace DB
|
||||||
|
{
|
||||||
|
|
||||||
|
namespace
|
||||||
|
{
|
||||||
|
|
||||||
|
bool extractIdentifiers(const ASTFunction & func, std::vector<ASTPtr *> & identifiers)
|
||||||
|
{
|
||||||
|
for (auto & arg : func.arguments->children)
|
||||||
|
{
|
||||||
|
if (const auto * arg_func = arg->as<ASTFunction>())
|
||||||
|
{
|
||||||
|
if (arg_func->name == "lambda")
|
||||||
|
return false;
|
||||||
|
|
||||||
|
if (AggregateFunctionFactory::instance().isAggregateFunctionName(arg_func->name))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
if (!extractIdentifiers(*arg_func, identifiers))
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
else if (arg->as<ASTIdentifier>())
|
||||||
|
identifiers.emplace_back(&arg);
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void RewriteAnyFunctionMatcher::visit(ASTPtr & ast, Data & data)
|
||||||
|
{
|
||||||
|
if (auto * func = ast->as<ASTFunction>())
|
||||||
|
visit(*func, ast, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
void RewriteAnyFunctionMatcher::visit(const ASTFunction & func, ASTPtr & ast, Data & data)
|
||||||
|
{
|
||||||
|
if (func.arguments->children.empty() || !func.arguments->children[0])
|
||||||
|
return;
|
||||||
|
|
||||||
|
if (func.name != "any" && func.name != "anyLast")
|
||||||
|
return;
|
||||||
|
|
||||||
|
auto & func_arguments = func.arguments->children;
|
||||||
|
|
||||||
|
const auto * first_arg_func = func_arguments[0]->as<ASTFunction>();
|
||||||
|
if (!first_arg_func || first_arg_func->arguments->children.empty())
|
||||||
|
return;
|
||||||
|
|
||||||
|
/// We have rewritten this function. Just unwrap its argument.
|
||||||
|
if (data.rewritten.count(ast.get()))
|
||||||
|
{
|
||||||
|
func_arguments[0]->setAlias(func.alias);
|
||||||
|
ast = func_arguments[0];
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<ASTPtr *> identifiers;
|
||||||
|
if (!extractIdentifiers(func, identifiers))
|
||||||
|
return;
|
||||||
|
|
||||||
|
/// Wrap identifiers: any(f(x, y, g(z))) -> any(f(any(x), any(y), g(any(z))))
|
||||||
|
for (auto * ast_to_change : identifiers)
|
||||||
|
{
|
||||||
|
ASTPtr identifier_ast = *ast_to_change;
|
||||||
|
*ast_to_change = makeASTFunction(func.name);
|
||||||
|
(*ast_to_change)->as<ASTFunction>()->arguments->children.emplace_back(identifier_ast);
|
||||||
|
}
|
||||||
|
|
||||||
|
data.rewritten.insert(ast.get());
|
||||||
|
|
||||||
|
/// Unwrap function: any(f(any(x), any(y), g(any(z)))) -> f(any(x), any(y), g(any(z)))
|
||||||
|
func_arguments[0]->setAlias(func.alias);
|
||||||
|
ast = func_arguments[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
bool RewriteAnyFunctionMatcher::needChildVisit(const ASTPtr & node, const ASTPtr &)
|
||||||
|
{
|
||||||
|
return !node->as<ASTSubquery>() &&
|
||||||
|
!node->as<ASTTableExpression>() &&
|
||||||
|
!node->as<ASTArrayJoin>();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
29
src/Interpreters/RewriteAnyFunctionVisitor.h
Normal file
29
src/Interpreters/RewriteAnyFunctionVisitor.h
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <unordered_set>
|
||||||
|
|
||||||
|
#include <Parsers/IAST.h>
|
||||||
|
#include <Interpreters/InDepthNodeVisitor.h>
|
||||||
|
|
||||||
|
namespace DB
|
||||||
|
{
|
||||||
|
|
||||||
|
class ASTFunction;
|
||||||
|
|
||||||
|
/// Rewrite 'any' and 'anyLast' functions pushing them inside original function.
|
||||||
|
/// any(f(x, y, g(z))) -> f(any(x), any(y), g(any(z)))
|
||||||
|
class RewriteAnyFunctionMatcher
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
struct Data
|
||||||
|
{
|
||||||
|
std::unordered_set<IAST *> rewritten;
|
||||||
|
};
|
||||||
|
|
||||||
|
static void visit(ASTPtr & ast, Data & data);
|
||||||
|
static void visit(const ASTFunction &, ASTPtr & ast, Data & data);
|
||||||
|
static bool needChildVisit(const ASTPtr & node, const ASTPtr & child);
|
||||||
|
};
|
||||||
|
using RewriteAnyFunctionVisitor = InDepthNodeVisitor<RewriteAnyFunctionMatcher, false>;
|
||||||
|
|
||||||
|
}
|
@ -8,7 +8,7 @@
|
|||||||
#include <Interpreters/DuplicateOrderByVisitor.h>
|
#include <Interpreters/DuplicateOrderByVisitor.h>
|
||||||
#include <Interpreters/GroupByFunctionKeysVisitor.h>
|
#include <Interpreters/GroupByFunctionKeysVisitor.h>
|
||||||
#include <Interpreters/AggregateFunctionOfGroupByKeysVisitor.h>
|
#include <Interpreters/AggregateFunctionOfGroupByKeysVisitor.h>
|
||||||
#include <Interpreters/AnyInputOptimize.h>
|
#include <Interpreters/RewriteAnyFunctionVisitor.h>
|
||||||
#include <Interpreters/RemoveInjectiveFunctionsVisitor.h>
|
#include <Interpreters/RemoveInjectiveFunctionsVisitor.h>
|
||||||
#include <Interpreters/RedundantFunctionsInOrderByVisitor.h>
|
#include <Interpreters/RedundantFunctionsInOrderByVisitor.h>
|
||||||
#include <Interpreters/MonotonicityCheckVisitor.h>
|
#include <Interpreters/MonotonicityCheckVisitor.h>
|
||||||
@ -458,11 +458,10 @@ void optimizeAggregationFunctions(ASTPtr & query)
|
|||||||
ArithmeticOperationsInAgrFuncVisitor(data).visit(query);
|
ArithmeticOperationsInAgrFuncVisitor(data).visit(query);
|
||||||
}
|
}
|
||||||
|
|
||||||
void optimizeAnyInput(ASTPtr & query)
|
void optimizeAnyFunctions(ASTPtr & query)
|
||||||
{
|
{
|
||||||
/// Removing arithmetic operations from functions
|
RewriteAnyFunctionVisitor::Data data = {};
|
||||||
AnyInputVisitor::Data data = {};
|
RewriteAnyFunctionVisitor(data).visit(query);
|
||||||
AnyInputVisitor(data).visit(query);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void optimizeInjectiveFunctionsInsideUniq(ASTPtr & query, const Context & context)
|
void optimizeInjectiveFunctionsInsideUniq(ASTPtr & query, const Context & context)
|
||||||
@ -522,7 +521,7 @@ void TreeOptimizer::apply(ASTPtr & query, Aliases & aliases, const NameSet & sou
|
|||||||
|
|
||||||
/// Move all operations out of any function
|
/// Move all operations out of any function
|
||||||
if (settings.optimize_move_functions_out_of_any)
|
if (settings.optimize_move_functions_out_of_any)
|
||||||
optimizeAnyInput(query);
|
optimizeAnyFunctions(query);
|
||||||
|
|
||||||
/// Remove injective functions inside uniq
|
/// Remove injective functions inside uniq
|
||||||
if (settings.optimize_injective_functions_inside_uniq)
|
if (settings.optimize_injective_functions_inside_uniq)
|
||||||
|
@ -20,7 +20,6 @@ SRCS(
|
|||||||
addTypeConversionToAST.cpp
|
addTypeConversionToAST.cpp
|
||||||
AggregateDescription.cpp
|
AggregateDescription.cpp
|
||||||
Aggregator.cpp
|
Aggregator.cpp
|
||||||
AnyInputOptimize.cpp
|
|
||||||
ArithmeticOperationsInAgrFuncOptimize.cpp
|
ArithmeticOperationsInAgrFuncOptimize.cpp
|
||||||
ArithmeticOperationsInAgrFuncOptimize.h
|
ArithmeticOperationsInAgrFuncOptimize.h
|
||||||
ArrayJoinAction.cpp
|
ArrayJoinAction.cpp
|
||||||
@ -127,6 +126,7 @@ SRCS(
|
|||||||
ReplaceQueryParameterVisitor.cpp
|
ReplaceQueryParameterVisitor.cpp
|
||||||
RequiredSourceColumnsData.cpp
|
RequiredSourceColumnsData.cpp
|
||||||
RequiredSourceColumnsVisitor.cpp
|
RequiredSourceColumnsVisitor.cpp
|
||||||
|
RewriteAnyFunctionVisitor.cpp
|
||||||
RowRefs.cpp
|
RowRefs.cpp
|
||||||
Set.cpp
|
Set.cpp
|
||||||
SetVariants.cpp
|
SetVariants.cpp
|
||||||
|
@ -1,3 +1,30 @@
|
|||||||
9
|
|
||||||
SELECT any(number) + (any(number) * 2)
|
SELECT any(number) + (any(number) * 2)
|
||||||
FROM numbers(3, 10)
|
FROM numbers(1, 2)
|
||||||
|
3
|
||||||
|
SELECT anyLast(number) + (anyLast(number) * 2)
|
||||||
|
FROM numbers(1, 2)
|
||||||
|
6
|
||||||
|
WITH any(number) * 3 AS x
|
||||||
|
SELECT x
|
||||||
|
FROM numbers(1, 2)
|
||||||
|
3
|
||||||
|
SELECT
|
||||||
|
anyLast(number) * 3 AS x,
|
||||||
|
x
|
||||||
|
FROM numbers(1, 2)
|
||||||
|
6 6
|
||||||
|
SELECT any(number + (number * 2))
|
||||||
|
FROM numbers(1, 2)
|
||||||
|
3
|
||||||
|
SELECT anyLast(number + (number * 2))
|
||||||
|
FROM numbers(1, 2)
|
||||||
|
6
|
||||||
|
WITH any(number * 3) AS x
|
||||||
|
SELECT x
|
||||||
|
FROM numbers(1, 2)
|
||||||
|
3
|
||||||
|
SELECT
|
||||||
|
anyLast(number * 3) AS x,
|
||||||
|
x
|
||||||
|
FROM numbers(1, 2)
|
||||||
|
6 6
|
||||||
|
@ -1,4 +1,32 @@
|
|||||||
SET optimize_move_functions_out_of_any=1;
|
|
||||||
SET enable_debug_queries = 1;
|
SET enable_debug_queries = 1;
|
||||||
SELECT any(number + number * 2) FROM numbers(3, 10);
|
SET optimize_move_functions_out_of_any = 1;
|
||||||
ANALYZE SELECT any(number + number * 2) FROM numbers(3, 10);
|
|
||||||
|
ANALYZE SELECT any(number + number * 2) FROM numbers(1, 2);
|
||||||
|
SELECT any(number + number * 2) FROM numbers(1, 2);
|
||||||
|
|
||||||
|
ANALYZE SELECT anyLast(number + number * 2) FROM numbers(1, 2);
|
||||||
|
SELECT anyLast(number + number * 2) FROM numbers(1, 2);
|
||||||
|
|
||||||
|
ANALYZE WITH any(number * 3) AS x SELECT x FROM numbers(1, 2);
|
||||||
|
WITH any(number * 3) AS x SELECT x FROM numbers(1, 2);
|
||||||
|
|
||||||
|
ANALYZE SELECT anyLast(number * 3) AS x, x FROM numbers(1, 2);
|
||||||
|
SELECT anyLast(number * 3) AS x, x FROM numbers(1, 2);
|
||||||
|
|
||||||
|
SELECT any(anyLast(number)) FROM numbers(1); -- { serverError 184 }
|
||||||
|
|
||||||
|
SET optimize_move_functions_out_of_any = 0;
|
||||||
|
|
||||||
|
ANALYZE SELECT any(number + number * 2) FROM numbers(1, 2);
|
||||||
|
SELECT any(number + number * 2) FROM numbers(1, 2);
|
||||||
|
|
||||||
|
ANALYZE SELECT anyLast(number + number * 2) FROM numbers(1, 2);
|
||||||
|
SELECT anyLast(number + number * 2) FROM numbers(1, 2);
|
||||||
|
|
||||||
|
ANALYZE WITH any(number * 3) AS x SELECT x FROM numbers(1, 2);
|
||||||
|
WITH any(number * 3) AS x SELECT x FROM numbers(1, 2);
|
||||||
|
|
||||||
|
ANALYZE SELECT anyLast(number * 3) AS x, x FROM numbers(1, 2);
|
||||||
|
SELECT anyLast(number * 3) AS x, x FROM numbers(1, 2);
|
||||||
|
|
||||||
|
SELECT any(anyLast(number)) FROM numbers(1); -- { serverError 184 }
|
||||||
|
@ -0,0 +1 @@
|
|||||||
|
0
|
18
tests/queries/0_stateless/01414_optimize_any_bug.sql
Normal file
18
tests/queries/0_stateless/01414_optimize_any_bug.sql
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
DROP TABLE IF EXISTS test;
|
||||||
|
|
||||||
|
CREATE TABLE test
|
||||||
|
(
|
||||||
|
`Source.C1` Array(UInt64),
|
||||||
|
`Source.C2` Array(UInt64)
|
||||||
|
)
|
||||||
|
ENGINE = MergeTree()
|
||||||
|
ORDER BY tuple();
|
||||||
|
|
||||||
|
SET optimize_move_functions_out_of_any = 1;
|
||||||
|
|
||||||
|
SELECT any(arrayFilter((c, d) -> (4 = d), `Source.C1`, `Source.C2`)[1]) AS x
|
||||||
|
FROM test
|
||||||
|
WHERE 0
|
||||||
|
GROUP BY 42;
|
||||||
|
|
||||||
|
DROP TABLE test;
|
Loading…
Reference in New Issue
Block a user