This commit is contained in:
kssenii 2021-08-16 15:03:55 +03:00
parent c01be0354b
commit 844e04e341
9 changed files with 83 additions and 17 deletions

View File

@ -31,10 +31,12 @@ namespace ErrorCodes
} }
InterpreterSelectWithUnionQuery::InterpreterSelectWithUnionQuery( InterpreterSelectWithUnionQuery::InterpreterSelectWithUnionQuery(
const ASTPtr & query_ptr_, ContextPtr context_, const SelectQueryOptions & options_, const Names & required_result_column_names) const ASTPtr & query_ptr_, ContextPtr context_,
const SelectQueryOptions & options_, const Names & required_result_column_names)
: IInterpreterUnionOrSelectQuery(query_ptr_, context_, options_) : IInterpreterUnionOrSelectQuery(query_ptr_, context_, options_)
{ {
ASTSelectWithUnionQuery * ast = query_ptr->as<ASTSelectWithUnionQuery>(); ASTSelectWithUnionQuery * ast = query_ptr->as<ASTSelectWithUnionQuery>();
bool require_full_header = ast->hasNonDefaultUnionMode();
const Settings & settings = context->getSettingsRef(); const Settings & settings = context->getSettingsRef();
if (options.subquery_depth == 0 && (settings.limit > 0 || settings.offset > 0)) if (options.subquery_depth == 0 && (settings.limit > 0 || settings.offset > 0))
@ -51,10 +53,7 @@ InterpreterSelectWithUnionQuery::InterpreterSelectWithUnionQuery(
nested_interpreters.reserve(num_children); nested_interpreters.reserve(num_children);
std::vector<Names> required_result_column_names_for_other_selects(num_children); std::vector<Names> required_result_column_names_for_other_selects(num_children);
/// If it is UNION DISTINCT, do not filter by required_result_columns. if (!require_full_header && !required_result_column_names.empty() && num_children > 1)
bool is_union_distinct = ast->union_mode == ASTSelectWithUnionQuery::Mode::DISTINCT;
if (!required_result_column_names.empty() && num_children > 1 && !is_union_distinct)
{ {
/// Result header if there are no filtering by 'required_result_column_names'. /// Result header if there are no filtering by 'required_result_column_names'.
/// We use it to determine positions of 'required_result_column_names' in SELECT clause. /// We use it to determine positions of 'required_result_column_names' in SELECT clause.
@ -133,10 +132,10 @@ InterpreterSelectWithUnionQuery::InterpreterSelectWithUnionQuery(
for (size_t query_num = 0; query_num < num_children; ++query_num) for (size_t query_num = 0; query_num < num_children; ++query_num)
{ {
const Names & current_required_result_column_names const Names & current_required_result_column_names
= !is_union_distinct && query_num == 0 ? required_result_column_names : required_result_column_names_for_other_selects[query_num]; = query_num == 0 ? required_result_column_names : required_result_column_names_for_other_selects[query_num];
nested_interpreters.emplace_back( nested_interpreters.emplace_back(
buildCurrentChildInterpreter(ast->list_of_selects->children.at(query_num), current_required_result_column_names)); buildCurrentChildInterpreter(ast->list_of_selects->children.at(query_num), require_full_header ? Names() : current_required_result_column_names));
} }
/// Determine structure of the result. /// Determine structure of the result.

View File

@ -48,7 +48,6 @@ private:
std::unique_ptr<IInterpreterUnionOrSelectQuery> std::unique_ptr<IInterpreterUnionOrSelectQuery>
buildCurrentChildInterpreter(const ASTPtr & ast_ptr_, const Names & current_required_result_column_names); buildCurrentChildInterpreter(const ASTPtr & ast_ptr_, const Names & current_required_result_column_names);
}; };
} }

View File

@ -11,11 +11,11 @@ namespace ErrorCodes
extern const int EXPECTED_ALL_OR_DISTINCT; extern const int EXPECTED_ALL_OR_DISTINCT;
} }
void NormalizeSelectWithUnionQueryMatcher::getSelectsFromUnionListNode(ASTPtr & ast_select, ASTs & selects) void NormalizeSelectWithUnionQueryMatcher::getSelectsFromUnionListNode(ASTPtr ast_select, ASTs & selects)
{ {
if (auto * inner_union = ast_select->as<ASTSelectWithUnionQuery>()) if (const auto * inner_union = ast_select->as<ASTSelectWithUnionQuery>())
{ {
for (auto & child : inner_union->list_of_selects->children) for (const auto & child : inner_union->list_of_selects->children)
getSelectsFromUnionListNode(child, selects); getSelectsFromUnionListNode(child, selects);
return; return;
@ -34,11 +34,28 @@ void NormalizeSelectWithUnionQueryMatcher::visit(ASTSelectWithUnionQuery & ast,
{ {
auto & union_modes = ast.list_of_modes; auto & union_modes = ast.list_of_modes;
ASTs selects; ASTs selects;
auto & select_list = ast.list_of_selects->children; const auto & select_list = ast.list_of_selects->children;
if (select_list.empty())
throw Exception(ErrorCodes::LOGICAL_ERROR, "Got empty list of selects for ASTSelectWithUnionQuery");
/// Since nodes are traversed from bottom to top, we can also collect union modes from chidlren up to parents.
ASTSelectWithUnionQuery::UnionModesSet current_set_of_modes;
bool distinct_found = false;
int i; int i;
for (i = union_modes.size() - 1; i >= 0; --i) for (i = union_modes.size() - 1; i >= 0; --i)
{ {
current_set_of_modes.insert(union_modes[i]);
if (const auto * union_ast = typeid_cast<const ASTSelectWithUnionQuery *>(select_list[i + 1].get()))
{
const auto & current_select_modes = union_ast->set_of_modes;
current_set_of_modes.insert(current_select_modes.begin(), current_select_modes.end());
}
if (distinct_found)
continue;
/// Rewrite UNION Mode /// Rewrite UNION Mode
if (union_modes[i] == ASTSelectWithUnionQuery::Mode::Unspecified) if (union_modes[i] == ASTSelectWithUnionQuery::Mode::Unspecified)
{ {
@ -80,12 +97,18 @@ void NormalizeSelectWithUnionQueryMatcher::visit(ASTSelectWithUnionQuery & ast,
distinct_list->union_mode = ASTSelectWithUnionQuery::Mode::DISTINCT; distinct_list->union_mode = ASTSelectWithUnionQuery::Mode::DISTINCT;
distinct_list->is_normalized = true; distinct_list->is_normalized = true;
selects.push_back(std::move(distinct_list)); selects.push_back(std::move(distinct_list));
break; distinct_found = true;
} }
} }
if (const auto * union_ast = typeid_cast<const ASTSelectWithUnionQuery *>(select_list[0].get()))
{
const auto & current_select_modes = union_ast->set_of_modes;
current_set_of_modes.insert(current_select_modes.begin(), current_select_modes.end());
}
/// No UNION DISTINCT or only one child in select_list /// No UNION DISTINCT or only one child in select_list
if (i == -1) if (!distinct_found)
{ {
if (auto * inner_union = select_list[0]->as<ASTSelectWithUnionQuery>(); if (auto * inner_union = select_list[0]->as<ASTSelectWithUnionQuery>();
inner_union && inner_union->union_mode == ASTSelectWithUnionQuery::Mode::ALL) inner_union && inner_union->union_mode == ASTSelectWithUnionQuery::Mode::ALL)
@ -103,6 +126,7 @@ void NormalizeSelectWithUnionQueryMatcher::visit(ASTSelectWithUnionQuery & ast,
if (selects.size() == 1 && selects[0]->as<ASTSelectWithUnionQuery>()) if (selects.size() == 1 && selects[0]->as<ASTSelectWithUnionQuery>())
{ {
ast = *(selects[0]->as<ASTSelectWithUnionQuery>()); ast = *(selects[0]->as<ASTSelectWithUnionQuery>());
ast.set_of_modes = std::move(current_set_of_modes);
return; return;
} }
@ -111,6 +135,7 @@ void NormalizeSelectWithUnionQueryMatcher::visit(ASTSelectWithUnionQuery & ast,
ast.is_normalized = true; ast.is_normalized = true;
ast.union_mode = ASTSelectWithUnionQuery::Mode::ALL; ast.union_mode = ASTSelectWithUnionQuery::Mode::ALL;
ast.set_of_modes = std::move(current_set_of_modes);
ast.list_of_selects->children = std::move(selects); ast.list_of_selects->children = std::move(selects);
} }

View File

@ -21,7 +21,7 @@ public:
const UnionMode & union_default_mode; const UnionMode & union_default_mode;
}; };
static void getSelectsFromUnionListNode(ASTPtr & ast_select, ASTs & selects); static void getSelectsFromUnionListNode(ASTPtr ast_select, ASTs & selects);
static void visit(ASTPtr & ast, Data &); static void visit(ASTPtr & ast, Data &);
static void visit(ASTSelectWithUnionQuery &, Data &); static void visit(ASTSelectWithUnionQuery &, Data &);

View File

@ -20,6 +20,7 @@ ASTPtr ASTSelectWithUnionQuery::clone() const
res->union_mode = union_mode; res->union_mode = union_mode;
res->list_of_modes = list_of_modes; res->list_of_modes = list_of_modes;
res->set_of_modes = set_of_modes;
cloneOutputOptions(*res); cloneOutputOptions(*res);
return res; return res;
@ -71,4 +72,10 @@ void ASTSelectWithUnionQuery::formatQueryImpl(const FormatSettings & settings, F
} }
} }
bool ASTSelectWithUnionQuery::hasNonDefaultUnionMode() const
{
return set_of_modes.contains(Mode::DISTINCT) || set_of_modes.contains(Mode::INTERSECT) || set_of_modes.contains(Mode::EXCEPT);
}
} }

View File

@ -14,6 +14,7 @@ public:
String getID(char) const override { return "SelectWithUnionQuery"; } String getID(char) const override { return "SelectWithUnionQuery"; }
ASTPtr clone() const override; ASTPtr clone() const override;
void formatQueryImpl(const FormatSettings & settings, FormatState & state, FormatStateStacked frame) const override; void formatQueryImpl(const FormatSettings & settings, FormatState & state, FormatStateStacked frame) const override;
const char * getQueryKindString() const override { return "Select"; } const char * getQueryKindString() const override { return "Select"; }
@ -28,6 +29,7 @@ public:
}; };
using UnionModes = std::vector<Mode>; using UnionModes = std::vector<Mode>;
using UnionModesSet = std::unordered_set<Mode>;
Mode union_mode; Mode union_mode;
@ -36,6 +38,11 @@ public:
bool is_normalized = false; bool is_normalized = false;
ASTPtr list_of_selects; ASTPtr list_of_selects;
UnionModesSet set_of_modes;
/// Consider any mode other than ALL as non-default.
bool hasNonDefaultUnionMode() const;
}; };
} }

View File

@ -1,9 +1,26 @@
-- { echo } -- { echo }
select count() from (select * from test union distinct select * from test); select count() from (select * from test union distinct select * from test);
5
select count() from (select * from test union distinct select * from test union all select * from test);
10
select count() from (select * from test union distinct select * from test except select * from test where name = '3');
4
select count() from (select * from test intersect (select * from test where toUInt8(name) < 4) union distinct (select * from test where name = '5' or name = '1') except select * from test where name = '3');
3
with (select count() from (select * from test union distinct select * from test except select * from test where toUInt8(name) > 3)) as max
select count() from (select * from test union all select * from test where toUInt8(name) < max);
7
with (select count() from (select * from test union distinct select * from test except select * from test where toUInt8(name) > 3)) as max
select count() from (select * from test except select * from test where toUInt8(name) < max);
3 3
select uuid from test union distinct select uuid from test; select uuid from test union distinct select uuid from test;
00000000-0000-0000-0000-000000000000 00000000-0000-0000-0000-000000000000
select uuid from test union distinct select uuid from test union all select uuid from test where name = '1';
00000000-0000-0000-0000-000000000000
00000000-0000-0000-0000-000000000000
select uuid from (select * from test union distinct select * from test); select uuid from (select * from test union distinct select * from test);
00000000-0000-0000-0000-000000000000 00000000-0000-0000-0000-000000000000
00000000-0000-0000-0000-000000000000 00000000-0000-0000-0000-000000000000
00000000-0000-0000-0000-000000000000 00000000-0000-0000-0000-000000000000
00000000-0000-0000-0000-000000000000
00000000-0000-0000-0000-000000000000

View File

@ -3,8 +3,21 @@ create table test (name String, uuid UUID) engine=Memory();
insert into test select '1', '00000000-0000-0000-0000-000000000000'; insert into test select '1', '00000000-0000-0000-0000-000000000000';
insert into test select '2', '00000000-0000-0000-0000-000000000000'; insert into test select '2', '00000000-0000-0000-0000-000000000000';
insert into test select '3', '00000000-0000-0000-0000-000000000000'; insert into test select '3', '00000000-0000-0000-0000-000000000000';
insert into test select '4', '00000000-0000-0000-0000-000000000000';
insert into test select '5', '00000000-0000-0000-0000-000000000000';
-- { echo } -- { echo }
select count() from (select * from test union distinct select * from test); select count() from (select * from test union distinct select * from test);
select count() from (select * from test union distinct select * from test union all select * from test);
select count() from (select * from test union distinct select * from test except select * from test where name = '3');
select count() from (select * from test intersect (select * from test where toUInt8(name) < 4) union distinct (select * from test where name = '5' or name = '1') except select * from test where name = '3');
with (select count() from (select * from test union distinct select * from test except select * from test where toUInt8(name) > 3)) as max
select count() from (select * from test union all select * from test where toUInt8(name) < max);
with (select count() from (select * from test union distinct select * from test except select * from test where toUInt8(name) > 3)) as max
select count() from (select * from test except select * from test where toUInt8(name) < max);
select uuid from test union distinct select uuid from test; select uuid from test union distinct select uuid from test;
select uuid from test union distinct select uuid from test union all select uuid from test where name = '1';
select uuid from (select * from test union distinct select * from test); select uuid from (select * from test union distinct select * from test);

View File

@ -507,7 +507,6 @@
"01532_execute_merges_on_single_replica", /// static zk path "01532_execute_merges_on_single_replica", /// static zk path
"01530_drop_database_atomic_sync", /// creates database "01530_drop_database_atomic_sync", /// creates database
"02001_add_default_database_to_system_users", ///create user "02001_add_default_database_to_system_users", ///create user
"02002_row_level_filter_bug", ///create user "02002_row_level_filter_bug" ///create user
"02008_test_union_distinct_in_subquery" /// create database
] ]
} }