use source and joined columns to detect JOIN ON right and left keys

This commit is contained in:
chertus 2019-07-24 18:37:37 +03:00
parent 9da1b0089c
commit b3123df58e
6 changed files with 172 additions and 31 deletions

View File

@ -4,14 +4,18 @@
#include <Parsers/queryToString.h>
#include <Interpreters/InDepthNodeVisitor.h>
#include <Interpreters/Aliases.h>
#include <Interpreters/SyntaxAnalyzer.h>
namespace DB
{
namespace ErrorCodes
{
extern const int INVALID_JOIN_ON_EXPRESSION;
extern const int AMBIGUOUS_COLUMN_NAME;
extern const int LOGICAL_ERROR;
}
@ -23,6 +27,9 @@ public:
struct Data
{
AnalyzedJoin & analyzed_join;
const NameSet & source_columns;
const NameSet & joined_columns;
const Aliases & aliases;
bool has_some = false;
};
@ -50,7 +57,7 @@ private:
{
ASTPtr left = func.arguments->children.at(0)->clone();
ASTPtr right = func.arguments->children.at(1)->clone();
addColumnsFromEqualsExpr(ast, left, right, data);
addJoinKeys(ast, left, right, data);
return;
}
@ -70,7 +77,7 @@ private:
getIdentifiers(child, out);
}
static void addColumnsFromEqualsExpr(const ASTPtr & expr, ASTPtr left_ast, ASTPtr right_ast, Data & data)
static void addJoinKeys(const ASTPtr & expr, ASTPtr left_ast, ASTPtr right_ast, Data & data)
{
std::vector<const ASTIdentifier *> left_identifiers;
std::vector<const ASTIdentifier *> right_identifiers;
@ -78,8 +85,8 @@ private:
getIdentifiers(left_ast, left_identifiers);
getIdentifiers(right_ast, right_identifiers);
size_t left_idents_table = checkSameTable(left_identifiers);
size_t right_idents_table = checkSameTable(right_identifiers);
size_t left_idents_table = getTableForIdentifiers(left_identifiers, data);
size_t right_idents_table = getTableForIdentifiers(right_identifiers, data);
if (left_idents_table && left_idents_table == right_idents_table)
{
@ -95,40 +102,79 @@ private:
else if (left_idents_table == 2 || right_idents_table == 1)
data.analyzed_join.addOnKeys(right_ast, left_ast);
else
{
/// Default variant when all identifiers may be from any table.
data.analyzed_join.addOnKeys(left_ast, right_ast); /// FIXME
}
throw Exception("Cannot detect left and right JOIN keys. JOIN ON section is ambiguous.",
ErrorCodes::AMBIGUOUS_COLUMN_NAME);
data.has_some = true;
}
static size_t checkSameTable(std::vector<const ASTIdentifier *> & identifiers)
static const ASTIdentifier * unrollAliases(const ASTIdentifier * identifier, const Aliases & aliases)
{
UInt32 max_attempts = 100;
for (auto it = aliases.find(identifier->name); it != aliases.end();)
{
const ASTIdentifier * parent = identifier;
identifier = it->second->as<ASTIdentifier>();
if (!identifier)
break; /// not a column alias
if (identifier == parent)
break; /// alias to itself with the same name: 'a as a'
if (identifier->compound())
break; /// not an alias. Break to prevent cycle through short names: 'a as b, t1.b as a'
it = aliases.find(identifier->name);
if (!max_attempts--)
throw Exception("Cannot unroll aliases for '" + identifier->name + "'", ErrorCodes::LOGICAL_ERROR);
}
return identifier;
}
/// @returns 1 if identifiers belongs to left table, 2 for right table and 0 if unknown. Throws on table mix.
/// Place detected identifier into identifiers[0] if any.
static size_t getTableForIdentifiers(std::vector<const ASTIdentifier *> & identifiers, const Data & data)
{
size_t table_number = 0;
const ASTIdentifier * detected = nullptr;
for (const auto & identifier : identifiers)
for (auto & ident : identifiers)
{
/// It's set in TranslateQualifiedNamesVisitor
const ASTIdentifier * identifier = unrollAliases(ident, data.aliases);
if (!identifier)
continue;
/// Column name could be cropped to a short form in TranslateQualifiedNamesVisitor.
/// In this case it saves membership in IdentifierSemantic.
size_t membership = IdentifierSemantic::getMembership(*identifier);
if (!membership)
{
const String & name = identifier->name;
bool in_left_table = data.source_columns.count(name);
bool in_right_table = data.joined_columns.count(name);
if (in_left_table && in_right_table)
throw Exception("Column '" + name + "' is ambiguous", ErrorCodes::AMBIGUOUS_COLUMN_NAME);
if (in_left_table)
membership = 1;
if (in_right_table)
membership = 2;
}
if (membership && table_number == 0)
{
table_number = membership;
detected = identifier;
std::swap(ident, identifiers[0]); /// move first detected identifier to the first position
}
if (membership && membership != table_number)
{
throw Exception("Invalid columns in JOIN ON section. Columns "
+ detected->getAliasOrColumnName() + " and " + identifier->getAliasOrColumnName()
+ identifiers[0]->getAliasOrColumnName() + " and " + ident->getAliasOrColumnName()
+ " are from different tables.", ErrorCodes::INVALID_JOIN_ON_EXPRESSION);
}
}
identifiers.clear();
if (detected)
identifiers.push_back(detected);
return table_number;
}

View File

@ -482,8 +482,8 @@ void getArrayJoinedColumns(ASTPtr & query, SyntaxAnalyzerResult & result, const
}
/// Find the columns that are obtained by JOIN.
void collectJoinedColumns(AnalyzedJoin & analyzed_join, const ASTSelectQuery & select_query,
const NameSet & source_columns, const String & current_database, bool join_use_nulls)
void collectJoinedColumns(AnalyzedJoin & analyzed_join, const ASTSelectQuery & select_query, const NameSet & source_columns,
const Aliases & aliases, const String & current_database, bool join_use_nulls)
{
const ASTTablesInSelectQueryElement * node = select_query.join();
if (!node)
@ -505,7 +505,11 @@ void collectJoinedColumns(AnalyzedJoin & analyzed_join, const ASTSelectQuery & s
}
else if (table_join.on_expression)
{
CollectJoinOnKeysVisitor::Data data{analyzed_join};
NameSet joined_columns;
for (const auto & col : analyzed_join.columns_from_joined_table)
joined_columns.insert(col.original_name);
CollectJoinOnKeysVisitor::Data data{analyzed_join, source_columns, joined_columns, aliases};
CollectJoinOnKeysVisitor(data).visit(table_join.on_expression);
if (!data.has_some)
throw Exception("Cannot get JOIN keys from JOIN ON section: " + queryToString(table_join.on_expression),
@ -662,7 +666,8 @@ SyntaxAnalyzerResultPtr SyntaxAnalyzer::analyze(
/// Push the predicate expression down to the subqueries.
result.rewrite_subqueries = PredicateExpressionsOptimizer(select_query, settings, context).optimize();
collectJoinedColumns(result.analyzed_join, *select_query, source_columns_set, context.getCurrentDatabase(), settings.join_use_nulls);
collectJoinedColumns(result.analyzed_join, *select_query, source_columns_set, result.aliases,
context.getCurrentDatabase(), settings.join_use_nulls);
}
return std::make_shared<const SyntaxAnalyzerResult>(result);

View File

@ -24,7 +24,7 @@ public:
struct Data
{
NameSet source_columns;
const NameSet source_columns;
const std::vector<TableWithColumnNames> & tables;
std::unordered_set<String> join_using_columns;
bool has_columns;

View File

@ -8,15 +8,16 @@ select * from (select toLowCardinality(toNullable(dummy)) as val from system.one
select * from (select toLowCardinality(dummy) as val from system.one) any left join (select toLowCardinality(toNullable(dummy)) as val from system.one) using val;
select * from (select toLowCardinality(toNullable(dummy)) as val from system.one) any left join (select toLowCardinality(toNullable(dummy)) as val from system.one) using val;
select '-';
select * from (select dummy as val from system.one) any left join (select dummy as val from system.one) on val + 0 = val * 1;
select * from (select toLowCardinality(dummy) as val from system.one) any left join (select dummy as val from system.one) on val + 0 = val * 1;
select * from (select dummy as val from system.one) any left join (select toLowCardinality(dummy) as val from system.one) on val + 0 = val * 1;
select * from (select toLowCardinality(dummy) as val from system.one) any left join (select toLowCardinality(dummy) as val from system.one) on val + 0 = val * 1;
select * from (select toLowCardinality(toNullable(dummy)) as val from system.one) any left join (select dummy as val from system.one) on val + 0 = val * 1;
select * from (select dummy as val from system.one) any left join (select toLowCardinality(toNullable(dummy)) as val from system.one) on val + 0 = val * 1;
select * from (select toLowCardinality(toNullable(dummy)) as val from system.one) any left join (select toLowCardinality(dummy) as val from system.one) on val + 0 = val * 1;
select * from (select toLowCardinality(dummy) as val from system.one) any left join (select toLowCardinality(toNullable(dummy)) as val from system.one) on val + 0 = val * 1;
select * from (select toLowCardinality(toNullable(dummy)) as val from system.one) any left join (select toLowCardinality(toNullable(dummy)) as val from system.one) on val + 0 = val * 1;
select * from (select dummy as val from system.one) any left join (select dummy as val from system.one) on val + 0 = val * 1; -- { serverError 352 }
select * from (select dummy as val from system.one) any left join (select dummy as rval from system.one) on val + 0 = rval * 1;
select * from (select toLowCardinality(dummy) as val from system.one) any left join (select dummy as rval from system.one) on val + 0 = rval * 1;
select * from (select dummy as val from system.one) any left join (select toLowCardinality(dummy) as rval from system.one) on val + 0 = rval * 1;
select * from (select toLowCardinality(dummy) as val from system.one) any left join (select toLowCardinality(dummy) as rval from system.one) on val + 0 = rval * 1;
select * from (select toLowCardinality(toNullable(dummy)) as val from system.one) any left join (select dummy as rval from system.one) on val + 0 = rval * 1;
select * from (select dummy as val from system.one) any left join (select toLowCardinality(toNullable(dummy)) as rval from system.one) on val + 0 = rval * 1;
select * from (select toLowCardinality(toNullable(dummy)) as val from system.one) any left join (select toLowCardinality(dummy) as rval from system.one) on val + 0 = rval * 1;
select * from (select toLowCardinality(dummy) as val from system.one) any left join (select toLowCardinality(toNullable(dummy)) as rval from system.one) on val + 0 = rval * 1;
select * from (select toLowCardinality(toNullable(dummy)) as val from system.one) any left join (select toLowCardinality(toNullable(dummy)) as rval from system.one) on val + 0 = rval * 1;
select '-';
select * from (select number as l from system.numbers limit 3) any left join (select number as r from system.numbers limit 3) on l + 1 = r * 1;
select * from (select toLowCardinality(number) as l from system.numbers limit 3) any left join (select number as r from system.numbers limit 3) on l + 1 = r * 1;

View File

@ -0,0 +1,28 @@
2 y 2 w
2 y 2 w
2 2
2 2
y w
y w
2 2
2 2
y w
y w
y y
y y
y y
y y
2 y 2 w
2 y 2 w
2 2
2 2
y w
y w
2 2
2 2
y w
y w
y y
y y
y y
y y

View File

@ -0,0 +1,61 @@
use test;
drop table if exists t1;
drop table if exists t2;
create table t1 (a UInt32, b String) engine = Memory;
create table t2 (c UInt32, d String) engine = Memory;
insert into t1 values (1, 'x'), (2, 'y'), (3, 'z');
insert into t2 values (2, 'w'), (4, 'y');
set enable_optimize_predicate_expression = 0;
select * from t1 join t2 on a = c;
select * from t1 join t2 on c = a;
select t1.a, t2.c from t1 join t2 on a = c;
select t1.a, t2.c from t1 join t2 on c = a;
select t1.b, t2.d from t1 join t2 on a = c;
select t1.b, t2.d from t1 join t2 on c = a;
select a, c from t1 join t2 on a = c;
select a, c from t1 join t2 on c = a;
select b, d from t1 join t2 on a = c;
select b, d from t1 join t2 on c = a;
select b as a, d as c from t1 join t2 on a = c;
select b as a, d as c from t1 join t2 on c = a;
select b as c, d as a from t1 join t2 on a = c;
select b as c, d as a from t1 join t2 on c = a;
-- TODO
-- select t1.a as a, t2.c as c from t1 join t2 on a = c;
-- select t1.a as a, t2.c as c from t1 join t2 on c = a;
-- select t1.a as c, t2.c as a from t1 join t2 on a = c;
-- select t1.a as c, t2.c as a from t1 join t2 on c = a;
--
-- select t1.a as c, t2.c as a from t1 join t2 on t1.a = t2.c;
-- select t1.a as c, t2.c as a from t1 join t2 on t2.c = t1.a;
set enable_optimize_predicate_expression = 1;
select * from t1 join t2 on a = c;
select * from t1 join t2 on c = a;
select t1.a, t2.c from t1 join t2 on a = c;
select t1.a, t2.c from t1 join t2 on c = a;
select t1.b, t2.d from t1 join t2 on a = c;
select t1.b, t2.d from t1 join t2 on c = a;
select a, c from t1 join t2 on a = c;
select a, c from t1 join t2 on c = a;
select b, d from t1 join t2 on a = c;
select b, d from t1 join t2 on c = a;
select b as a, d as c from t1 join t2 on a = c;
select b as a, d as c from t1 join t2 on c = a;
select b as c, d as a from t1 join t2 on a = c;
select b as c, d as a from t1 join t2 on c = a;
drop table t1;
drop table t2;