From 8b4fabe60ce6aa3c5e62c2bb799ff76a36a71181 Mon Sep 17 00:00:00 2001 From: Vitaly Baranov Date: Fri, 2 Jul 2021 14:20:41 +0300 Subject: [PATCH] Fix crash on call dictGet() with bad arguments. --- .../MarkTableIdentifiersVisitor.cpp | 43 +++++++++++-------- .../MarkTableIdentifiersVisitor.h | 2 +- ...arts_identifiers_in_wrong_places.reference | 1 + ...hree_parts_identifiers_in_wrong_places.sql | 7 +++ 4 files changed, 33 insertions(+), 20 deletions(-) create mode 100644 tests/queries/0_stateless/01936_three_parts_identifiers_in_wrong_places.reference create mode 100644 tests/queries/0_stateless/01936_three_parts_identifiers_in_wrong_places.sql diff --git a/src/Interpreters/MarkTableIdentifiersVisitor.cpp b/src/Interpreters/MarkTableIdentifiersVisitor.cpp index 52f180aa199..1f418e759e7 100644 --- a/src/Interpreters/MarkTableIdentifiersVisitor.cpp +++ b/src/Interpreters/MarkTableIdentifiersVisitor.cpp @@ -11,6 +11,26 @@ namespace DB { +namespace +{ + void replaceArgumentWithTableIdentifierIfNotAlias(ASTFunction & func, size_t argument_pos, const Aliases & aliases) + { + if (!func.arguments || (func.arguments->children.size() <= argument_pos)) + return; + auto arg = func.arguments->children[argument_pos]; + auto identifier = arg->as(); + if (!identifier) + return; + if (aliases.contains(identifier->name())) + return; + auto table_identifier = identifier->createTable(); + if (!table_identifier) + return; + func.arguments->children[argument_pos] = table_identifier; + } +} + + bool MarkTableIdentifiersMatcher::needChildVisit(ASTPtr & node, const ASTPtr & child) { if (child->as()) @@ -23,37 +43,22 @@ bool MarkTableIdentifiersMatcher::needChildVisit(ASTPtr & node, const ASTPtr & c void MarkTableIdentifiersMatcher::visit(ASTPtr & ast, Data & data) { if (auto * node_func = ast->as()) - visit(*node_func, ast, data); + visit(*node_func, data); } -void MarkTableIdentifiersMatcher::visit(const ASTFunction & func, ASTPtr & ptr, Data & data) +void MarkTableIdentifiersMatcher::visit(ASTFunction & func, const Data & data) { /// `IN t` can be specified, where t is a table, which is equivalent to `IN (SELECT * FROM t)`. if (checkFunctionIsInOrGlobalInOperator(func)) { - auto ast = func.arguments->children.at(1); - auto opt_name = tryGetIdentifierName(ast); - if (opt_name && !data.aliases.count(*opt_name) && ast->as()) - { - ptr->as()->arguments->children[1] = ast->as()->createTable(); - assert(ptr->as()->arguments->children[1]); - } + replaceArgumentWithTableIdentifierIfNotAlias(func, 1, data.aliases); } // First argument of joinGet can be a table name, perhaps with a database. // First argument of dictGet can be a dictionary name, perhaps with a database. else if (functionIsJoinGet(func.name) || functionIsDictGet(func.name)) { - if (!func.arguments || func.arguments->children.empty()) - return; - - auto ast = func.arguments->children.at(0); - auto opt_name = tryGetIdentifierName(ast); - if (opt_name && !data.aliases.count(*opt_name) && ast->as()) - { - ptr->as()->arguments->children[0] = ast->as()->createTable(); - assert(ptr->as()->arguments->children[0]); - } + replaceArgumentWithTableIdentifierIfNotAlias(func, 0, data.aliases); } } diff --git a/src/Interpreters/MarkTableIdentifiersVisitor.h b/src/Interpreters/MarkTableIdentifiersVisitor.h index 0d80b865e53..d05c067397b 100644 --- a/src/Interpreters/MarkTableIdentifiersVisitor.h +++ b/src/Interpreters/MarkTableIdentifiersVisitor.h @@ -24,7 +24,7 @@ public: static void visit(ASTPtr & ast, Data & data); private: - static void visit(const ASTFunction & func, ASTPtr &, Data &); + static void visit(ASTFunction & func, const Data & data); }; using MarkTableIdentifiersVisitor = MarkTableIdentifiersMatcher::Visitor; diff --git a/tests/queries/0_stateless/01936_three_parts_identifiers_in_wrong_places.reference b/tests/queries/0_stateless/01936_three_parts_identifiers_in_wrong_places.reference new file mode 100644 index 00000000000..bbf76e61257 --- /dev/null +++ b/tests/queries/0_stateless/01936_three_parts_identifiers_in_wrong_places.reference @@ -0,0 +1 @@ +still alive diff --git a/tests/queries/0_stateless/01936_three_parts_identifiers_in_wrong_places.sql b/tests/queries/0_stateless/01936_three_parts_identifiers_in_wrong_places.sql new file mode 100644 index 00000000000..d2ca771edc5 --- /dev/null +++ b/tests/queries/0_stateless/01936_three_parts_identifiers_in_wrong_places.sql @@ -0,0 +1,7 @@ +SELECT dictGet(t.nest.a, concat(currentDatabase(), '.dict.dict'), 's', number) FROM numbers(5); -- { serverError 47 } + +SELECT dictGetFloat64(t.b.s, 'database_for_dict.dict1', dictGetFloat64('Ta\0', toUInt64('databas\0_for_dict.dict1databas\0_for_dict.dict1', dictGetFloat64('', '', toUInt64(1048577), toDate(NULL)), NULL), toDate(dictGetFloat64(257, 'database_for_dict.dict1database_for_dict.dict1', '', toUInt64(NULL), 2, toDate(NULL)), '2019-05-2\0')), NULL, toUInt64(dictGetFloat64('', '', toUInt64(-9223372036854775808), toDate(NULL)), NULL)); -- { serverError 47 } + +SELECT NULL AND (2147483648 AND NULL) AND -2147483647, toUUID(((1048576 AND NULL) AND (2147483647 AND 257 AND NULL AND -2147483649) AND NULL) IN (test_01103.t1_distr.id), '00000000-e1fe-11e\0-bb8f\0853d60c00749'), stringToH3('89184926cc3ffff89184926cc3ffff89184926cc3ffff89184926cc3ffff89184926cc3ffff89184926cc3ffff89184926cc3ffff89184926cc3ffff'); -- { serverError 47 } + +SELECT 'still alive';