#include #include #include #include #include #include #include #include namespace DB { namespace ErrorCodes { extern const int BAD_TYPE_OF_FIELD; } namespace { const ASTFunction * getInternalFunction(const ASTFunction & func) { if (func.arguments && func.arguments->children.size() == 1) return func.arguments->children[0]->as(); return nullptr; } ASTPtr exchangeExtractFirstArgument(const String & func_name, const ASTFunction & child_func) { ASTs new_child_args; new_child_args.push_back(child_func.arguments->children[1]); auto new_child = makeASTFunction(func_name, new_child_args); ASTs new_args; new_args.push_back(child_func.arguments->children[0]); new_args.push_back(new_child); return makeASTFunction(child_func.name, new_args); } ASTPtr exchangeExtractSecondArgument(const String & func_name, const ASTFunction & child_func) { ASTs new_child_args; new_child_args.push_back(child_func.arguments->children[0]); auto new_child = makeASTFunction(func_name, new_child_args); ASTs new_args; new_args.push_back(new_child); new_args.push_back(child_func.arguments->children[1]); return makeASTFunction(child_func.name, new_args); } Field zeroField(const Field & value) { switch (value.getType()) { case Field::Types::UInt64: return static_cast(0); case Field::Types::Int64: return static_cast(0); case Field::Types::Float64: return static_cast(0); case Field::Types::UInt128: return static_cast(0); case Field::Types::Int128: return static_cast(0); case Field::Types::UInt256: return static_cast(0); case Field::Types::Int256: return static_cast(0); default: break; } throw Exception(ErrorCodes::BAD_TYPE_OF_FIELD, "Unexpected literal type in function"); } ASTPtr tryExchangeFunctions(const ASTFunction & func) { static const std::unordered_map> supported = {{"sum", {"multiply", "divide"}}, {"min", {"multiply", "divide", "plus", "minus"}}, {"max", {"multiply", "divide", "plus", "minus"}}, {"avg", {"multiply", "divide", "plus", "minus"}}}; /// Aggregate functions[sum|min|max|avg] is case-insensitive, so we use lower cases name auto lower_name = Poco::toLower(func.name); const ASTFunction * child_func = getInternalFunction(func); if (!child_func || !child_func->arguments || child_func->arguments->children.size() != 2 || !supported.contains(lower_name) || !supported.find(lower_name)->second.contains(child_func->name)) return {}; auto original_alias = func.tryGetAlias(); const auto & child_func_args = child_func->arguments->children; const auto * first_literal = child_func_args[0]->as(); const auto * second_literal = child_func_args[1]->as(); ASTPtr optimized_ast; /** Need reverse max <-> min for: * * max(-1*value) -> -1*min(value) * max(value/-2) -> min(value)/-2 * max(1-value) -> 1-min(value) */ auto get_reverse_aggregate_function_name = [](const std::string & aggregate_function_name) -> std::string { if (aggregate_function_name == "min") return "max"; else if (aggregate_function_name == "max") return "min"; else return aggregate_function_name; }; if (first_literal && !second_literal) { /// It's possible to rewrite 'sum(1/n)' with 'sum(1) * div(1/n)' but we lose accuracy. Ignored. if (child_func->name == "divide") return {}; bool need_reverse = (child_func->name == "multiply" && first_literal->value < zeroField(first_literal->value)) || child_func->name == "minus"; if (need_reverse) lower_name = get_reverse_aggregate_function_name(lower_name); optimized_ast = exchangeExtractFirstArgument(lower_name, *child_func); } else if (second_literal) /// second or both are consts { bool need_reverse = (child_func->name == "multiply" || child_func->name == "divide") && second_literal->value < zeroField(second_literal->value); if (need_reverse) lower_name = get_reverse_aggregate_function_name(lower_name); optimized_ast = exchangeExtractSecondArgument(lower_name, *child_func); } if (optimized_ast) { optimized_ast->setAlias(original_alias); return optimized_ast; } return {}; } } void ArithmeticOperationsInAgrFuncMatcher::visit(const ASTFunction & func, ASTPtr & ast, Data & data) { if (auto exchanged_funcs = tryExchangeFunctions(func)) { ast = exchanged_funcs; /// Main visitor is bottom-up. This is top-down part. /// We've found an aggregate function an now move it down through others: sum(mul(mul)) -> mul(mul(sum)). /// It's not dangerous cause main visitor already has visited this part of tree. auto & expression_list = ast->children[0]; visit(expression_list->children[0], data); } } void ArithmeticOperationsInAgrFuncMatcher::visit(ASTPtr & ast, Data & data) { if (const auto * function_node = ast->as()) { if (function_node->is_window_function) return; visit(*function_node, ast, data); } } bool ArithmeticOperationsInAgrFuncMatcher::needChildVisit(const ASTPtr & node, const ASTPtr &) { return !node->as() && !node->as() && !node->as(); } }