Merge pull request #71855 from ClickHouse/vdimir/grouping_sets_aliases

Fix GROUPING function error when input is ALIAS on distribured table
This commit is contained in:
Vladimir Cherkasov 2024-11-19 13:23:05 +00:00 committed by GitHub
commit 6865d1e383
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 110 additions and 59 deletions

View File

@ -12,9 +12,13 @@
#include <Analyzer/HashUtils.h>
#include <Analyzer/FunctionNode.h>
#include <Analyzer/ColumnNode.h>
#include <Analyzer/ValidationUtils.h>
#include <ranges>
namespace DB
{
namespace Setting
{
extern const SettingsBool force_grouping_standard_compatibility;
@ -27,6 +31,26 @@ namespace ErrorCodes
extern const int LOGICAL_ERROR;
}
struct GroupByKeyComparator
{
GroupByKeyComparator(QueryTreeNodePtr node_) /// NOLINT
: node(std::move(node_))
, hash(node->getTreeHash({.compare_aliases = false, .compare_types = true}))
{}
bool operator==(const GroupByKeyComparator & other) const { return hash == other.hash && compareGroupByKeys(node, other.node); }
bool operator!=(const GroupByKeyComparator & other) const { return !(*this == other); }
struct Hasher { size_t operator()(const GroupByKeyComparator & key) const { return key.hash.low64; } };
QueryTreeNodePtr node = nullptr;
CityHash_v1_0_2::uint128 hash;
};
template <typename Value>
using AggredationKeyNodeMap = std::unordered_map<GroupByKeyComparator, Value, GroupByKeyComparator::Hasher>;
namespace
{
@ -42,7 +66,7 @@ class GroupingFunctionResolveVisitor : public InDepthQueryTreeVisitorWithContext
{
public:
GroupingFunctionResolveVisitor(GroupByKind group_by_kind_,
QueryTreeNodePtrWithHashMap<size_t> aggregation_key_to_index_,
AggredationKeyNodeMap<size_t> aggregation_key_to_index_,
ColumnNumbersList grouping_sets_keys_indices_,
ContextPtr context_)
: InDepthQueryTreeVisitorWithContext(std::move(context_))
@ -67,9 +91,12 @@ public:
{
auto it = aggregation_key_to_index.find(argument);
if (it == aggregation_key_to_index.end())
{
throw Exception(ErrorCodes::BAD_ARGUMENTS,
"Argument {} of GROUPING function is not a part of GROUP BY clause",
argument->formatASTForErrorMessage());
"Argument {} of GROUPING function is not a part of GROUP BY clause [{}]",
argument->formatASTForErrorMessage(),
fmt::join(aggregation_key_to_index | std::views::transform([](const auto & e) { return e.first.node->formatASTForErrorMessage(); }), ", "));
}
arguments_indexes.push_back(it->second);
}
@ -133,7 +160,7 @@ public:
private:
GroupByKind group_by_kind;
QueryTreeNodePtrWithHashMap<size_t> aggregation_key_to_index;
AggredationKeyNodeMap<size_t> aggregation_key_to_index;
ColumnNumbersList grouping_sets_keys_indexes;
};
@ -142,7 +169,7 @@ void resolveGroupingFunctions(QueryTreeNodePtr & query_node, ContextPtr context)
auto & query_node_typed = query_node->as<QueryNode &>();
size_t aggregation_node_index = 0;
QueryTreeNodePtrWithHashMap<size_t> aggregation_key_to_index;
AggredationKeyNodeMap<size_t> aggregation_key_to_index;
std::vector<QueryTreeNodes> grouping_sets_used_aggregation_keys_list;

View File

@ -80,6 +80,65 @@ void validateFilters(const QueryTreeNodePtr & query_node)
validateFilter(query_node_typed.getQualify(), "QUALIFY", query_node);
}
bool areColumnSourcesEqual(const QueryTreeNodePtr & lhs, const QueryTreeNodePtr & rhs)
{
using NodePair = std::pair<const IQueryTreeNode *, const IQueryTreeNode *>;
std::vector<NodePair> nodes_to_process;
nodes_to_process.emplace_back(lhs.get(), rhs.get());
while (!nodes_to_process.empty())
{
const auto [lhs_node, rhs_node] = nodes_to_process.back();
nodes_to_process.pop_back();
if (lhs_node->getNodeType() != rhs_node->getNodeType())
return false;
if (lhs_node->getNodeType() == QueryTreeNodeType::COLUMN)
{
const auto * lhs_column_node = lhs_node->as<ColumnNode>();
const auto * rhs_column_node = rhs_node->as<ColumnNode>();
if (!lhs_column_node->getColumnSource()->isEqual(*rhs_column_node->getColumnSource()))
return false;
}
const auto & lhs_children = lhs_node->getChildren();
const auto & rhs_children = rhs_node->getChildren();
if (lhs_children.size() != rhs_children.size())
return false;
for (size_t i = 0; i < lhs_children.size(); ++i)
{
const auto & lhs_child = lhs_children[i];
const auto & rhs_child = rhs_children[i];
if (!lhs_child && !rhs_child)
continue;
if (lhs_child && !rhs_child)
return false;
if (!lhs_child && rhs_child)
return false;
nodes_to_process.emplace_back(lhs_child.get(), rhs_child.get());
}
}
return true;
}
bool compareGroupByKeys(const QueryTreeNodePtr & node, const QueryTreeNodePtr & group_by_key_node)
{
if (node->isEqual(*group_by_key_node, {.compare_aliases = false}))
{
/** Column sources should be compared with aliases for correct GROUP BY keys validation,
* otherwise t2.x and t1.x will be considered as the same column:
* SELECT t2.x FROM t1 JOIN t1 as t2 ON t1.x = t2.x GROUP BY t1.x;
*/
if (areColumnSourcesEqual(node, group_by_key_node))
return true;
}
return false;
}
namespace
{
@ -154,51 +213,6 @@ public:
private:
static bool areColumnSourcesEqual(const QueryTreeNodePtr & lhs, const QueryTreeNodePtr & rhs)
{
using NodePair = std::pair<const IQueryTreeNode *, const IQueryTreeNode *>;
std::vector<NodePair> nodes_to_process;
nodes_to_process.emplace_back(lhs.get(), rhs.get());
while (!nodes_to_process.empty())
{
const auto [lhs_node, rhs_node] = nodes_to_process.back();
nodes_to_process.pop_back();
if (lhs_node->getNodeType() != rhs_node->getNodeType())
return false;
if (lhs_node->getNodeType() == QueryTreeNodeType::COLUMN)
{
const auto * lhs_column_node = lhs_node->as<ColumnNode>();
const auto * rhs_column_node = rhs_node->as<ColumnNode>();
if (!lhs_column_node->getColumnSource()->isEqual(*rhs_column_node->getColumnSource()))
return false;
}
const auto & lhs_children = lhs_node->getChildren();
const auto & rhs_children = rhs_node->getChildren();
if (lhs_children.size() != rhs_children.size())
return false;
for (size_t i = 0; i < lhs_children.size(); ++i)
{
const auto & lhs_child = lhs_children[i];
const auto & rhs_child = rhs_children[i];
if (!lhs_child && !rhs_child)
continue;
if (lhs_child && !rhs_child)
return false;
if (!lhs_child && rhs_child)
return false;
nodes_to_process.emplace_back(lhs_child.get(), rhs_child.get());
}
}
return true;
}
bool nodeIsAggregateFunctionOrInGroupByKeys(const QueryTreeNodePtr & node) const
{
if (auto * function_node = node->as<FunctionNode>())
@ -207,16 +221,9 @@ private:
for (const auto & group_by_key_node : group_by_keys_nodes)
{
if (node->isEqual(*group_by_key_node, {.compare_aliases = false}))
{
/** Column sources should be compared with aliases for correct GROUP BY keys validation,
* otherwise t2.x and t1.x will be considered as the same column:
* SELECT t2.x FROM t1 JOIN t1 as t2 ON t1.x = t2.x GROUP BY t1.x;
*/
if (areColumnSourcesEqual(node, group_by_key_node))
if (compareGroupByKeys(node, group_by_key_node))
return true;
}
}
return false;
}

View File

@ -41,4 +41,10 @@ void validateTreeSize(const QueryTreeNodePtr & node,
size_t max_size,
std::unordered_map<QueryTreeNodePtr, size_t> & node_to_tree_size);
/** Compare node with group by key node.
* Such comparison does not take into account aliases, but checks types and column sources.
*/
bool compareGroupByKeys(const QueryTreeNodePtr & node, const QueryTreeNodePtr & group_by_key_node);
}

View File

@ -0,0 +1,2 @@
LOW 2 0
HIGH 1 0

View File

@ -0,0 +1,9 @@
DROP TABLE IF EXISTS users;
CREATE TABLE users (name String, score UInt8, user_level String ALIAS multiIf(score <= 3, 'LOW', score <= 6, 'MEDIUM', 'HIGH') ) ENGINE=MergeTree ORDER BY name;
INSERT INTO users VALUES ('a',1),('b',2),('c', 50);
SELECT user_level as level_alias, uniq(name) as name_alias, grouping(level_alias) as _totals
FROM remote('127.0.0.{1,2}', currentDatabase(), users)
GROUP BY GROUPING SETS ((level_alias))
ORDER BY name_alias DESC;