mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-15 12:14:18 +00:00
better code in FuseFunctionsPass.cpp
This commit is contained in:
parent
97e7c505ad
commit
7daf5200f0
@ -18,7 +18,7 @@ namespace DB
|
||||
namespace
|
||||
{
|
||||
|
||||
class FuseFunctionsMatcher : public InDepthQueryTreeVisitor<FuseFunctionsMatcher>
|
||||
class FuseFunctionsVisitor : public InDepthQueryTreeVisitor<FuseFunctionsVisitor>
|
||||
{
|
||||
public:
|
||||
|
||||
@ -33,101 +33,117 @@ public:
|
||||
if (!function_node || !function_node->isAggregateFunction() || !matchFunctionName(function_node->getFunctionName()))
|
||||
return;
|
||||
|
||||
auto argument_hash = function_node->getArgumentsNode()->getTreeHash();
|
||||
mapping[argument_hash].push_back(&node);
|
||||
const auto & arguments = function_node->getArgumentsNode()->getChildren();
|
||||
if (arguments.size() != 1)
|
||||
throw Exception(ErrorCodes::LOGICAL_ERROR, "Aggregate function {} must have exactly one argument", function_node->getFunctionName());
|
||||
|
||||
mapping[QueryTreeNodeWithHash(arguments[0])].push_back(&node);
|
||||
}
|
||||
|
||||
struct QueryTreeHashForMap
|
||||
struct QueryTreeNodeWithHash
|
||||
{
|
||||
size_t operator()(const IQueryTreeNode::Hash & hash) const { return hash.first ^ hash.second; }
|
||||
const QueryTreeNodePtr & node;
|
||||
IQueryTreeNode::Hash hash;
|
||||
|
||||
explicit QueryTreeNodeWithHash(const QueryTreeNodePtr & node_)
|
||||
: node(node_)
|
||||
, hash(node->getTreeHash())
|
||||
{}
|
||||
|
||||
bool operator==(const QueryTreeNodeWithHash & rhs) const
|
||||
{
|
||||
return hash == rhs.hash && node->isEqual(*rhs.node);
|
||||
}
|
||||
|
||||
struct Hash
|
||||
{
|
||||
size_t operator() (const QueryTreeNodeWithHash & key) const { return key.hash.first ^ key.hash.second; }
|
||||
};
|
||||
};
|
||||
|
||||
/// argument -> list of sum/count/avg functions with this argument
|
||||
std::unordered_map<IQueryTreeNode::Hash, std::vector<QueryTreeNodePtr *>, QueryTreeHashForMap> mapping;
|
||||
std::unordered_map<QueryTreeNodeWithHash, std::vector<QueryTreeNodePtr *>, QueryTreeNodeWithHash::Hash> mapping;
|
||||
};
|
||||
|
||||
template <typename... Args>
|
||||
QueryTreeNodePtr createResolvedFunction(ContextPtr context, const String & name, DataTypePtr result_type, Args &&... args)
|
||||
QueryTreeNodePtr createResolvedFunction(ContextPtr context, const String & name, DataTypePtr result_type, QueryTreeNodes arguments)
|
||||
{
|
||||
auto function_node = std::make_shared<FunctionNode>(name);
|
||||
auto function = FunctionFactory::instance().get(name, context);
|
||||
function_node->resolveAsFunction(function, result_type);
|
||||
function_node->getArguments().getNodes() = { std::forward<Args>(args)... };
|
||||
function_node->getArgumentsNode() = std::make_shared<ListNode>(std::move(arguments));
|
||||
return function_node;
|
||||
}
|
||||
|
||||
QueryTreeNodePtr createTupleElement(ContextPtr context, DataTypePtr result_type, QueryTreeNodePtr argument, UInt64 idx)
|
||||
{
|
||||
return createResolvedFunction(context, "tupleElement", result_type, argument, std::make_shared<ConstantNode>(idx));
|
||||
}
|
||||
|
||||
QueryTreeNodePtr createSumCount(const FunctionNode & function_node, ContextPtr context)
|
||||
FunctionNodePtr createSumCoundNode(const QueryTreeNodePtr & argument)
|
||||
{
|
||||
auto sum_count_node = std::make_shared<FunctionNode>("sumCount");
|
||||
|
||||
DataTypePtr sum_return_type;
|
||||
DataTypePtr count_return_type;
|
||||
AggregateFunctionProperties properties;
|
||||
auto aggregate_function = AggregateFunctionFactory::instance().get("sumCount", {argument->getResultType()}, {}, properties);
|
||||
|
||||
sum_count_node->resolveAsAggregateFunction(aggregate_function, aggregate_function->getReturnType());
|
||||
|
||||
sum_count_node->getArgumentsNode() = std::make_shared<ListNode>(QueryTreeNodes{argument});
|
||||
return sum_count_node;
|
||||
}
|
||||
|
||||
QueryTreeNodePtr createTupleElementFunction(ContextPtr context, DataTypePtr result_type, QueryTreeNodePtr argument, UInt64 index)
|
||||
{
|
||||
return createResolvedFunction(context, "tupleElement", result_type, {argument, std::make_shared<ConstantNode>(index)});
|
||||
}
|
||||
|
||||
void replaceWithSumCount(QueryTreeNodePtr & node, const FunctionNodePtr & sum_count_node, ContextPtr context)
|
||||
{
|
||||
auto sum_count_result_type = std::dynamic_pointer_cast<const DataTypeTuple>(sum_count_node->getResultType());
|
||||
if (!sum_count_result_type || sum_count_result_type->getElements().size() != 2)
|
||||
{
|
||||
AggregateFunctionProperties properties;
|
||||
auto function_aggregate_function = function_node.getAggregateFunction();
|
||||
|
||||
auto aggregate_function = AggregateFunctionFactory::instance().get("sumCount",
|
||||
function_aggregate_function->getArgumentTypes(),
|
||||
function_aggregate_function->getParameters(),
|
||||
properties);
|
||||
|
||||
sum_count_node->resolveAsAggregateFunction(aggregate_function, aggregate_function->getReturnType());
|
||||
|
||||
if (auto ret_type = std::dynamic_pointer_cast<const DataTypeTuple>(aggregate_function->getReturnType()))
|
||||
{
|
||||
sum_return_type = ret_type->getElement(0);
|
||||
count_return_type = ret_type->getElement(1);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw Exception(ErrorCodes::LOGICAL_ERROR, "Unexpected return type '{}' of sumCount aggregate function",
|
||||
aggregate_function->getReturnType()->getName());
|
||||
}
|
||||
|
||||
sum_count_node->getArgumentsNode() = function_node.getArgumentsNode();
|
||||
throw Exception(ErrorCodes::LOGICAL_ERROR,
|
||||
"Unexpected return type '{}' of function {}, should be tuple of two elements",
|
||||
sum_count_node->getResultType(), sum_count_node->getFunctionName());
|
||||
}
|
||||
|
||||
/// TODO: function_node.getResultType() or sum_return_type/count_return_type.
|
||||
/// Should it be the same?
|
||||
String function_name = node->as<const FunctionNode &>().getFunctionName();
|
||||
|
||||
if (function_node.getFunctionName() == "sum")
|
||||
return createTupleElement(context, function_node.getResultType(), sum_count_node, 1);
|
||||
|
||||
if (function_node.getFunctionName() == "count")
|
||||
return createTupleElement(context, function_node.getResultType(), sum_count_node, 2);
|
||||
|
||||
if (function_node.getFunctionName() == "avg")
|
||||
if (function_name == "sum")
|
||||
{
|
||||
auto sum_result = createTupleElement(context, sum_return_type, sum_count_node, 1);
|
||||
auto count_result = createTupleElement(context, count_return_type, sum_count_node, 2);
|
||||
assert(node->getResultType() == sum_count_result_type->getElement(0));
|
||||
node = createTupleElementFunction(context, node->getResultType(), sum_count_node, 1);
|
||||
}
|
||||
else if (function_name == "count")
|
||||
{
|
||||
assert(node->getResultType() == sum_count_result_type->getElement(1));
|
||||
node = createTupleElementFunction(context, node->getResultType(), sum_count_node, 2);
|
||||
}
|
||||
else if (function_name == "avg")
|
||||
{
|
||||
auto sum_result = createTupleElementFunction(context, sum_count_result_type->getElement(0), sum_count_node, 1);
|
||||
auto count_result = createTupleElementFunction(context, sum_count_result_type->getElement(1), sum_count_node, 2);
|
||||
/// To avoid integer division by zero
|
||||
auto count_float_result = createResolvedFunction(context, "toFloat64", std::make_shared<DataTypeFloat64>(), count_result);
|
||||
return createResolvedFunction(context, "divide", function_node.getResultType(), sum_result, count_float_result);
|
||||
auto count_float_result = createResolvedFunction(context, "toFloat64", std::make_shared<DataTypeFloat64>(), {count_result});
|
||||
node = createResolvedFunction(context, "divide", node->getResultType(), {sum_result, count_float_result});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw Exception(ErrorCodes::LOGICAL_ERROR, "Unsupported function '{}'", function_name);
|
||||
}
|
||||
|
||||
throw Exception(ErrorCodes::LOGICAL_ERROR, "Unsupported function '{}'", function_node.getFunctionName());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void FuseFunctionsPass::run(QueryTreeNodePtr query_tree_node, ContextPtr context)
|
||||
{
|
||||
FuseFunctionsMatcher visitor;
|
||||
FuseFunctionsVisitor visitor;
|
||||
visitor.visit(query_tree_node);
|
||||
|
||||
for (auto & [_, nodes] : visitor.mapping)
|
||||
for (auto & [argument, nodes] : visitor.mapping)
|
||||
{
|
||||
if (nodes.size() < 2)
|
||||
continue;
|
||||
|
||||
auto sum_count_node = createSumCoundNode(argument.node);
|
||||
for (auto * node : nodes)
|
||||
{
|
||||
*node = createSumCount((*node)->as<const FunctionNode &>(), context);
|
||||
replaceWithSumCount(*node, sum_count_node, context);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user