Update cross to inner rewrite

This commit is contained in:
vdimir 2021-02-15 15:00:08 +03:00
parent 09c4bd91ee
commit ed9d49abc3
No known key found for this signature in database
GPG Key ID: F57B3E10A21DBB31
3 changed files with 142 additions and 146 deletions

View File

@ -81,156 +81,143 @@ private:
ASTTableJoin * join = nullptr;
};
bool isComparison(const String & name)
{
return name == NameEquals::name ||
name == NameNotEquals::name ||
name == NameLess::name ||
name == NameGreater::name ||
name == NameLessOrEquals::name ||
name == NameGreaterOrEquals::name;
}
/// It checks if where expression could be moved to JOIN ON expression partially or entirely.
class CheckExpressionVisitorData
/// Collect all identifiers from ast
class IdentifiersCollector
{
public:
using TypeToVisit = const ASTFunction;
CheckExpressionVisitorData(const std::vector<JoinedElement> & tables_,
const std::vector<TableWithColumnNamesAndTypes> & tables_with_columns,
const Aliases & aliases_)
: joined_tables(tables_)
, tables(tables_with_columns)
, aliases(aliases_)
, is_complex(false)
{}
void visit(const ASTFunction & node, const ASTPtr & ast)
using ASTIdentPtr = const ASTIdentifier *;
using ASTIdentifiers = std::vector<ASTIdentPtr>;
struct Data
{
if (is_complex)
ASTIdentifiers idents;
};
static void visit(const ASTPtr & node, Data & data)
{
if (const auto * ident = node->as<ASTIdentifier>())
data.idents.push_back(ident);
}
static bool needChildVisit(const ASTPtr &, const ASTPtr &)
{
return true;
}
static ASTIdentifiers collect(const ASTPtr & node)
{
IdentifiersCollector::Data ident_data;
ConstInDepthNodeVisitor<IdentifiersCollector, true> ident_visitor(ident_data);
ident_visitor.visit(node);
return ident_data.idents;
}
};
/// Split expression `expr_1 AND expr_2 AND ... AND expr_n` into vector `[expr_1, expr_2, ..., expr_n]`
void collectConjunctions(const ASTPtr & node, std::vector<ASTPtr> & members)
{
if (const auto * func = node->as<ASTFunction>(); func && func->name == NameAnd::name)
{
for (const auto & child : func->arguments->children)
collectConjunctions(child, members);
return;
}
members.push_back(node);
}
if (node.name == NameAnd::name)
{
if (!node.arguments || node.arguments->children.empty())
throw Exception("Logical error: function requires argument", ErrorCodes::LOGICAL_ERROR);
std::optional<size_t> getIdentMembership(const ASTIdentifier & ident, const std::vector<TableWithColumnNamesAndTypes> & tables)
{
std::optional<size_t> table_pos = IdentifierSemantic::getMembership(ident);
if (table_pos)
return table_pos;
return IdentifierSemantic::chooseTableColumnMatch(ident, tables);
}
for (auto & child : node.arguments->children)
{
if (const auto * func = child->as<ASTFunction>())
{
visit(*func, child);
}
else
{
bool is_literal_or_ident = !child->as<ASTLiteral>() && !child->as<ASTIdentifier>();
is_complex = is_complex || !is_literal_or_ident;
}
}
}
else if (node.name == NameEquals::name)
{
if (size_t min_table = canMoveEqualsToJoinOn(node))
asts_to_join_on[min_table].push_back(ast);
}
else if (isComparison(node.name))
{
/// leave other comparisons as is
}
else if (functionIsLikeOperator(node.name) || /// LIKE, NOT LIKE, ILIKE, NOT ILIKE
functionIsInOperator(node.name)) /// IN, NOT IN
{
/// Leave as is. It's not possible to make push down here cause of unknown aliases and not implemented JOIN predicates.
/// select a as b from t1, t2 where t1.x = t2.x and b in(42)
/// select a as b from t1 inner join t2 on t1.x = t2.x and b in(42)
}
else if (node.name == NameOr::name)
{
std::optional<size_t> getIdentsMembership(const ASTPtr ast,
const std::vector<TableWithColumnNamesAndTypes> & tables,
const Aliases & aliases)
{
auto idents = IdentifiersCollector::collect(ast);
}
else
std::optional<size_t> result;
for (const auto * ident : idents)
{
is_complex = true;
asts_to_join_on.clear();
}
}
bool complex() const { return is_complex; }
bool matchAny(size_t t) const { return asts_to_join_on.count(t); }
ASTPtr makeOnExpression(size_t table_pos)
{
if (!asts_to_join_on.count(table_pos))
/// Moving expressions that use column aliases is not supported.
if (ident->isShort() && aliases.count(ident->shortName()))
return {};
const auto pos = getIdentMembership(*ident, tables);
if (!pos)
return {};
if (result && *pos != *result)
return {};
result = pos;
}
return result;
}
std::vector<ASTPtr> & expressions = asts_to_join_on[table_pos];
bool isAllowedToRewriteCrossJoin(const ASTPtr & node, const Aliases & aliases)
{
if (const auto * func = node->as<ASTFunction>())
{
auto idents = IdentifiersCollector::collect(node);
for (const auto * ident : idents)
{
if (ident->isShort() && aliases.count(ident->shortName()))
return false;
}
return true;
}
return node->as<ASTIdentifier>() || node->as<ASTLiteral>();
}
bool canMoveExpressionToJoinOn(const ASTPtr & ast,
const std::vector<JoinedElement> & joined_tables,
const std::vector<TableWithColumnNamesAndTypes> & tables,
const Aliases & aliases,
std::map<size_t, std::vector<ASTPtr>> & asts_to_join_on)
{
std::vector<ASTPtr> conjuncts;
collectConjunctions(ast, conjuncts);
for (const auto & node : conjuncts)
{
if (const auto * func = node->as<ASTFunction>(); func && func->name == NameEquals::name)
{
if (!func->arguments || func->arguments->children.size() != 2)
return false;
/// Check if the identifiers are from different joined tables.
/// If it's a self joint, tables should have aliases.
auto left_table_pos = getIdentsMembership(func->arguments->children[0], tables, aliases);
auto right_table_pos = getIdentsMembership(func->arguments->children[1], tables, aliases);
/// Identifiers from different table move to JOIN ON
if (left_table_pos && right_table_pos && *left_table_pos != *right_table_pos)
{
size_t table_pos = std::max(*left_table_pos, *right_table_pos);
if (joined_tables[table_pos].canAttachOnExpression())
asts_to_join_on[table_pos].push_back(node);
else
return false;
}
}
if (!isAllowedToRewriteCrossJoin(node, aliases))
return false;
}
return true;
}
ASTPtr makeOnExpression(const std::vector<ASTPtr> & expressions)
{
if (expressions.size() == 1)
return expressions[0]->clone();
std::vector<ASTPtr> arguments;
arguments.reserve(expressions.size());
for (auto & ast : expressions)
for (const auto & ast : expressions)
arguments.emplace_back(ast->clone());
return makeASTFunction(NameAnd::name, std::move(arguments));
}
private:
const std::vector<JoinedElement> & joined_tables;
const std::vector<TableWithColumnNamesAndTypes> & tables;
std::map<size_t, std::vector<ASTPtr>> asts_to_join_on;
const Aliases & aliases;
bool is_complex;
size_t canMoveEqualsToJoinOn(const ASTFunction & node)
{
if (!node.arguments)
throw Exception("Logical error: function requires arguments", ErrorCodes::LOGICAL_ERROR);
if (node.arguments->children.size() != 2)
return false;
const auto * left = node.arguments->children[0]->as<ASTIdentifier>();
const auto * right = node.arguments->children[1]->as<ASTIdentifier>();
if (!left || !right)
return false;
/// Moving expressions that use column aliases is not supported.
if (left->isShort() && aliases.count(left->shortName()))
return false;
if (right->isShort() && aliases.count(right->shortName()))
return false;
return checkIdentifiers(*left, *right);
}
/// Check if the identifiers are from different joined tables. If it's a self joint, tables should have aliases.
/// select * from t1 a cross join t2 b where a.x = b.x
/// @return table position to attach expression to or 0.
size_t checkIdentifiers(const ASTIdentifier & left, const ASTIdentifier & right)
{
std::optional<size_t> left_table_pos = IdentifierSemantic::getMembership(left);
if (!left_table_pos)
left_table_pos = IdentifierSemantic::chooseTableColumnMatch(left, tables);
std::optional<size_t> right_table_pos = IdentifierSemantic::getMembership(right);
if (!right_table_pos)
right_table_pos = IdentifierSemantic::chooseTableColumnMatch(right, tables);
if (left_table_pos && right_table_pos && (*left_table_pos != *right_table_pos))
{
size_t table_pos = std::max(*left_table_pos, *right_table_pos);
if (joined_tables[table_pos].canAttachOnExpression())
return table_pos;
}
return 0;
}
};
using CheckExpressionMatcher = ConstOneTypeMatcher<CheckExpressionVisitorData, NeedChild::none>;
using CheckExpressionVisitor = ConstInDepthNodeVisitor<CheckExpressionMatcher, true>;
}
bool getTables(ASTSelectQuery & select, std::vector<JoinedElement> & joined_tables, size_t & num_comma)
{
@ -342,20 +329,21 @@ void CrossToInnerJoinMatcher::visit(ASTSelectQuery & select, ASTPtr &, Data & da
if (!select.where())
return;
CheckExpressionVisitor::Data visitor_data{joined_tables, data.tables_with_columns, data.aliases};
CheckExpressionVisitor(visitor_data).visit(select.where());
if (visitor_data.complex())
return;
std::map<size_t, std::vector<ASTPtr>> asts_to_join_on;
bool can_move_where = canMoveExpressionToJoinOn(
select.where(), joined_tables, data.tables_with_columns, data.aliases, asts_to_join_on);
if (can_move_where)
{
for (size_t i = 1; i < joined_tables.size(); ++i)
{
if (visitor_data.matchAny(i))
const auto & expr_it = asts_to_join_on.find(i);
if (expr_it != asts_to_join_on.end())
{
if (joined_tables[i].rewriteCrossToInner(visitor_data.makeOnExpression(i)))
if (joined_tables[i].rewriteCrossToInner(makeOnExpression(expr_it->second)))
data.done = true;
}
}
}
}
}

View File

@ -19,3 +19,10 @@ SELECT
FROM n
ALL INNER JOIN r ON k = r.k
WHERE (k = r.k) AND (name NOT LIKE \'A%\')
SELECT
k,
r.k,
name
FROM n
ALL INNER JOIN r ON (k + 1) = (r.k + 1)
WHERE ((k + 1) = (r.k + 1)) AND ((name = \'A\') OR (name = \'AA\'))

View File

@ -9,6 +9,7 @@ SET enable_optimize_predicate_expression = 0;
EXPLAIN SYNTAX SELECT * FROM n, r WHERE n.k = r.k AND r.name = 'A';
EXPLAIN SYNTAX SELECT * FROM n, r WHERE n.k = r.k AND r.name LIKE 'A%';
EXPLAIN SYNTAX SELECT * FROM n, r WHERE n.k = r.k AND r.name NOT LIKE 'A%';
EXPLAIN SYNTAX SELECT * FROM n, r WHERE n.k + 1 = r.k + 1 AND (r.name = 'A' OR r.name = 'AA');
DROP TABLE n;
DROP TABLE r;