better code in FuseFunctionsPass.cpp

This commit is contained in:
vdimir 2022-11-04 11:27:46 +00:00
parent 97e7c505ad
commit 7daf5200f0
No known key found for this signature in database
GPG Key ID: 6EE4CE2BEDC51862

View File

@ -18,7 +18,7 @@ namespace DB
namespace
{
class FuseFunctionsMatcher : public InDepthQueryTreeVisitor<FuseFunctionsMatcher>
class FuseFunctionsVisitor : public InDepthQueryTreeVisitor<FuseFunctionsVisitor>
{
public:
@ -33,101 +33,117 @@ public:
if (!function_node || !function_node->isAggregateFunction() || !matchFunctionName(function_node->getFunctionName()))
return;
auto argument_hash = function_node->getArgumentsNode()->getTreeHash();
mapping[argument_hash].push_back(&node);
const auto & arguments = function_node->getArgumentsNode()->getChildren();
if (arguments.size() != 1)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Aggregate function {} must have exactly one argument", function_node->getFunctionName());
mapping[QueryTreeNodeWithHash(arguments[0])].push_back(&node);
}
struct QueryTreeHashForMap
struct QueryTreeNodeWithHash
{
size_t operator()(const IQueryTreeNode::Hash & hash) const { return hash.first ^ hash.second; }
const QueryTreeNodePtr & node;
IQueryTreeNode::Hash hash;
explicit QueryTreeNodeWithHash(const QueryTreeNodePtr & node_)
: node(node_)
, hash(node->getTreeHash())
{}
bool operator==(const QueryTreeNodeWithHash & rhs) const
{
return hash == rhs.hash && node->isEqual(*rhs.node);
}
struct Hash
{
size_t operator() (const QueryTreeNodeWithHash & key) const { return key.hash.first ^ key.hash.second; }
};
};
/// argument -> list of sum/count/avg functions with this argument
std::unordered_map<IQueryTreeNode::Hash, std::vector<QueryTreeNodePtr *>, QueryTreeHashForMap> mapping;
std::unordered_map<QueryTreeNodeWithHash, std::vector<QueryTreeNodePtr *>, QueryTreeNodeWithHash::Hash> mapping;
};
template <typename... Args>
QueryTreeNodePtr createResolvedFunction(ContextPtr context, const String & name, DataTypePtr result_type, Args &&... args)
QueryTreeNodePtr createResolvedFunction(ContextPtr context, const String & name, DataTypePtr result_type, QueryTreeNodes arguments)
{
auto function_node = std::make_shared<FunctionNode>(name);
auto function = FunctionFactory::instance().get(name, context);
function_node->resolveAsFunction(function, result_type);
function_node->getArguments().getNodes() = { std::forward<Args>(args)... };
function_node->getArgumentsNode() = std::make_shared<ListNode>(std::move(arguments));
return function_node;
}
QueryTreeNodePtr createTupleElement(ContextPtr context, DataTypePtr result_type, QueryTreeNodePtr argument, UInt64 idx)
{
return createResolvedFunction(context, "tupleElement", result_type, argument, std::make_shared<ConstantNode>(idx));
}
QueryTreeNodePtr createSumCount(const FunctionNode & function_node, ContextPtr context)
FunctionNodePtr createSumCoundNode(const QueryTreeNodePtr & argument)
{
auto sum_count_node = std::make_shared<FunctionNode>("sumCount");
DataTypePtr sum_return_type;
DataTypePtr count_return_type;
AggregateFunctionProperties properties;
auto aggregate_function = AggregateFunctionFactory::instance().get("sumCount", {argument->getResultType()}, {}, properties);
sum_count_node->resolveAsAggregateFunction(aggregate_function, aggregate_function->getReturnType());
sum_count_node->getArgumentsNode() = std::make_shared<ListNode>(QueryTreeNodes{argument});
return sum_count_node;
}
QueryTreeNodePtr createTupleElementFunction(ContextPtr context, DataTypePtr result_type, QueryTreeNodePtr argument, UInt64 index)
{
return createResolvedFunction(context, "tupleElement", result_type, {argument, std::make_shared<ConstantNode>(index)});
}
void replaceWithSumCount(QueryTreeNodePtr & node, const FunctionNodePtr & sum_count_node, ContextPtr context)
{
auto sum_count_result_type = std::dynamic_pointer_cast<const DataTypeTuple>(sum_count_node->getResultType());
if (!sum_count_result_type || sum_count_result_type->getElements().size() != 2)
{
AggregateFunctionProperties properties;
auto function_aggregate_function = function_node.getAggregateFunction();
auto aggregate_function = AggregateFunctionFactory::instance().get("sumCount",
function_aggregate_function->getArgumentTypes(),
function_aggregate_function->getParameters(),
properties);
sum_count_node->resolveAsAggregateFunction(aggregate_function, aggregate_function->getReturnType());
if (auto ret_type = std::dynamic_pointer_cast<const DataTypeTuple>(aggregate_function->getReturnType()))
{
sum_return_type = ret_type->getElement(0);
count_return_type = ret_type->getElement(1);
}
else
{
throw Exception(ErrorCodes::LOGICAL_ERROR, "Unexpected return type '{}' of sumCount aggregate function",
aggregate_function->getReturnType()->getName());
}
sum_count_node->getArgumentsNode() = function_node.getArgumentsNode();
throw Exception(ErrorCodes::LOGICAL_ERROR,
"Unexpected return type '{}' of function {}, should be tuple of two elements",
sum_count_node->getResultType(), sum_count_node->getFunctionName());
}
/// TODO: function_node.getResultType() or sum_return_type/count_return_type.
/// Should it be the same?
String function_name = node->as<const FunctionNode &>().getFunctionName();
if (function_node.getFunctionName() == "sum")
return createTupleElement(context, function_node.getResultType(), sum_count_node, 1);
if (function_node.getFunctionName() == "count")
return createTupleElement(context, function_node.getResultType(), sum_count_node, 2);
if (function_node.getFunctionName() == "avg")
if (function_name == "sum")
{
auto sum_result = createTupleElement(context, sum_return_type, sum_count_node, 1);
auto count_result = createTupleElement(context, count_return_type, sum_count_node, 2);
assert(node->getResultType() == sum_count_result_type->getElement(0));
node = createTupleElementFunction(context, node->getResultType(), sum_count_node, 1);
}
else if (function_name == "count")
{
assert(node->getResultType() == sum_count_result_type->getElement(1));
node = createTupleElementFunction(context, node->getResultType(), sum_count_node, 2);
}
else if (function_name == "avg")
{
auto sum_result = createTupleElementFunction(context, sum_count_result_type->getElement(0), sum_count_node, 1);
auto count_result = createTupleElementFunction(context, sum_count_result_type->getElement(1), sum_count_node, 2);
/// To avoid integer division by zero
auto count_float_result = createResolvedFunction(context, "toFloat64", std::make_shared<DataTypeFloat64>(), count_result);
return createResolvedFunction(context, "divide", function_node.getResultType(), sum_result, count_float_result);
auto count_float_result = createResolvedFunction(context, "toFloat64", std::make_shared<DataTypeFloat64>(), {count_result});
node = createResolvedFunction(context, "divide", node->getResultType(), {sum_result, count_float_result});
}
else
{
throw Exception(ErrorCodes::LOGICAL_ERROR, "Unsupported function '{}'", function_name);
}
throw Exception(ErrorCodes::LOGICAL_ERROR, "Unsupported function '{}'", function_node.getFunctionName());
}
}
void FuseFunctionsPass::run(QueryTreeNodePtr query_tree_node, ContextPtr context)
{
FuseFunctionsMatcher visitor;
FuseFunctionsVisitor visitor;
visitor.visit(query_tree_node);
for (auto & [_, nodes] : visitor.mapping)
for (auto & [argument, nodes] : visitor.mapping)
{
if (nodes.size() < 2)
continue;
auto sum_count_node = createSumCoundNode(argument.node);
for (auto * node : nodes)
{
*node = createSumCount((*node)->as<const FunctionNode &>(), context);
replaceWithSumCount(*node, sum_count_node, context);
}
}
}