diff --git a/src/Interpreters/AnyInputOptimize.cpp b/src/Interpreters/AnyInputOptimize.cpp deleted file mode 100644 index 1b31ea4024b..00000000000 --- a/src/Interpreters/AnyInputOptimize.cpp +++ /dev/null @@ -1,98 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include - -namespace DB -{ - -namespace ErrorCodes -{ - extern const int LOGICAL_ERROR; - extern const int ILLEGAL_AGGREGATION; -} - -namespace -{ - constexpr auto * any = "any"; - constexpr auto * anyLast = "anyLast"; -} - -ASTPtr * getExactChild(const ASTPtr & ast, const size_t ind) -{ - if (ast && ast->as()->arguments->children[ind]) - return &ast->as()->arguments->children[ind]; - return nullptr; -} - -///recursive searching of identifiers -void changeAllIdentifiers(ASTPtr & ast, size_t ind, const std::string & name) -{ - ASTPtr * exact_child = getExactChild(ast, ind); - if (!exact_child) - return; - - if ((*exact_child)->as()) - { - ///put new any - ASTPtr old_ast = *exact_child; - *exact_child = makeASTFunction(name); - (*exact_child)->as()->arguments->children.push_back(old_ast); - } - else if ((*exact_child)->as()) - { - if (AggregateFunctionFactory::instance().isAggregateFunctionName((*exact_child)->as()->name)) - throw Exception("Aggregate function " + (*exact_child)->as()->name + - " is found inside aggregate function " + name + " in query", ErrorCodes::ILLEGAL_AGGREGATION); - - for (size_t i = 0; i < (*exact_child)->as()->arguments->children.size(); i++) - changeAllIdentifiers(*exact_child, i, name); - } -} - - -///cut old any, put any to identifiers. any(functions(x)) -> functions(any(x)) -void AnyInputMatcher::visit(ASTPtr & current_ast, Data data) -{ - data = {}; - if (!current_ast) - return; - - auto * function_node = current_ast->as(); - if (!function_node || function_node->arguments->children.empty()) - return; - - const auto & function_argument = function_node->arguments->children[0]; - if ((function_node->name == any || function_node->name == anyLast) - && function_argument && function_argument->as()) - { - auto name = function_node->name; - auto alias = function_node->alias; - - ///cut any or anyLast - if (!function_argument->as()->arguments->children.empty()) - { - current_ast = function_argument->clone(); - current_ast->setAlias(alias); - for (size_t i = 0; i < current_ast->as()->arguments->children.size(); ++i) - changeAllIdentifiers(current_ast, i, name); - } - } -} - -bool AnyInputMatcher::needChildVisit(const ASTPtr & node, const ASTPtr & child) -{ - if (!child) - throw Exception("AST item should not have nullptr in children", ErrorCodes::LOGICAL_ERROR); - - if (node->as() || node->as()) - return false; // NOLINT - - return true; -} - -} diff --git a/src/Interpreters/AnyInputOptimize.h b/src/Interpreters/AnyInputOptimize.h deleted file mode 100644 index 6e782578e35..00000000000 --- a/src/Interpreters/AnyInputOptimize.h +++ /dev/null @@ -1,19 +0,0 @@ -#pragma once - -#include -#include - -namespace DB -{ - -///This optimiser is similar to ArithmeticOperationsInAgrFunc optimizer, but for function any we can extract any functions. -class AnyInputMatcher -{ -public: - struct Data {}; - - static void visit(ASTPtr & ast, Data data); - static bool needChildVisit(const ASTPtr & node, const ASTPtr & child); -}; -using AnyInputVisitor = InDepthNodeVisitor; -} diff --git a/src/Interpreters/ArithmeticOperationsInAgrFuncOptimize.cpp b/src/Interpreters/ArithmeticOperationsInAgrFuncOptimize.cpp index c67986d97c0..3ee7b59197a 100644 --- a/src/Interpreters/ArithmeticOperationsInAgrFuncOptimize.cpp +++ b/src/Interpreters/ArithmeticOperationsInAgrFuncOptimize.cpp @@ -153,7 +153,9 @@ void ArithmeticOperationsInAgrFuncMatcher::visit(ASTPtr & ast, Data & data) bool ArithmeticOperationsInAgrFuncMatcher::needChildVisit(const ASTPtr & node, const ASTPtr &) { - return !node->as() && !node->as(); + return !node->as() && + !node->as() && + !node->as(); } } diff --git a/src/Interpreters/RewriteAnyFunctionVisitor.cpp b/src/Interpreters/RewriteAnyFunctionVisitor.cpp new file mode 100644 index 00000000000..7d5d204499e --- /dev/null +++ b/src/Interpreters/RewriteAnyFunctionVisitor.cpp @@ -0,0 +1,96 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ + +namespace +{ + +bool extractIdentifiers(const ASTFunction & func, std::vector & identifiers) +{ + for (auto & arg : func.arguments->children) + { + if (const auto * arg_func = arg->as()) + { + if (arg_func->name == "lambda") + return false; + + if (AggregateFunctionFactory::instance().isAggregateFunctionName(arg_func->name)) + return false; + + if (!extractIdentifiers(*arg_func, identifiers)) + return false; + } + else if (arg->as()) + identifiers.emplace_back(&arg); + } + + return true; +} + +} + + +void RewriteAnyFunctionMatcher::visit(ASTPtr & ast, Data & data) +{ + if (auto * func = ast->as()) + visit(*func, ast, data); +} + +void RewriteAnyFunctionMatcher::visit(const ASTFunction & func, ASTPtr & ast, Data & data) +{ + if (func.arguments->children.empty() || !func.arguments->children[0]) + return; + + if (func.name != "any" && func.name != "anyLast") + return; + + auto & func_arguments = func.arguments->children; + + const auto * first_arg_func = func_arguments[0]->as(); + if (!first_arg_func || first_arg_func->arguments->children.empty()) + return; + + /// We have rewritten this function. Just unwrap its argument. + if (data.rewritten.count(ast.get())) + { + func_arguments[0]->setAlias(func.alias); + ast = func_arguments[0]; + return; + } + + std::vector identifiers; + if (!extractIdentifiers(func, identifiers)) + return; + + /// Wrap identifiers: any(f(x, y, g(z))) -> any(f(any(x), any(y), g(any(z)))) + for (auto * ast_to_change : identifiers) + { + ASTPtr identifier_ast = *ast_to_change; + *ast_to_change = makeASTFunction(func.name); + (*ast_to_change)->as()->arguments->children.emplace_back(identifier_ast); + } + + data.rewritten.insert(ast.get()); + + /// Unwrap function: any(f(any(x), any(y), g(any(z)))) -> f(any(x), any(y), g(any(z))) + func_arguments[0]->setAlias(func.alias); + ast = func_arguments[0]; +} + +bool RewriteAnyFunctionMatcher::needChildVisit(const ASTPtr & node, const ASTPtr &) +{ + return !node->as() && + !node->as() && + !node->as(); +} + +} diff --git a/src/Interpreters/RewriteAnyFunctionVisitor.h b/src/Interpreters/RewriteAnyFunctionVisitor.h new file mode 100644 index 00000000000..d29af322711 --- /dev/null +++ b/src/Interpreters/RewriteAnyFunctionVisitor.h @@ -0,0 +1,29 @@ +#pragma once + +#include + +#include +#include + +namespace DB +{ + +class ASTFunction; + +/// Rewrite 'any' and 'anyLast' functions pushing them inside original function. +/// any(f(x, y, g(z))) -> f(any(x), any(y), g(any(z))) +class RewriteAnyFunctionMatcher +{ +public: + struct Data + { + std::unordered_set rewritten; + }; + + static void visit(ASTPtr & ast, Data & data); + static void visit(const ASTFunction &, ASTPtr & ast, Data & data); + static bool needChildVisit(const ASTPtr & node, const ASTPtr & child); +}; +using RewriteAnyFunctionVisitor = InDepthNodeVisitor; + +} diff --git a/src/Interpreters/TreeOptimizer.cpp b/src/Interpreters/TreeOptimizer.cpp index 97922aba3f0..99eaf6e6736 100644 --- a/src/Interpreters/TreeOptimizer.cpp +++ b/src/Interpreters/TreeOptimizer.cpp @@ -8,7 +8,7 @@ #include #include #include -#include +#include #include #include #include @@ -458,11 +458,10 @@ void optimizeAggregationFunctions(ASTPtr & query) ArithmeticOperationsInAgrFuncVisitor(data).visit(query); } -void optimizeAnyInput(ASTPtr & query) +void optimizeAnyFunctions(ASTPtr & query) { - /// Removing arithmetic operations from functions - AnyInputVisitor::Data data = {}; - AnyInputVisitor(data).visit(query); + RewriteAnyFunctionVisitor::Data data = {}; + RewriteAnyFunctionVisitor(data).visit(query); } void optimizeInjectiveFunctionsInsideUniq(ASTPtr & query, const Context & context) @@ -520,9 +519,9 @@ void TreeOptimizer::apply(ASTPtr & query, Aliases & aliases, const NameSet & sou if (settings.optimize_group_by_function_keys) optimizeGroupByFunctionKeys(select_query); - ///Move all operations out of any function + /// Move all operations out of any function if (settings.optimize_move_functions_out_of_any) - optimizeAnyInput(query); + optimizeAnyFunctions(query); /// Remove injective functions inside uniq if (settings.optimize_injective_functions_inside_uniq) diff --git a/src/Interpreters/ya.make b/src/Interpreters/ya.make index 48be2fae1b9..c571b857365 100644 --- a/src/Interpreters/ya.make +++ b/src/Interpreters/ya.make @@ -20,7 +20,6 @@ SRCS( addTypeConversionToAST.cpp AggregateDescription.cpp Aggregator.cpp - AnyInputOptimize.cpp ArithmeticOperationsInAgrFuncOptimize.cpp ArithmeticOperationsInAgrFuncOptimize.h ArrayJoinAction.cpp @@ -127,6 +126,7 @@ SRCS( ReplaceQueryParameterVisitor.cpp RequiredSourceColumnsData.cpp RequiredSourceColumnsVisitor.cpp + RewriteAnyFunctionVisitor.cpp RowRefs.cpp Set.cpp SetVariants.cpp diff --git a/tests/queries/0_stateless/01322_any_input_optimize.reference b/tests/queries/0_stateless/01322_any_input_optimize.reference index 209a4afb3a2..4b6fbfd32c0 100644 --- a/tests/queries/0_stateless/01322_any_input_optimize.reference +++ b/tests/queries/0_stateless/01322_any_input_optimize.reference @@ -1,3 +1,30 @@ -9 SELECT any(number) + (any(number) * 2) -FROM numbers(3, 10) +FROM numbers(1, 2) +3 +SELECT anyLast(number) + (anyLast(number) * 2) +FROM numbers(1, 2) +6 +WITH any(number) * 3 AS x +SELECT x +FROM numbers(1, 2) +3 +SELECT + anyLast(number) * 3 AS x, + x +FROM numbers(1, 2) +6 6 +SELECT any(number + (number * 2)) +FROM numbers(1, 2) +3 +SELECT anyLast(number + (number * 2)) +FROM numbers(1, 2) +6 +WITH any(number * 3) AS x +SELECT x +FROM numbers(1, 2) +3 +SELECT + anyLast(number * 3) AS x, + x +FROM numbers(1, 2) +6 6 diff --git a/tests/queries/0_stateless/01322_any_input_optimize.sql b/tests/queries/0_stateless/01322_any_input_optimize.sql index 65f09d65738..4b8a55d4c7b 100644 --- a/tests/queries/0_stateless/01322_any_input_optimize.sql +++ b/tests/queries/0_stateless/01322_any_input_optimize.sql @@ -1,4 +1,32 @@ -SET optimize_move_functions_out_of_any=1; -SET enable_debug_queries=1; -SELECT any(number + number * 2) FROM numbers(3, 10); -ANALYZE SELECT any(number + number * 2) FROM numbers(3, 10); +SET enable_debug_queries = 1; +SET optimize_move_functions_out_of_any = 1; + +ANALYZE SELECT any(number + number * 2) FROM numbers(1, 2); +SELECT any(number + number * 2) FROM numbers(1, 2); + +ANALYZE SELECT anyLast(number + number * 2) FROM numbers(1, 2); +SELECT anyLast(number + number * 2) FROM numbers(1, 2); + +ANALYZE WITH any(number * 3) AS x SELECT x FROM numbers(1, 2); +WITH any(number * 3) AS x SELECT x FROM numbers(1, 2); + +ANALYZE SELECT anyLast(number * 3) AS x, x FROM numbers(1, 2); +SELECT anyLast(number * 3) AS x, x FROM numbers(1, 2); + +SELECT any(anyLast(number)) FROM numbers(1); -- { serverError 184 } + +SET optimize_move_functions_out_of_any = 0; + +ANALYZE SELECT any(number + number * 2) FROM numbers(1, 2); +SELECT any(number + number * 2) FROM numbers(1, 2); + +ANALYZE SELECT anyLast(number + number * 2) FROM numbers(1, 2); +SELECT anyLast(number + number * 2) FROM numbers(1, 2); + +ANALYZE WITH any(number * 3) AS x SELECT x FROM numbers(1, 2); +WITH any(number * 3) AS x SELECT x FROM numbers(1, 2); + +ANALYZE SELECT anyLast(number * 3) AS x, x FROM numbers(1, 2); +SELECT anyLast(number * 3) AS x, x FROM numbers(1, 2); + +SELECT any(anyLast(number)) FROM numbers(1); -- { serverError 184 } diff --git a/tests/queries/0_stateless/01414_optimize_any_bug.reference b/tests/queries/0_stateless/01414_optimize_any_bug.reference new file mode 100644 index 00000000000..573541ac970 --- /dev/null +++ b/tests/queries/0_stateless/01414_optimize_any_bug.reference @@ -0,0 +1 @@ +0 diff --git a/tests/queries/0_stateless/01414_optimize_any_bug.sql b/tests/queries/0_stateless/01414_optimize_any_bug.sql new file mode 100644 index 00000000000..6f6f291c504 --- /dev/null +++ b/tests/queries/0_stateless/01414_optimize_any_bug.sql @@ -0,0 +1,18 @@ +DROP TABLE IF EXISTS test; + +CREATE TABLE test +( + `Source.C1` Array(UInt64), + `Source.C2` Array(UInt64) +) +ENGINE = MergeTree() +ORDER BY tuple(); + +SET optimize_move_functions_out_of_any = 1; + +SELECT any(arrayFilter((c, d) -> (4 = d), `Source.C1`, `Source.C2`)[1]) AS x +FROM test +WHERE 0 +GROUP BY 42; + +DROP TABLE test;