Apply optimization only in case of type match

This commit is contained in:
Dmitry Novik 2024-01-03 16:18:00 +00:00
parent 1c541e3b00
commit 558d103f96

View File

@ -64,39 +64,43 @@ public:
auto lhs_argument_node_type = lhs_argument->getNodeType();
auto rhs_argument_node_type = rhs_argument->getNodeType();
QueryTreeNodePtr candidate;
if (lhs_argument_node_type == QueryTreeNodeType::FUNCTION && rhs_argument_node_type == QueryTreeNodeType::FUNCTION)
tryOptimizeComparisonTupleFunctions(node, lhs_argument, rhs_argument, comparison_function_name);
candidate = tryOptimizeComparisonTupleFunctions(lhs_argument, rhs_argument, comparison_function_name);
else if (lhs_argument_node_type == QueryTreeNodeType::FUNCTION && rhs_argument_node_type == QueryTreeNodeType::CONSTANT)
tryOptimizeComparisonTupleFunctionAndConstant(node, lhs_argument, rhs_argument, comparison_function_name);
candidate = tryOptimizeComparisonTupleFunctionAndConstant(lhs_argument, rhs_argument, comparison_function_name);
else if (lhs_argument_node_type == QueryTreeNodeType::CONSTANT && rhs_argument_node_type == QueryTreeNodeType::FUNCTION)
tryOptimizeComparisonTupleFunctionAndConstant(node, rhs_argument, lhs_argument, comparison_function_name);
candidate = tryOptimizeComparisonTupleFunctionAndConstant(rhs_argument, lhs_argument, comparison_function_name);
if (candidate != nullptr && node->getResultType()->equals(*candidate->getResultType()))
node = candidate;
}
private:
void tryOptimizeComparisonTupleFunctions(QueryTreeNodePtr & node,
QueryTreeNodePtr tryOptimizeComparisonTupleFunctions(
const QueryTreeNodePtr & lhs_function_node,
const QueryTreeNodePtr & rhs_function_node,
const std::string & comparison_function_name) const
{
const auto & lhs_function_node_typed = lhs_function_node->as<FunctionNode &>();
if (lhs_function_node_typed.getFunctionName() != "tuple")
return;
return {};
const auto & rhs_function_node_typed = rhs_function_node->as<FunctionNode &>();
if (rhs_function_node_typed.getFunctionName() != "tuple")
return;
return {};
const auto & lhs_tuple_function_arguments_nodes = lhs_function_node_typed.getArguments().getNodes();
size_t lhs_tuple_function_arguments_nodes_size = lhs_tuple_function_arguments_nodes.size();
const auto & rhs_tuple_function_arguments_nodes = rhs_function_node_typed.getArguments().getNodes();
if (lhs_tuple_function_arguments_nodes_size != rhs_tuple_function_arguments_nodes.size())
return;
return {};
if (lhs_tuple_function_arguments_nodes_size == 1)
{
node = makeComparisonFunction(lhs_tuple_function_arguments_nodes[0], rhs_tuple_function_arguments_nodes[0], comparison_function_name);
return;
return makeComparisonFunction(lhs_tuple_function_arguments_nodes[0], rhs_tuple_function_arguments_nodes[0], comparison_function_name);
}
QueryTreeNodes tuple_arguments_equals_functions;
@ -108,45 +112,44 @@ private:
tuple_arguments_equals_functions.push_back(std::move(equals_function));
}
node = makeEquivalentTupleComparisonFunction(std::move(tuple_arguments_equals_functions), comparison_function_name);
return makeEquivalentTupleComparisonFunction(std::move(tuple_arguments_equals_functions), comparison_function_name);
}
void tryOptimizeComparisonTupleFunctionAndConstant(QueryTreeNodePtr & node,
QueryTreeNodePtr tryOptimizeComparisonTupleFunctionAndConstant(
const QueryTreeNodePtr & function_node,
const QueryTreeNodePtr & constant_node,
const std::string & comparison_function_name) const
{
const auto & function_node_typed = function_node->as<FunctionNode &>();
if (function_node_typed.getFunctionName() != "tuple")
return;
return {};
auto & constant_node_typed = constant_node->as<ConstantNode &>();
const auto & constant_node_value = constant_node_typed.getValue();
if (constant_node_value.getType() != Field::Types::Which::Tuple)
return;
return {};
const auto & constant_tuple = constant_node_value.get<const Tuple &>();
const auto & function_arguments_nodes = function_node_typed.getArguments().getNodes();
size_t function_arguments_nodes_size = function_arguments_nodes.size();
if (function_arguments_nodes_size != constant_tuple.size())
return;
return {};
auto constant_node_result_type = constant_node_typed.getResultType();
const auto * tuple_data_type = typeid_cast<const DataTypeTuple *>(constant_node_result_type.get());
if (!tuple_data_type)
return;
return {};
const auto & tuple_data_type_elements = tuple_data_type->getElements();
if (tuple_data_type_elements.size() != function_arguments_nodes_size)
return;
return {};
if (function_arguments_nodes_size == 1)
{
auto comparison_argument_constant_value = std::make_shared<ConstantValue>(constant_tuple[0], tuple_data_type_elements[0]);
auto comparison_argument_constant_node = std::make_shared<ConstantNode>(std::move(comparison_argument_constant_value));
node = makeComparisonFunction(function_arguments_nodes[0], std::move(comparison_argument_constant_node), comparison_function_name);
return;
return makeComparisonFunction(function_arguments_nodes[0], std::move(comparison_argument_constant_node), comparison_function_name);
}
QueryTreeNodes tuple_arguments_equals_functions;
@ -160,7 +163,7 @@ private:
tuple_arguments_equals_functions.push_back(std::move(equals_function));
}
node = makeEquivalentTupleComparisonFunction(std::move(tuple_arguments_equals_functions), comparison_function_name);
return makeEquivalentTupleComparisonFunction(std::move(tuple_arguments_equals_functions), comparison_function_name);
}
QueryTreeNodePtr makeEquivalentTupleComparisonFunction(QueryTreeNodes tuple_arguments_equals_functions,