mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-25 09:02:00 +00:00
Update cross to inner rewrite
This commit is contained in:
parent
09c4bd91ee
commit
ed9d49abc3
@ -81,156 +81,143 @@ private:
|
||||
ASTTableJoin * join = nullptr;
|
||||
};
|
||||
|
||||
bool isComparison(const String & name)
|
||||
{
|
||||
return name == NameEquals::name ||
|
||||
name == NameNotEquals::name ||
|
||||
name == NameLess::name ||
|
||||
name == NameGreater::name ||
|
||||
name == NameLessOrEquals::name ||
|
||||
name == NameGreaterOrEquals::name;
|
||||
}
|
||||
|
||||
/// It checks if where expression could be moved to JOIN ON expression partially or entirely.
|
||||
class CheckExpressionVisitorData
|
||||
/// Collect all identifiers from ast
|
||||
class IdentifiersCollector
|
||||
{
|
||||
public:
|
||||
using TypeToVisit = const ASTFunction;
|
||||
|
||||
CheckExpressionVisitorData(const std::vector<JoinedElement> & tables_,
|
||||
const std::vector<TableWithColumnNamesAndTypes> & tables_with_columns,
|
||||
const Aliases & aliases_)
|
||||
: joined_tables(tables_)
|
||||
, tables(tables_with_columns)
|
||||
, aliases(aliases_)
|
||||
, is_complex(false)
|
||||
{}
|
||||
|
||||
void visit(const ASTFunction & node, const ASTPtr & ast)
|
||||
using ASTIdentPtr = const ASTIdentifier *;
|
||||
using ASTIdentifiers = std::vector<ASTIdentPtr>;
|
||||
struct Data
|
||||
{
|
||||
if (is_complex)
|
||||
return;
|
||||
ASTIdentifiers idents;
|
||||
};
|
||||
|
||||
if (node.name == NameAnd::name)
|
||||
{
|
||||
if (!node.arguments || node.arguments->children.empty())
|
||||
throw Exception("Logical error: function requires argument", ErrorCodes::LOGICAL_ERROR);
|
||||
|
||||
for (auto & child : node.arguments->children)
|
||||
{
|
||||
if (const auto * func = child->as<ASTFunction>())
|
||||
{
|
||||
visit(*func, child);
|
||||
}
|
||||
else
|
||||
{
|
||||
bool is_literal_or_ident = !child->as<ASTLiteral>() && !child->as<ASTIdentifier>();
|
||||
is_complex = is_complex || !is_literal_or_ident;
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (node.name == NameEquals::name)
|
||||
{
|
||||
if (size_t min_table = canMoveEqualsToJoinOn(node))
|
||||
asts_to_join_on[min_table].push_back(ast);
|
||||
}
|
||||
else if (isComparison(node.name))
|
||||
{
|
||||
/// leave other comparisons as is
|
||||
}
|
||||
else if (functionIsLikeOperator(node.name) || /// LIKE, NOT LIKE, ILIKE, NOT ILIKE
|
||||
functionIsInOperator(node.name)) /// IN, NOT IN
|
||||
{
|
||||
/// Leave as is. It's not possible to make push down here cause of unknown aliases and not implemented JOIN predicates.
|
||||
/// select a as b from t1, t2 where t1.x = t2.x and b in(42)
|
||||
/// select a as b from t1 inner join t2 on t1.x = t2.x and b in(42)
|
||||
}
|
||||
else if (node.name == NameOr::name)
|
||||
{
|
||||
|
||||
}
|
||||
else
|
||||
{
|
||||
is_complex = true;
|
||||
asts_to_join_on.clear();
|
||||
}
|
||||
static void visit(const ASTPtr & node, Data & data)
|
||||
{
|
||||
if (const auto * ident = node->as<ASTIdentifier>())
|
||||
data.idents.push_back(ident);
|
||||
}
|
||||
|
||||
bool complex() const { return is_complex; }
|
||||
bool matchAny(size_t t) const { return asts_to_join_on.count(t); }
|
||||
|
||||
ASTPtr makeOnExpression(size_t table_pos)
|
||||
static bool needChildVisit(const ASTPtr &, const ASTPtr &)
|
||||
{
|
||||
if (!asts_to_join_on.count(table_pos))
|
||||
return {};
|
||||
|
||||
std::vector<ASTPtr> & expressions = asts_to_join_on[table_pos];
|
||||
|
||||
if (expressions.size() == 1)
|
||||
return expressions[0]->clone();
|
||||
|
||||
std::vector<ASTPtr> arguments;
|
||||
arguments.reserve(expressions.size());
|
||||
for (auto & ast : expressions)
|
||||
arguments.emplace_back(ast->clone());
|
||||
|
||||
return makeASTFunction(NameAnd::name, std::move(arguments));
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
const std::vector<JoinedElement> & joined_tables;
|
||||
const std::vector<TableWithColumnNamesAndTypes> & tables;
|
||||
std::map<size_t, std::vector<ASTPtr>> asts_to_join_on;
|
||||
const Aliases & aliases;
|
||||
bool is_complex;
|
||||
|
||||
size_t canMoveEqualsToJoinOn(const ASTFunction & node)
|
||||
static ASTIdentifiers collect(const ASTPtr & node)
|
||||
{
|
||||
if (!node.arguments)
|
||||
throw Exception("Logical error: function requires arguments", ErrorCodes::LOGICAL_ERROR);
|
||||
if (node.arguments->children.size() != 2)
|
||||
return false;
|
||||
|
||||
const auto * left = node.arguments->children[0]->as<ASTIdentifier>();
|
||||
const auto * right = node.arguments->children[1]->as<ASTIdentifier>();
|
||||
if (!left || !right)
|
||||
return false;
|
||||
|
||||
/// Moving expressions that use column aliases is not supported.
|
||||
if (left->isShort() && aliases.count(left->shortName()))
|
||||
return false;
|
||||
if (right->isShort() && aliases.count(right->shortName()))
|
||||
return false;
|
||||
|
||||
return checkIdentifiers(*left, *right);
|
||||
}
|
||||
|
||||
/// Check if the identifiers are from different joined tables. If it's a self joint, tables should have aliases.
|
||||
/// select * from t1 a cross join t2 b where a.x = b.x
|
||||
/// @return table position to attach expression to or 0.
|
||||
size_t checkIdentifiers(const ASTIdentifier & left, const ASTIdentifier & right)
|
||||
{
|
||||
std::optional<size_t> left_table_pos = IdentifierSemantic::getMembership(left);
|
||||
if (!left_table_pos)
|
||||
left_table_pos = IdentifierSemantic::chooseTableColumnMatch(left, tables);
|
||||
|
||||
std::optional<size_t> right_table_pos = IdentifierSemantic::getMembership(right);
|
||||
if (!right_table_pos)
|
||||
right_table_pos = IdentifierSemantic::chooseTableColumnMatch(right, tables);
|
||||
|
||||
if (left_table_pos && right_table_pos && (*left_table_pos != *right_table_pos))
|
||||
{
|
||||
size_t table_pos = std::max(*left_table_pos, *right_table_pos);
|
||||
if (joined_tables[table_pos].canAttachOnExpression())
|
||||
return table_pos;
|
||||
}
|
||||
return 0;
|
||||
IdentifiersCollector::Data ident_data;
|
||||
ConstInDepthNodeVisitor<IdentifiersCollector, true> ident_visitor(ident_data);
|
||||
ident_visitor.visit(node);
|
||||
return ident_data.idents;
|
||||
}
|
||||
};
|
||||
|
||||
using CheckExpressionMatcher = ConstOneTypeMatcher<CheckExpressionVisitorData, NeedChild::none>;
|
||||
using CheckExpressionVisitor = ConstInDepthNodeVisitor<CheckExpressionMatcher, true>;
|
||||
/// Split expression `expr_1 AND expr_2 AND ... AND expr_n` into vector `[expr_1, expr_2, ..., expr_n]`
|
||||
void collectConjunctions(const ASTPtr & node, std::vector<ASTPtr> & members)
|
||||
{
|
||||
if (const auto * func = node->as<ASTFunction>(); func && func->name == NameAnd::name)
|
||||
{
|
||||
for (const auto & child : func->arguments->children)
|
||||
collectConjunctions(child, members);
|
||||
return;
|
||||
}
|
||||
members.push_back(node);
|
||||
}
|
||||
|
||||
std::optional<size_t> getIdentMembership(const ASTIdentifier & ident, const std::vector<TableWithColumnNamesAndTypes> & tables)
|
||||
{
|
||||
std::optional<size_t> table_pos = IdentifierSemantic::getMembership(ident);
|
||||
if (table_pos)
|
||||
return table_pos;
|
||||
return IdentifierSemantic::chooseTableColumnMatch(ident, tables);
|
||||
}
|
||||
|
||||
std::optional<size_t> getIdentsMembership(const ASTPtr ast,
|
||||
const std::vector<TableWithColumnNamesAndTypes> & tables,
|
||||
const Aliases & aliases)
|
||||
{
|
||||
auto idents = IdentifiersCollector::collect(ast);
|
||||
|
||||
std::optional<size_t> result;
|
||||
for (const auto * ident : idents)
|
||||
{
|
||||
/// Moving expressions that use column aliases is not supported.
|
||||
if (ident->isShort() && aliases.count(ident->shortName()))
|
||||
return {};
|
||||
const auto pos = getIdentMembership(*ident, tables);
|
||||
if (!pos)
|
||||
return {};
|
||||
if (result && *pos != *result)
|
||||
return {};
|
||||
result = pos;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool isAllowedToRewriteCrossJoin(const ASTPtr & node, const Aliases & aliases)
|
||||
{
|
||||
if (const auto * func = node->as<ASTFunction>())
|
||||
{
|
||||
auto idents = IdentifiersCollector::collect(node);
|
||||
for (const auto * ident : idents)
|
||||
{
|
||||
if (ident->isShort() && aliases.count(ident->shortName()))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return node->as<ASTIdentifier>() || node->as<ASTLiteral>();
|
||||
}
|
||||
|
||||
bool canMoveExpressionToJoinOn(const ASTPtr & ast,
|
||||
const std::vector<JoinedElement> & joined_tables,
|
||||
const std::vector<TableWithColumnNamesAndTypes> & tables,
|
||||
const Aliases & aliases,
|
||||
std::map<size_t, std::vector<ASTPtr>> & asts_to_join_on)
|
||||
{
|
||||
std::vector<ASTPtr> conjuncts;
|
||||
collectConjunctions(ast, conjuncts);
|
||||
for (const auto & node : conjuncts)
|
||||
{
|
||||
if (const auto * func = node->as<ASTFunction>(); func && func->name == NameEquals::name)
|
||||
{
|
||||
if (!func->arguments || func->arguments->children.size() != 2)
|
||||
return false;
|
||||
|
||||
/// Check if the identifiers are from different joined tables.
|
||||
/// If it's a self joint, tables should have aliases.
|
||||
auto left_table_pos = getIdentsMembership(func->arguments->children[0], tables, aliases);
|
||||
auto right_table_pos = getIdentsMembership(func->arguments->children[1], tables, aliases);
|
||||
|
||||
/// Identifiers from different table move to JOIN ON
|
||||
if (left_table_pos && right_table_pos && *left_table_pos != *right_table_pos)
|
||||
{
|
||||
size_t table_pos = std::max(*left_table_pos, *right_table_pos);
|
||||
if (joined_tables[table_pos].canAttachOnExpression())
|
||||
asts_to_join_on[table_pos].push_back(node);
|
||||
else
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (!isAllowedToRewriteCrossJoin(node, aliases))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
ASTPtr makeOnExpression(const std::vector<ASTPtr> & expressions)
|
||||
{
|
||||
if (expressions.size() == 1)
|
||||
return expressions[0]->clone();
|
||||
|
||||
std::vector<ASTPtr> arguments;
|
||||
arguments.reserve(expressions.size());
|
||||
for (const auto & ast : expressions)
|
||||
arguments.emplace_back(ast->clone());
|
||||
|
||||
return makeASTFunction(NameAnd::name, std::move(arguments));
|
||||
}
|
||||
|
||||
bool getTables(ASTSelectQuery & select, std::vector<JoinedElement> & joined_tables, size_t & num_comma)
|
||||
{
|
||||
@ -342,18 +329,19 @@ void CrossToInnerJoinMatcher::visit(ASTSelectQuery & select, ASTPtr &, Data & da
|
||||
if (!select.where())
|
||||
return;
|
||||
|
||||
CheckExpressionVisitor::Data visitor_data{joined_tables, data.tables_with_columns, data.aliases};
|
||||
CheckExpressionVisitor(visitor_data).visit(select.where());
|
||||
|
||||
if (visitor_data.complex())
|
||||
return;
|
||||
|
||||
for (size_t i = 1; i < joined_tables.size(); ++i)
|
||||
std::map<size_t, std::vector<ASTPtr>> asts_to_join_on;
|
||||
bool can_move_where = canMoveExpressionToJoinOn(
|
||||
select.where(), joined_tables, data.tables_with_columns, data.aliases, asts_to_join_on);
|
||||
if (can_move_where)
|
||||
{
|
||||
if (visitor_data.matchAny(i))
|
||||
for (size_t i = 1; i < joined_tables.size(); ++i)
|
||||
{
|
||||
if (joined_tables[i].rewriteCrossToInner(visitor_data.makeOnExpression(i)))
|
||||
data.done = true;
|
||||
const auto & expr_it = asts_to_join_on.find(i);
|
||||
if (expr_it != asts_to_join_on.end())
|
||||
{
|
||||
if (joined_tables[i].rewriteCrossToInner(makeOnExpression(expr_it->second)))
|
||||
data.done = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -19,3 +19,10 @@ SELECT
|
||||
FROM n
|
||||
ALL INNER JOIN r ON k = r.k
|
||||
WHERE (k = r.k) AND (name NOT LIKE \'A%\')
|
||||
SELECT
|
||||
k,
|
||||
r.k,
|
||||
name
|
||||
FROM n
|
||||
ALL INNER JOIN r ON (k + 1) = (r.k + 1)
|
||||
WHERE ((k + 1) = (r.k + 1)) AND ((name = \'A\') OR (name = \'AA\'))
|
||||
|
@ -9,6 +9,7 @@ SET enable_optimize_predicate_expression = 0;
|
||||
EXPLAIN SYNTAX SELECT * FROM n, r WHERE n.k = r.k AND r.name = 'A';
|
||||
EXPLAIN SYNTAX SELECT * FROM n, r WHERE n.k = r.k AND r.name LIKE 'A%';
|
||||
EXPLAIN SYNTAX SELECT * FROM n, r WHERE n.k = r.k AND r.name NOT LIKE 'A%';
|
||||
EXPLAIN SYNTAX SELECT * FROM n, r WHERE n.k + 1 = r.k + 1 AND (r.name = 'A' OR r.name = 'AA');
|
||||
|
||||
DROP TABLE n;
|
||||
DROP TABLE r;
|
||||
|
Loading…
Reference in New Issue
Block a user