diff --git a/src/Analyzer/Passes/ComparisonTupleEliminationPass.cpp b/src/Analyzer/Passes/ComparisonTupleEliminationPass.cpp index 117e649ac88..7c38ba81c70 100644 --- a/src/Analyzer/Passes/ComparisonTupleEliminationPass.cpp +++ b/src/Analyzer/Passes/ComparisonTupleEliminationPass.cpp @@ -64,39 +64,43 @@ public: auto lhs_argument_node_type = lhs_argument->getNodeType(); auto rhs_argument_node_type = rhs_argument->getNodeType(); + QueryTreeNodePtr candidate; + if (lhs_argument_node_type == QueryTreeNodeType::FUNCTION && rhs_argument_node_type == QueryTreeNodeType::FUNCTION) - tryOptimizeComparisonTupleFunctions(node, lhs_argument, rhs_argument, comparison_function_name); + candidate = tryOptimizeComparisonTupleFunctions(lhs_argument, rhs_argument, comparison_function_name); else if (lhs_argument_node_type == QueryTreeNodeType::FUNCTION && rhs_argument_node_type == QueryTreeNodeType::CONSTANT) - tryOptimizeComparisonTupleFunctionAndConstant(node, lhs_argument, rhs_argument, comparison_function_name); + candidate = tryOptimizeComparisonTupleFunctionAndConstant(lhs_argument, rhs_argument, comparison_function_name); else if (lhs_argument_node_type == QueryTreeNodeType::CONSTANT && rhs_argument_node_type == QueryTreeNodeType::FUNCTION) - tryOptimizeComparisonTupleFunctionAndConstant(node, rhs_argument, lhs_argument, comparison_function_name); + candidate = tryOptimizeComparisonTupleFunctionAndConstant(rhs_argument, lhs_argument, comparison_function_name); + + if (candidate != nullptr && node->getResultType()->equals(*candidate->getResultType())) + node = candidate; } private: - void tryOptimizeComparisonTupleFunctions(QueryTreeNodePtr & node, + QueryTreeNodePtr tryOptimizeComparisonTupleFunctions( const QueryTreeNodePtr & lhs_function_node, const QueryTreeNodePtr & rhs_function_node, const std::string & comparison_function_name) const { const auto & lhs_function_node_typed = lhs_function_node->as(); if (lhs_function_node_typed.getFunctionName() != "tuple") - return; + return {}; const auto & rhs_function_node_typed = rhs_function_node->as(); if (rhs_function_node_typed.getFunctionName() != "tuple") - return; + return {}; const auto & lhs_tuple_function_arguments_nodes = lhs_function_node_typed.getArguments().getNodes(); size_t lhs_tuple_function_arguments_nodes_size = lhs_tuple_function_arguments_nodes.size(); const auto & rhs_tuple_function_arguments_nodes = rhs_function_node_typed.getArguments().getNodes(); if (lhs_tuple_function_arguments_nodes_size != rhs_tuple_function_arguments_nodes.size()) - return; + return {}; if (lhs_tuple_function_arguments_nodes_size == 1) { - node = makeComparisonFunction(lhs_tuple_function_arguments_nodes[0], rhs_tuple_function_arguments_nodes[0], comparison_function_name); - return; + return makeComparisonFunction(lhs_tuple_function_arguments_nodes[0], rhs_tuple_function_arguments_nodes[0], comparison_function_name); } QueryTreeNodes tuple_arguments_equals_functions; @@ -108,45 +112,44 @@ private: tuple_arguments_equals_functions.push_back(std::move(equals_function)); } - node = makeEquivalentTupleComparisonFunction(std::move(tuple_arguments_equals_functions), comparison_function_name); + return makeEquivalentTupleComparisonFunction(std::move(tuple_arguments_equals_functions), comparison_function_name); } - void tryOptimizeComparisonTupleFunctionAndConstant(QueryTreeNodePtr & node, + QueryTreeNodePtr tryOptimizeComparisonTupleFunctionAndConstant( const QueryTreeNodePtr & function_node, const QueryTreeNodePtr & constant_node, const std::string & comparison_function_name) const { const auto & function_node_typed = function_node->as(); if (function_node_typed.getFunctionName() != "tuple") - return; + return {}; auto & constant_node_typed = constant_node->as(); const auto & constant_node_value = constant_node_typed.getValue(); if (constant_node_value.getType() != Field::Types::Which::Tuple) - return; + return {}; const auto & constant_tuple = constant_node_value.get(); const auto & function_arguments_nodes = function_node_typed.getArguments().getNodes(); size_t function_arguments_nodes_size = function_arguments_nodes.size(); if (function_arguments_nodes_size != constant_tuple.size()) - return; + return {}; auto constant_node_result_type = constant_node_typed.getResultType(); const auto * tuple_data_type = typeid_cast(constant_node_result_type.get()); if (!tuple_data_type) - return; + return {}; const auto & tuple_data_type_elements = tuple_data_type->getElements(); if (tuple_data_type_elements.size() != function_arguments_nodes_size) - return; + return {}; if (function_arguments_nodes_size == 1) { auto comparison_argument_constant_value = std::make_shared(constant_tuple[0], tuple_data_type_elements[0]); auto comparison_argument_constant_node = std::make_shared(std::move(comparison_argument_constant_value)); - node = makeComparisonFunction(function_arguments_nodes[0], std::move(comparison_argument_constant_node), comparison_function_name); - return; + return makeComparisonFunction(function_arguments_nodes[0], std::move(comparison_argument_constant_node), comparison_function_name); } QueryTreeNodes tuple_arguments_equals_functions; @@ -160,7 +163,7 @@ private: tuple_arguments_equals_functions.push_back(std::move(equals_function)); } - node = makeEquivalentTupleComparisonFunction(std::move(tuple_arguments_equals_functions), comparison_function_name); + return makeEquivalentTupleComparisonFunction(std::move(tuple_arguments_equals_functions), comparison_function_name); } QueryTreeNodePtr makeEquivalentTupleComparisonFunction(QueryTreeNodes tuple_arguments_equals_functions,