Fix joined table access with Merge engine and aggregation

This commit is contained in:
vdimir 2021-03-02 15:28:09 +03:00
parent 8f23d39f26
commit 5476e68d6c
No known key found for this signature in database
GPG Key ID: F57B3E10A21DBB31
7 changed files with 183 additions and 36 deletions

View File

@ -249,4 +249,65 @@ void IdentifierSemantic::setColumnLongName(ASTIdentifier & identifier, const Dat
} }
} }
IdentifiersCollector::ASTIdentifiers IdentifiersCollector::collect(const ASTPtr & node)
{
IdentifiersCollector::Data ident_data;
ConstInDepthNodeVisitor<IdentifiersCollector, true> ident_visitor(ident_data);
ident_visitor.visit(node);
return ident_data.idents;
}
bool IdentifiersCollector::needChildVisit(const ASTPtr &, const ASTPtr &)
{
return true;
}
void IdentifiersCollector::visit(const ASTPtr & node, IdentifiersCollector::Data & data)
{
if (const auto * ident = node->as<ASTIdentifier>())
data.idents.push_back(ident);
}
IdentifierMembershipCollector::IdentifierMembershipCollector(const ASTSelectQuery & select, const Context & context)
{
if (ASTPtr with = select.with())
QueryAliasesNoSubqueriesVisitor(aliases).visit(with);
QueryAliasesNoSubqueriesVisitor(aliases).visit(select.select());
tables = getDatabaseAndTablesWithColumns(getTableExpressions(select), context);
}
std::optional<size_t> IdentifierMembershipCollector::getIdentsMembership(
const ASTPtr ast, const std::vector<TableWithColumnNamesAndTypes> & tables, const Aliases & aliases)
{
auto idents = IdentifiersCollector::collect(ast);
std::optional<size_t> result;
for (const auto * ident : idents)
{
/// short name clashes with alias, ambiguous case
if (ident->isShort() && aliases.count(ident->shortName()))
return {};
const auto pos = getIdentMembership(*ident, tables);
if (!pos)
return {};
/// identifiers from different tables
if (result && *pos != *result)
return {};
result = pos;
}
return result;
}
std::optional<size_t>
IdentifierMembershipCollector::getIdentMembership(const ASTIdentifier & ident, const std::vector<TableWithColumnNamesAndTypes> & tables)
{
std::optional<size_t> table_pos = IdentifierSemantic::getMembership(ident);
if (table_pos)
return table_pos;
return IdentifierSemantic::chooseTableColumnMatch(ident, tables);
}
} }

View File

@ -2,8 +2,15 @@
#include <optional> #include <optional>
#include <Parsers/ASTIdentifier.h> #include <Interpreters/Aliases.h>
#include <Interpreters/DatabaseAndTableWithAlias.h> #include <Interpreters/DatabaseAndTableWithAlias.h>
#include <Interpreters/InDepthNodeVisitor.h>
#include <Interpreters/QueryAliasesVisitor.h>
#include <Interpreters/getHeaderForProcessingStage.h>
#include <Interpreters/getTableExpressions.h>
#include <Parsers/ASTIdentifier.h>
#include <Parsers/ASTSelectQuery.h>
namespace DB namespace DB
{ {
@ -64,4 +71,43 @@ private:
static bool doesIdentifierBelongTo(const ASTIdentifier & identifier, const String & table); static bool doesIdentifierBelongTo(const ASTIdentifier & identifier, const String & table);
}; };
/// Collect all identifies from AST recursively
class IdentifiersCollector
{
public:
using ASTIdentPtr = const ASTIdentifier *;
using ASTIdentifiers = std::vector<ASTIdentPtr>;
struct Data
{
ASTIdentifiers idents;
};
static void visit(const ASTPtr & node, Data & data);
static bool needChildVisit(const ASTPtr &, const ASTPtr &);
static ASTIdentifiers collect(const ASTPtr & node);
};
/// Collect identifier table membership considering aliases
class IdentifierMembershipCollector
{
public:
IdentifierMembershipCollector(const ASTSelectQuery & select, const Context & context);
std::optional<size_t> getIdentsMembership(const ASTPtr ast) const
{
return IdentifierMembershipCollector::getIdentsMembership(ast, tables, aliases);
}
/// Collect common table membership for identifiers in expression
/// If membership cannot be established or there are several identifies from different tables, return empty optional
static std::optional<size_t>
getIdentsMembership(const ASTPtr ast, const std::vector<TableWithColumnNamesAndTypes> & tables, const Aliases & aliases);
private:
std::vector<TableWithColumnNamesAndTypes> tables;
Aliases aliases;
static std::optional<size_t> getIdentMembership(const ASTIdentifier & ident, const std::vector<TableWithColumnNamesAndTypes> & tables);
};
} }

View File

@ -1,8 +1,9 @@
#include <Interpreters/getHeaderForProcessingStage.h>
#include <Interpreters/InterpreterSelectQuery.h>
#include <Storages/IStorage.h>
#include <DataStreams/OneBlockInputStream.h> #include <DataStreams/OneBlockInputStream.h>
#include <Interpreters/IdentifierSemantic.h>
#include <Interpreters/InterpreterSelectQuery.h>
#include <Interpreters/getHeaderForProcessingStage.h>
#include <Parsers/ASTTablesInSelectQuery.h> #include <Parsers/ASTTablesInSelectQuery.h>
#include <Storages/IStorage.h>
namespace DB namespace DB
{ {
@ -13,7 +14,7 @@ namespace ErrorCodes
} }
/// Rewrite original query removing joined tables from it /// Rewrite original query removing joined tables from it
bool removeJoin(ASTSelectQuery & select) bool removeJoin(ASTSelectQuery & select, const IdentifierMembershipCollector & membership_collector)
{ {
const auto & tables = select.tables(); const auto & tables = select.tables();
if (!tables || tables->children.size() < 2) if (!tables || tables->children.size() < 2)
@ -23,8 +24,22 @@ bool removeJoin(ASTSelectQuery & select)
if (!joined_table.table_join) if (!joined_table.table_join)
return false; return false;
/// We need to remove joined columns and related functions (taking in account aliases if any).
auto * select_list = select.select()->as<ASTExpressionList>();
if (select_list)
{
ASTs new_children;
for (const auto & elem : select_list->children)
{
auto table_no = membership_collector.getIdentsMembership(elem);
if (!table_no.has_value() || *table_no < 1)
new_children.push_back(elem);
}
select_list->children = std::move(new_children);
}
/// The most simple temporary solution: leave only the first table in query. /// The most simple temporary solution: leave only the first table in query.
/// TODO: we also need to remove joined columns and related functions (taking in account aliases if any).
tables->children.resize(1); tables->children.resize(1);
return true; return true;
} }
@ -66,7 +81,8 @@ Block getHeaderForProcessingStage(
case QueryProcessingStage::MAX: case QueryProcessingStage::MAX:
{ {
auto query = query_info.query->clone(); auto query = query_info.query->clone();
removeJoin(*query->as<ASTSelectQuery>()); auto & select = *query->as<ASTSelectQuery>();
removeJoin(select, IdentifierMembershipCollector{select, context});
auto stream = std::make_shared<OneBlockInputStream>( auto stream = std::make_shared<OneBlockInputStream>(
metadata_snapshot->getSampleBlockForColumns(column_names, storage.getVirtuals(), storage.getStorageID())); metadata_snapshot->getSampleBlockForColumns(column_names, storage.getVirtuals(), storage.getStorageID()));

View File

@ -13,8 +13,9 @@ using StorageMetadataPtr = std::shared_ptr<const StorageInMemoryMetadata>;
struct SelectQueryInfo; struct SelectQueryInfo;
class Context; class Context;
class ASTSelectQuery; class ASTSelectQuery;
class IdentifierMembershipCollector;
bool removeJoin(ASTSelectQuery & select); bool removeJoin(ASTSelectQuery & select, const IdentifierMembershipCollector & membership_collector);
Block getHeaderForProcessingStage( Block getHeaderForProcessingStage(
const IStorage & storage, const IStorage & storage,

View File

@ -1,31 +1,34 @@
#include <DataStreams/narrowBlockInputStreams.h>
#include <DataStreams/OneBlockInputStream.h>
#include <Storages/StorageMerge.h>
#include <Storages/StorageFactory.h>
#include <Storages/VirtualColumnUtils.h>
#include <Storages/AlterCommands.h>
#include <Interpreters/Context.h>
#include <Interpreters/TreeRewriter.h>
#include <Interpreters/ExpressionActions.h>
#include <Interpreters/evaluateConstantExpression.h>
#include <Interpreters/InterpreterSelectQuery.h>
#include <Interpreters/getHeaderForProcessingStage.h>
#include <Parsers/ASTSelectQuery.h>
#include <Parsers/ASTLiteral.h>
#include <Parsers/ASTIdentifier.h>
#include <Parsers/ASTExpressionList.h>
#include <DataTypes/DataTypeString.h>
#include <Columns/ColumnString.h>
#include <Common/typeid_cast.h>
#include <Common/checkStackSize.h>
#include <Databases/IDatabase.h>
#include <ext/range.h>
#include <algorithm> #include <algorithm>
#include <ext/range.h>
#include <Columns/ColumnString.h>
#include <Common/checkStackSize.h>
#include <Common/typeid_cast.h>
#include <DataStreams/OneBlockInputStream.h>
#include <DataStreams/narrowBlockInputStreams.h>
#include <DataTypes/DataTypeString.h>
#include <Databases/IDatabase.h>
#include <Interpreters/Context.h>
#include <Interpreters/ExpressionActions.h>
#include <Interpreters/IdentifierSemantic.h>
#include <Interpreters/InterpreterSelectQuery.h>
#include <Interpreters/TreeRewriter.h>
#include <Interpreters/evaluateConstantExpression.h>
#include <Interpreters/getHeaderForProcessingStage.h>
#include <Parsers/ASTExpressionList.h>
#include <Parsers/ASTIdentifier.h>
#include <Parsers/ASTLiteral.h>
#include <Parsers/ASTSelectQuery.h>
#include <Parsers/queryToString.h> #include <Parsers/queryToString.h>
#include <Processors/Transforms/MaterializingTransform.h>
#include <Processors/ConcatProcessor.h> #include <Processors/ConcatProcessor.h>
#include <Processors/Transforms/ExpressionTransform.h> #include <Processors/Transforms/ExpressionTransform.h>
#include <Processors/Transforms/MaterializingTransform.h>
#include <Storages/AlterCommands.h>
#include <Storages/StorageFactory.h>
#include <Storages/StorageMerge.h>
#include <Storages/VirtualColumnUtils.h>
#include <common/logger_useful.h>
namespace DB namespace DB
{ {
@ -43,9 +46,12 @@ namespace ErrorCodes
namespace namespace
{ {
void modifySelect(ASTSelectQuery & select, const TreeRewriterResult & rewriter_result) TreeRewriterResult modifySelect(ASTSelectQuery & select, const TreeRewriterResult & rewriter_result, const Context & context)
{ {
if (removeJoin(select)) IdentifierMembershipCollector membership_collector{select, context};
TreeRewriterResult new_rewriter_result = rewriter_result;
if (removeJoin(select, membership_collector))
{ {
/// Also remove GROUP BY cause ExpressionAnalyzer would check if it has all aggregate columns but joined columns would be missed. /// Also remove GROUP BY cause ExpressionAnalyzer would check if it has all aggregate columns but joined columns would be missed.
select.setExpression(ASTSelectQuery::Expression::GROUP_BY, {}); select.setExpression(ASTSelectQuery::Expression::GROUP_BY, {});
@ -62,7 +68,17 @@ void modifySelect(ASTSelectQuery & select, const TreeRewriterResult & rewriter_r
select.setExpression(ASTSelectQuery::Expression::PREWHERE, {}); select.setExpression(ASTSelectQuery::Expression::PREWHERE, {});
select.setExpression(ASTSelectQuery::Expression::HAVING, {}); select.setExpression(ASTSelectQuery::Expression::HAVING, {});
select.setExpression(ASTSelectQuery::Expression::ORDER_BY, {}); select.setExpression(ASTSelectQuery::Expression::ORDER_BY, {});
new_rewriter_result.aggregates.clear();
for (const auto & agg : rewriter_result.aggregates)
{
auto table_no = membership_collector.getIdentsMembership(std::make_shared<ASTFunction>(*agg));
if (!table_no.has_value() || *table_no < 1)
new_rewriter_result.aggregates.push_back(agg);
}
} }
return new_rewriter_result;
} }
} }
@ -159,7 +175,7 @@ QueryProcessingStage::Enum StorageMerge::getQueryProcessingStage(const Context &
/// (see modifySelect()/removeJoin()) /// (see modifySelect()/removeJoin())
/// ///
/// And for this we need to return FetchColumns. /// And for this we need to return FetchColumns.
if (removeJoin(modified_select)) if (removeJoin(modified_select, IdentifierMembershipCollector{modified_select, context}))
return QueryProcessingStage::FetchColumns; return QueryProcessingStage::FetchColumns;
auto stage_in_source_tables = QueryProcessingStage::FetchColumns; auto stage_in_source_tables = QueryProcessingStage::FetchColumns;
@ -303,8 +319,9 @@ Pipe StorageMerge::createSources(
modified_query_info.query = query_info.query->clone(); modified_query_info.query = query_info.query->clone();
/// Original query could contain JOIN but we need only the first joined table and its columns. /// Original query could contain JOIN but we need only the first joined table and its columns.
auto & modified_select = modified_query_info.query->as<ASTSelectQuery &>(); auto & modified_select = modified_query_info.query->as<ASTSelectQuery &>();\
modifySelect(modified_select, *query_info.syntax_analyzer_result); auto new_analyzer_res = modifySelect(modified_select, *query_info.syntax_analyzer_result, *modified_context);
modified_query_info.syntax_analyzer_result = std::make_shared<TreeRewriterResult>(std::move(new_analyzer_res));
VirtualColumnUtils::rewriteEntityInAst(modified_query_info.query, "_table", table_name); VirtualColumnUtils::rewriteEntityInAst(modified_query_info.query, "_table", table_name);

View File

@ -5,3 +5,4 @@
1 1
0 1 0 1
0 1 0 1
1

View File

@ -17,6 +17,11 @@ SELECT ID FROM m INNER JOIN b USING(key) GROUP BY ID;
SELECT * FROM m INNER JOIN b USING(key) WHERE ID = 1 HAVING ID = 1 ORDER BY ID; SELECT * FROM m INNER JOIN b USING(key) WHERE ID = 1 HAVING ID = 1 ORDER BY ID;
SELECT * FROM m INNER JOIN b USING(key) WHERE ID = 1 GROUP BY ID, key HAVING ID = 1 ORDER BY ID; SELECT * FROM m INNER JOIN b USING(key) WHERE ID = 1 GROUP BY ID, key HAVING ID = 1 ORDER BY ID;
SELECT sum(b.ID) FROM m FULL JOIN b ON (m.key == b.key) GROUP BY key;
-- still not working because columns from different table under aggregation
SELECT sum(b.ID + m.key) FROM m FULL JOIN b ON (m.key == b.key) GROUP BY key; -- { serverError 47 }
DROP TABLE IF EXISTS a; DROP TABLE IF EXISTS a;
DROP TABLE IF EXISTS b; DROP TABLE IF EXISTS b;
DROP TABLE IF EXISTS m; DROP TABLE IF EXISTS m;