This commit is contained in:
kssenii 2021-08-16 15:03:55 +03:00
parent c01be0354b
commit 844e04e341
9 changed files with 83 additions and 17 deletions

View File

@ -31,10 +31,12 @@ namespace ErrorCodes
}
InterpreterSelectWithUnionQuery::InterpreterSelectWithUnionQuery(
const ASTPtr & query_ptr_, ContextPtr context_, const SelectQueryOptions & options_, const Names & required_result_column_names)
const ASTPtr & query_ptr_, ContextPtr context_,
const SelectQueryOptions & options_, const Names & required_result_column_names)
: IInterpreterUnionOrSelectQuery(query_ptr_, context_, options_)
{
ASTSelectWithUnionQuery * ast = query_ptr->as<ASTSelectWithUnionQuery>();
bool require_full_header = ast->hasNonDefaultUnionMode();
const Settings & settings = context->getSettingsRef();
if (options.subquery_depth == 0 && (settings.limit > 0 || settings.offset > 0))
@ -51,10 +53,7 @@ InterpreterSelectWithUnionQuery::InterpreterSelectWithUnionQuery(
nested_interpreters.reserve(num_children);
std::vector<Names> required_result_column_names_for_other_selects(num_children);
/// If it is UNION DISTINCT, do not filter by required_result_columns.
bool is_union_distinct = ast->union_mode == ASTSelectWithUnionQuery::Mode::DISTINCT;
if (!required_result_column_names.empty() && num_children > 1 && !is_union_distinct)
if (!require_full_header && !required_result_column_names.empty() && num_children > 1)
{
/// Result header if there are no filtering by 'required_result_column_names'.
/// We use it to determine positions of 'required_result_column_names' in SELECT clause.
@ -133,10 +132,10 @@ InterpreterSelectWithUnionQuery::InterpreterSelectWithUnionQuery(
for (size_t query_num = 0; query_num < num_children; ++query_num)
{
const Names & current_required_result_column_names
= !is_union_distinct && query_num == 0 ? required_result_column_names : required_result_column_names_for_other_selects[query_num];
= query_num == 0 ? required_result_column_names : required_result_column_names_for_other_selects[query_num];
nested_interpreters.emplace_back(
buildCurrentChildInterpreter(ast->list_of_selects->children.at(query_num), current_required_result_column_names));
buildCurrentChildInterpreter(ast->list_of_selects->children.at(query_num), require_full_header ? Names() : current_required_result_column_names));
}
/// Determine structure of the result.

View File

@ -48,7 +48,6 @@ private:
std::unique_ptr<IInterpreterUnionOrSelectQuery>
buildCurrentChildInterpreter(const ASTPtr & ast_ptr_, const Names & current_required_result_column_names);
};
}

View File

@ -11,11 +11,11 @@ namespace ErrorCodes
extern const int EXPECTED_ALL_OR_DISTINCT;
}
void NormalizeSelectWithUnionQueryMatcher::getSelectsFromUnionListNode(ASTPtr & ast_select, ASTs & selects)
void NormalizeSelectWithUnionQueryMatcher::getSelectsFromUnionListNode(ASTPtr ast_select, ASTs & selects)
{
if (auto * inner_union = ast_select->as<ASTSelectWithUnionQuery>())
if (const auto * inner_union = ast_select->as<ASTSelectWithUnionQuery>())
{
for (auto & child : inner_union->list_of_selects->children)
for (const auto & child : inner_union->list_of_selects->children)
getSelectsFromUnionListNode(child, selects);
return;
@ -34,11 +34,28 @@ void NormalizeSelectWithUnionQueryMatcher::visit(ASTSelectWithUnionQuery & ast,
{
auto & union_modes = ast.list_of_modes;
ASTs selects;
auto & select_list = ast.list_of_selects->children;
const auto & select_list = ast.list_of_selects->children;
if (select_list.empty())
throw Exception(ErrorCodes::LOGICAL_ERROR, "Got empty list of selects for ASTSelectWithUnionQuery");
/// Since nodes are traversed from bottom to top, we can also collect union modes from chidlren up to parents.
ASTSelectWithUnionQuery::UnionModesSet current_set_of_modes;
bool distinct_found = false;
int i;
for (i = union_modes.size() - 1; i >= 0; --i)
{
current_set_of_modes.insert(union_modes[i]);
if (const auto * union_ast = typeid_cast<const ASTSelectWithUnionQuery *>(select_list[i + 1].get()))
{
const auto & current_select_modes = union_ast->set_of_modes;
current_set_of_modes.insert(current_select_modes.begin(), current_select_modes.end());
}
if (distinct_found)
continue;
/// Rewrite UNION Mode
if (union_modes[i] == ASTSelectWithUnionQuery::Mode::Unspecified)
{
@ -80,12 +97,18 @@ void NormalizeSelectWithUnionQueryMatcher::visit(ASTSelectWithUnionQuery & ast,
distinct_list->union_mode = ASTSelectWithUnionQuery::Mode::DISTINCT;
distinct_list->is_normalized = true;
selects.push_back(std::move(distinct_list));
break;
distinct_found = true;
}
}
if (const auto * union_ast = typeid_cast<const ASTSelectWithUnionQuery *>(select_list[0].get()))
{
const auto & current_select_modes = union_ast->set_of_modes;
current_set_of_modes.insert(current_select_modes.begin(), current_select_modes.end());
}
/// No UNION DISTINCT or only one child in select_list
if (i == -1)
if (!distinct_found)
{
if (auto * inner_union = select_list[0]->as<ASTSelectWithUnionQuery>();
inner_union && inner_union->union_mode == ASTSelectWithUnionQuery::Mode::ALL)
@ -103,6 +126,7 @@ void NormalizeSelectWithUnionQueryMatcher::visit(ASTSelectWithUnionQuery & ast,
if (selects.size() == 1 && selects[0]->as<ASTSelectWithUnionQuery>())
{
ast = *(selects[0]->as<ASTSelectWithUnionQuery>());
ast.set_of_modes = std::move(current_set_of_modes);
return;
}
@ -111,6 +135,7 @@ void NormalizeSelectWithUnionQueryMatcher::visit(ASTSelectWithUnionQuery & ast,
ast.is_normalized = true;
ast.union_mode = ASTSelectWithUnionQuery::Mode::ALL;
ast.set_of_modes = std::move(current_set_of_modes);
ast.list_of_selects->children = std::move(selects);
}

View File

@ -21,7 +21,7 @@ public:
const UnionMode & union_default_mode;
};
static void getSelectsFromUnionListNode(ASTPtr & ast_select, ASTs & selects);
static void getSelectsFromUnionListNode(ASTPtr ast_select, ASTs & selects);
static void visit(ASTPtr & ast, Data &);
static void visit(ASTSelectWithUnionQuery &, Data &);

View File

@ -20,6 +20,7 @@ ASTPtr ASTSelectWithUnionQuery::clone() const
res->union_mode = union_mode;
res->list_of_modes = list_of_modes;
res->set_of_modes = set_of_modes;
cloneOutputOptions(*res);
return res;
@ -71,4 +72,10 @@ void ASTSelectWithUnionQuery::formatQueryImpl(const FormatSettings & settings, F
}
}
bool ASTSelectWithUnionQuery::hasNonDefaultUnionMode() const
{
return set_of_modes.contains(Mode::DISTINCT) || set_of_modes.contains(Mode::INTERSECT) || set_of_modes.contains(Mode::EXCEPT);
}
}

View File

@ -14,6 +14,7 @@ public:
String getID(char) const override { return "SelectWithUnionQuery"; }
ASTPtr clone() const override;
void formatQueryImpl(const FormatSettings & settings, FormatState & state, FormatStateStacked frame) const override;
const char * getQueryKindString() const override { return "Select"; }
@ -28,6 +29,7 @@ public:
};
using UnionModes = std::vector<Mode>;
using UnionModesSet = std::unordered_set<Mode>;
Mode union_mode;
@ -36,6 +38,11 @@ public:
bool is_normalized = false;
ASTPtr list_of_selects;
UnionModesSet set_of_modes;
/// Consider any mode other than ALL as non-default.
bool hasNonDefaultUnionMode() const;
};
}

View File

@ -1,9 +1,26 @@
-- { echo }
select count() from (select * from test union distinct select * from test);
5
select count() from (select * from test union distinct select * from test union all select * from test);
10
select count() from (select * from test union distinct select * from test except select * from test where name = '3');
4
select count() from (select * from test intersect (select * from test where toUInt8(name) < 4) union distinct (select * from test where name = '5' or name = '1') except select * from test where name = '3');
3
with (select count() from (select * from test union distinct select * from test except select * from test where toUInt8(name) > 3)) as max
select count() from (select * from test union all select * from test where toUInt8(name) < max);
7
with (select count() from (select * from test union distinct select * from test except select * from test where toUInt8(name) > 3)) as max
select count() from (select * from test except select * from test where toUInt8(name) < max);
3
select uuid from test union distinct select uuid from test;
00000000-0000-0000-0000-000000000000
select uuid from test union distinct select uuid from test union all select uuid from test where name = '1';
00000000-0000-0000-0000-000000000000
00000000-0000-0000-0000-000000000000
select uuid from (select * from test union distinct select * from test);
00000000-0000-0000-0000-000000000000
00000000-0000-0000-0000-000000000000
00000000-0000-0000-0000-000000000000
00000000-0000-0000-0000-000000000000
00000000-0000-0000-0000-000000000000

View File

@ -3,8 +3,21 @@ create table test (name String, uuid UUID) engine=Memory();
insert into test select '1', '00000000-0000-0000-0000-000000000000';
insert into test select '2', '00000000-0000-0000-0000-000000000000';
insert into test select '3', '00000000-0000-0000-0000-000000000000';
insert into test select '4', '00000000-0000-0000-0000-000000000000';
insert into test select '5', '00000000-0000-0000-0000-000000000000';
-- { echo }
select count() from (select * from test union distinct select * from test);
select count() from (select * from test union distinct select * from test union all select * from test);
select count() from (select * from test union distinct select * from test except select * from test where name = '3');
select count() from (select * from test intersect (select * from test where toUInt8(name) < 4) union distinct (select * from test where name = '5' or name = '1') except select * from test where name = '3');
with (select count() from (select * from test union distinct select * from test except select * from test where toUInt8(name) > 3)) as max
select count() from (select * from test union all select * from test where toUInt8(name) < max);
with (select count() from (select * from test union distinct select * from test except select * from test where toUInt8(name) > 3)) as max
select count() from (select * from test except select * from test where toUInt8(name) < max);
select uuid from test union distinct select uuid from test;
select uuid from test union distinct select uuid from test union all select uuid from test where name = '1';
select uuid from (select * from test union distinct select * from test);

View File

@ -507,7 +507,6 @@
"01532_execute_merges_on_single_replica", /// static zk path
"01530_drop_database_atomic_sync", /// creates database
"02001_add_default_database_to_system_users", ///create user
"02002_row_level_filter_bug", ///create user
"02008_test_union_distinct_in_subquery" /// create database
"02002_row_level_filter_bug" ///create user
]
}