ClickHouse/src/Analyzer/Passes/AggregateFunctionsArithmericOperationsPass.cpp

210 lines
9.1 KiB
C++

#include <Analyzer/Passes/AggregateFunctionsArithmericOperationsPass.h>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/IAggregateFunction.h>
#include <Functions/FunctionFactory.h>
#include <Functions/IFunction.h>
#include <Analyzer/InDepthQueryTreeVisitor.h>
#include <Analyzer/ConstantNode.h>
#include <Analyzer/FunctionNode.h>
namespace DB
{
namespace ErrorCodes
{
extern const int BAD_TYPE_OF_FIELD;
}
namespace
{
Field zeroField(const Field & value)
{
switch (value.getType())
{
case Field::Types::UInt64: return static_cast<UInt64>(0);
case Field::Types::Int64: return static_cast<Int64>(0);
case Field::Types::Float64: return static_cast<Float64>(0);
case Field::Types::UInt128: return static_cast<UInt128>(0);
case Field::Types::Int128: return static_cast<Int128>(0);
case Field::Types::UInt256: return static_cast<UInt256>(0);
case Field::Types::Int256: return static_cast<Int256>(0);
default:
break;
}
throw Exception(ErrorCodes::BAD_TYPE_OF_FIELD, "Unexpected literal type in function");
}
/** Rewrites: sum([multiply|divide]) -> [multiply|divide](sum)
* [min|max|avg]([multiply|divide|plus|minus]) -> [multiply|divide|plus|minus]([min|max|avg])
*
* TODO: Support `groupBitAnd`, `groupBitOr`, `groupBitXor` functions.
* TODO: Support rewrite `f((2 * n) * n)` into '2 * f(n * n)'.
*/
class AggregateFunctionsArithmericOperationsVisitor : public InDepthQueryTreeVisitorWithContext<AggregateFunctionsArithmericOperationsVisitor>
{
public:
using Base = InDepthQueryTreeVisitorWithContext<AggregateFunctionsArithmericOperationsVisitor>;
using Base::Base;
/// Traverse tree bottom to top
static bool shouldTraverseTopToBottom()
{
return false;
}
void visitImpl(QueryTreeNodePtr & node)
{
if (!getSettings().optimize_arithmetic_operations_in_aggregate_functions)
return;
auto * aggregate_function_node = node->as<FunctionNode>();
if (!aggregate_function_node || !aggregate_function_node->isAggregateFunction())
return;
static std::unordered_map<std::string_view, std::unordered_set<std::string_view>> supported_aggregate_functions
= {{"sum", {"multiply", "divide"}},
{"min", {"multiply", "divide", "plus", "minus"}},
{"max", {"multiply", "divide", "plus", "minus"}},
{"avg", {"multiply", "divide", "plus", "minus"}}};
auto & aggregate_function_arguments_nodes = aggregate_function_node->getArguments().getNodes();
if (aggregate_function_arguments_nodes.size() != 1)
return;
const auto & arithmetic_function_node = aggregate_function_arguments_nodes[0];
auto * arithmetic_function_node_typed = arithmetic_function_node->as<FunctionNode>();
if (!arithmetic_function_node_typed)
return;
const auto & arithmetic_function_arguments_nodes = arithmetic_function_node_typed->getArguments().getNodes();
if (arithmetic_function_arguments_nodes.size() != 2)
return;
/// Aggregate functions[sum|min|max|avg] is case-insensitive, so we use lower cases name
auto lower_aggregate_function_name = Poco::toLower(aggregate_function_node->getFunctionName());
auto supported_aggregate_function_it = supported_aggregate_functions.find(lower_aggregate_function_name);
if (supported_aggregate_function_it == supported_aggregate_functions.end())
return;
const auto & arithmetic_function_name = arithmetic_function_node_typed->getFunctionName();
if (!supported_aggregate_function_it->second.contains(arithmetic_function_name))
return;
const auto * left_argument_constant_node = arithmetic_function_arguments_nodes[0]->as<ConstantNode>();
const auto * right_argument_constant_node = arithmetic_function_arguments_nodes[1]->as<ConstantNode>();
if (!left_argument_constant_node && !right_argument_constant_node)
return;
/** If we extract negative constant, aggregate function name must be updated.
*
* Example: SELECT min(-1 * id);
* Result: SELECT -1 * max(id);
*/
std::string aggregate_function_name_if_constant_is_negative;
if (arithmetic_function_name == "multiply" || arithmetic_function_name == "divide")
{
if (lower_aggregate_function_name == "min")
aggregate_function_name_if_constant_is_negative = "max";
else if (lower_aggregate_function_name == "max")
aggregate_function_name_if_constant_is_negative = "min";
}
size_t arithmetic_function_argument_index = 0;
if (left_argument_constant_node && !right_argument_constant_node)
{
/// Do not rewrite `sum(1/n)` with `sum(1) * div(1/n)` because of lose accuracy
if (arithmetic_function_name == "divide")
return;
/// Rewrite `aggregate_function(inner_function(constant, argument))` into `inner_function(constant, aggregate_function(argument))`
const auto & left_argument_constant_value_literal = left_argument_constant_node->getValue();
if (!aggregate_function_name_if_constant_is_negative.empty() &&
left_argument_constant_value_literal < zeroField(left_argument_constant_value_literal))
{
lower_aggregate_function_name = aggregate_function_name_if_constant_is_negative;
}
arithmetic_function_argument_index = 1;
}
else if (right_argument_constant_node)
{
/// Rewrite `aggregate_function(inner_function(argument, constant))` into `inner_function(aggregate_function(argument), constant)`
const auto & right_argument_constant_value_literal = right_argument_constant_node->getValue();
if (!aggregate_function_name_if_constant_is_negative.empty() &&
right_argument_constant_value_literal < zeroField(right_argument_constant_value_literal))
{
lower_aggregate_function_name = aggregate_function_name_if_constant_is_negative;
}
arithmetic_function_argument_index = 0;
}
auto optimized_function_node = cloneArithmeticFunctionAndWrapArgumentIntoAggregateFunction(arithmetic_function_node,
arithmetic_function_argument_index,
node,
lower_aggregate_function_name);
if (optimized_function_node->getResultType()->equals(*node->getResultType()))
node = std::move(optimized_function_node);
}
private:
QueryTreeNodePtr cloneArithmeticFunctionAndWrapArgumentIntoAggregateFunction(
const QueryTreeNodePtr & arithmetic_function,
size_t arithmetic_function_argument_index,
const QueryTreeNodePtr & aggregate_function,
const std::string & result_aggregate_function_name)
{
auto arithmetic_function_clone = arithmetic_function->clone();
auto & arithmetic_function_clone_typed = arithmetic_function_clone->as<FunctionNode &>();
auto & arithmetic_function_clone_arguments_nodes = arithmetic_function_clone_typed.getArguments().getNodes();
auto & arithmetic_function_clone_argument = arithmetic_function_clone_arguments_nodes[arithmetic_function_argument_index];
auto aggregate_function_clone = aggregate_function->clone();
auto & aggregate_function_clone_typed = aggregate_function_clone->as<FunctionNode &>();
aggregate_function_clone_typed.getArguments().getNodes() = { arithmetic_function_clone_argument };
resolveAggregateFunctionNode(aggregate_function_clone_typed, arithmetic_function_clone_argument, result_aggregate_function_name);
arithmetic_function_clone_arguments_nodes[arithmetic_function_argument_index] = std::move(aggregate_function_clone);
resolveOrdinaryFunctionNode(arithmetic_function_clone_typed, arithmetic_function_clone_typed.getFunctionName());
return arithmetic_function_clone;
}
inline void resolveOrdinaryFunctionNode(FunctionNode & function_node, const String & function_name) const
{
auto function = FunctionFactory::instance().get(function_name, getContext());
function_node.resolveAsFunction(function->build(function_node.getArgumentColumns()));
}
static inline void resolveAggregateFunctionNode(FunctionNode & function_node, const QueryTreeNodePtr & argument, const String & aggregate_function_name)
{
auto function_aggregate_function = function_node.getAggregateFunction();
AggregateFunctionProperties properties;
auto aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name,
{ argument->getResultType() },
function_aggregate_function->getParameters(),
properties);
function_node.resolveAsAggregateFunction(std::move(aggregate_function));
}
};
}
void AggregateFunctionsArithmericOperationsPass::run(QueryTreeNodePtr query_tree_node, ContextPtr context)
{
AggregateFunctionsArithmericOperationsVisitor visitor(std::move(context));
visitor.visit(query_tree_node);
}
}