support alias for new analyzer

This commit is contained in:
JackyWoo 2023-07-19 18:46:20 +08:00
parent 71c12bbdc4
commit 08409059cc
2 changed files with 90 additions and 34 deletions

View File

@ -21,36 +21,82 @@ bool matchFnUniq(String func_name)
|| name == "uniqCombined64";
}
bool nodeEquals(const QueryTreeNodePtr & lhs, const QueryTreeNodePtr & rhs)
/// Extract the corresponding projection columns for group by node list.
/// For example:
/// SELECT a as aa, any(b) FROM table group by a; -> aa(ColumnNode)
NamesAndTypes extractProjectionColumnsForGroupBy(const QueryNode * query_node)
{
auto * lhs_node = lhs->as<ColumnNode>();
auto * rhs_node = rhs->as<ColumnNode>();
if (!query_node->hasGroupBy())
return {};
if (lhs_node && rhs_node && lhs_node->getColumn() == rhs_node->getColumn())
return true;
return false;
NamesAndTypes result;
for (const auto & group_by_ele : query_node->getGroupByNode()->getChildren())
{
const auto & projection_columns = query_node->getProjectionColumns();
const auto & projection_nodes = query_node->getProjection().getNodes();
assert(projection_columns.size() == projection_nodes.size());
for (size_t i = 0; i < projection_columns.size(); i++)
{
if (projection_nodes[i]->isEqual(*group_by_ele))
result.push_back(projection_columns[i]);
}
}
return result;
}
bool nodeListEquals(const QueryTreeNodes & lhs, const QueryTreeNodes & rhs)
/// Whether query_columns equals subquery_columns.
/// query_columns: query columns from query
/// subquery_columns: projection columns from subquery
bool nodeListEquals(const QueryTreeNodes & query_columns, const NamesAndTypes & subquery_columns)
{
if (lhs.size() != rhs.size())
if (query_columns.size() != subquery_columns.size())
return false;
for (size_t i = 0; i < lhs.size(); i++)
for (const auto & query_column : query_columns)
{
if (!nodeEquals(lhs[i], rhs[i]))
auto find = std::find_if(
subquery_columns.begin(),
subquery_columns.end(),
[&](const auto & subquery_column) -> bool
{
if (auto * column_node = query_column->as<ColumnNode>())
{
return subquery_column == column_node->getColumn();
}
return false;
});
if (find == subquery_columns.end())
return false;
}
return true;
}
bool nodeListContainsAll(const QueryTreeNodes & lhs, const QueryTreeNodes & rhs)
/// Whether subquery_columns contains all columns in subquery_columns.
/// query_columns: query columns from query
/// subquery_columns: projection columns from subquery
bool nodeListContainsAll(const QueryTreeNodes & query_columns, const NamesAndTypes & subquery_columns)
{
if (lhs.size() < rhs.size())
if (query_columns.size() > subquery_columns.size())
return false;
for (const auto & re : rhs)
for (const auto & query_column : query_columns)
{
auto predicate = [&](const QueryTreeNodePtr & le) { return nodeEquals(le, re); };
if (std::find_if(lhs.begin(), lhs.end(), predicate) == lhs.end())
auto find = std::find_if(
subquery_columns.begin(),
subquery_columns.end(),
[&](const auto & subquery_column) -> bool
{
if (auto * column_node = query_column->as<ColumnNode>())
{
return subquery_column == column_node->getColumn();
}
return false;
});
if (find == subquery_columns.end())
return false;
}
return true;
@ -58,17 +104,14 @@ bool nodeListContainsAll(const QueryTreeNodes & lhs, const QueryTreeNodes & rhs)
}
class UniqToCountVisitor : public InDepthQueryTreeVisitorWithContext<UniqToCountVisitor>
class UniqToCountVisitor : public InDepthQueryTreeVisitor<UniqToCountVisitor>
{
public:
using Base = InDepthQueryTreeVisitorWithContext<UniqToCountVisitor>;
using Base = InDepthQueryTreeVisitor<UniqToCountVisitor>;
using Base::Base;
void visitImpl(QueryTreeNodePtr & node)
{
if (!getSettings().optimize_uniq_to_count)
return;
auto * query_node = node->as<QueryNode>();
if (!query_node)
return;
@ -100,9 +143,11 @@ public:
{
if (!subquery_node->isDistinct())
return false;
/// uniq expression list == subquery group by expression list
if (!nodeListEquals(uniq_arguments_nodes, subquery_node->getProjection().getNodes()))
/// uniq expression list == subquery projection columns
if (!nodeListEquals(uniq_arguments_nodes, subquery_node->getProjectionColumns()))
return false;
return true;
};
@ -111,12 +156,17 @@ public:
{
if (!subquery_node->hasGroupBy())
return false;
/// uniq argument node list == subquery group by node list
if (!nodeListEquals(uniq_arguments_nodes, subquery_node->getGroupByNode()->getChildren()))
auto group_by_columns = extractProjectionColumnsForGroupBy(subquery_node);
if (!nodeListEquals(uniq_arguments_nodes, group_by_columns))
return false;
/// subquery select node list must contain all columns in uniq argument node list
if (!nodeListContainsAll(subquery_node->getProjection().getNodes(), uniq_arguments_nodes))
/// subquery projection columns must contain all columns in uniq argument node list
if (!nodeListContainsAll(uniq_arguments_nodes, subquery_node->getProjectionColumns()))
return false;
return true;
};
@ -125,8 +175,11 @@ public:
{
AggregateFunctionProperties properties;
auto aggregate_function = AggregateFunctionFactory::instance().get("count", {}, {}, properties);
function_node->resolveAsAggregateFunction(std::move(aggregate_function));
function_node->getArguments().getNodes().clear();
/// Update projection columns
query_node->resolveProjectionColumns({{"count()", function_node->getResultType()}});
}
}
@ -135,7 +188,10 @@ public:
void UniqToCountPass::run(QueryTreeNodePtr query_tree_node, ContextPtr context)
{
UniqToCountVisitor visitor(std::move(context));
if (!context->getSettings().optimize_uniq_to_count)
return;
UniqToCountVisitor visitor;
visitor.visit(query_tree_node);
}

View File

@ -83,13 +83,13 @@ def test_rewrite_distinct(started_cluster):
)
# test select expression alias
check_by_old_analyzer(
"SELECT uniq(a) FROM (SELECT DISTINCT test_rewrite_uniq_to_count.a as alias_of_a FROM test_rewrite_uniq_to_count) t",
check(
"SELECT uniq(alias_of_a) FROM (SELECT DISTINCT test_rewrite_uniq_to_count.a as alias_of_a FROM test_rewrite_uniq_to_count) t",
3,
)
# test select expression alias
check_by_old_analyzer(
check(
"SELECT uniq(alias_of_a) FROM (SELECT DISTINCT a as alias_of_a FROM test_rewrite_uniq_to_count) t",
3,
)
@ -109,19 +109,19 @@ def test_rewrite_group_by(started_cluster):
)
# test select expression alias
check_by_old_analyzer(
check(
"SELECT uniq(t.alias_of_a) FROM (SELECT a as alias_of_a, sum(b) FROM test_rewrite_uniq_to_count GROUP BY a) t",
3,
)
# test select expression alias
check_by_old_analyzer(
"SELECT uniq(t.a) FROM (SELECT a as alias_of_a, sum(b) FROM test_rewrite_uniq_to_count GROUP BY alias_of_a) t",
check(
"SELECT uniq(t.alias_of_a) FROM (SELECT a as alias_of_a, sum(b) FROM test_rewrite_uniq_to_count GROUP BY alias_of_a) t",
3,
)
# test select expression alias
check_by_old_analyzer(
"SELECT uniq(t.alias_of_a) FROM (SELECT a as alias_of_a, sum(b) FROM test_rewrite_uniq_to_count GROUP BY alias_of_a) t",
check(
"SELECT uniq(t.alias_of_a) FROM (SELECT a as alias_of_a, sum(b) FROM test_rewrite_uniq_to_count GROUP BY a) t",
3,
)