fix count(*) with multiple_joins_rewriter_version = 1

This commit is contained in:
Artem Zuikov 2020-04-15 20:47:08 +03:00
parent e8cd92bba3
commit f08cdfcc4c
3 changed files with 56 additions and 2 deletions

View File

@ -9,6 +9,7 @@
#include <Parsers/ASTTablesInSelectQuery.h>
#include <Parsers/ASTIdentifier.h>
#include <Parsers/ASTExpressionList.h>
#include <Parsers/ASTFunction.h>
#include <Parsers/ParserTablesInSelectQuery.h>
#include <Parsers/ExpressionListParsers.h>
#include <Parsers/parseQuery.h>
@ -129,6 +130,8 @@ private:
/// Make aliases maps (alias -> column_name, column_name -> alias)
struct ColumnAliasesMatcher
{
using Visitor = ConstInDepthNodeVisitor<ColumnAliasesMatcher, true>;
struct Data
{
const std::vector<DatabaseAndTableWithAlias> tables;
@ -137,6 +140,7 @@ struct ColumnAliasesMatcher
std::unordered_map<String, String> aliases; /// alias -> long_name
std::vector<std::pair<ASTIdentifier *, bool>> compound_identifiers;
std::set<String> allowed_long_names; /// original names allowed as aliases '--t.x as t.x' (select expressions only).
bool inside_function = false;
explicit Data(const std::vector<DatabaseAndTableWithAlias> && tables_)
: tables(tables_)
@ -192,6 +196,10 @@ struct ColumnAliasesMatcher
static bool needChildVisit(const ASTPtr & node, const ASTPtr &)
{
/// Do not go into subqueries. Function visits children itself.
if (node->as<ASTSubquery>() ||
node->as<ASTFunction>())
return false;
return !node->as<ASTQualifiedAsterisk>();
}
@ -199,11 +207,24 @@ struct ColumnAliasesMatcher
{
if (auto * t = ast->as<ASTIdentifier>())
visit(*t, ast, data);
else if (auto * f = ast->as<ASTFunction>())
visit(*f, ast, data);
if (ast->as<ASTAsterisk>() || ast->as<ASTQualifiedAsterisk>())
/// Do not allow asterisks but ignore them inside functions. I.e. allow 'count(*)'.
if (!data.inside_function && (ast->as<ASTAsterisk>() || ast->as<ASTQualifiedAsterisk>()))
throw Exception("Multiple JOIN do not support asterisks for complex queries yet", ErrorCodes::NOT_IMPLEMENTED);
}
static void visit(const ASTFunction &, const ASTPtr & ast, Data & data)
{
/// Grandchild case: Function -> (ExpressionList) -> Asterisk
data.inside_function = true;
Visitor visitor(data);
for (auto & child : ast->children)
visitor.visit(child);
data.inside_function = false;
}
static void visit(const ASTIdentifier & const_node, const ASTPtr &, Data & data)
{
ASTIdentifier & node = const_cast<ASTIdentifier &>(const_node); /// we know it's not const
@ -348,7 +369,7 @@ bool needRewrite(ASTSelectQuery & select, std::vector<const ASTTableExpression *
using RewriteMatcher = OneTypeMatcher<RewriteTablesVisitorData>;
using RewriteVisitor = InDepthNodeVisitor<RewriteMatcher, true>;
using ExtractAsterisksVisitor = ConstInDepthNodeVisitor<ExtractAsterisksMatcher, true>;
using ColumnAliasesVisitor = ConstInDepthNodeVisitor<ColumnAliasesMatcher, true>;
using ColumnAliasesVisitor = ColumnAliasesMatcher::Visitor;
using AppendSemanticMatcher = OneTypeMatcher<AppendSemanticVisitorData>;
using AppendSemanticVisitor = InDepthNodeVisitor<AppendSemanticMatcher, true>;

View File

@ -0,0 +1,4 @@
2
1
2
1

View File

@ -0,0 +1,29 @@
SET multiple_joins_rewriter_version = 2;
SELECT count(*)
FROM numbers(2) AS n1, numbers(3) AS n2, numbers(4) AS n3
WHERE (n1.number = n2.number) AND (n2.number = n3.number);
SELECT count(*) c FROM (
SELECT count(*), count(*) as c
FROM numbers(2) AS n1, numbers(3) AS n2, numbers(4) AS n3
WHERE (n1.number = n2.number) AND (n2.number = n3.number)
AND (SELECT count(*) FROM numbers(1)) = 1
)
WHERE (SELECT count(*) FROM numbers(2)) = 2
HAVING c IN(SELECT count(*) c FROM numbers(1));
SET multiple_joins_rewriter_version = 1;
SELECT count(*)
FROM numbers(2) AS n1, numbers(3) AS n2, numbers(4) AS n3
WHERE (n1.number = n2.number) AND (n2.number = n3.number);
SELECT count(*) c FROM (
SELECT count(*), count(*) as c
FROM numbers(2) AS n1, numbers(3) AS n2, numbers(4) AS n3
WHERE (n1.number = n2.number) AND (n2.number = n3.number)
AND (SELECT count(*) FROM numbers(1)) = 1
)
WHERE (SELECT count(*) FROM numbers(2)) = 2
HAVING c IN(SELECT count(*) c FROM numbers(1));