Analyzers: added support for lambda expressions when searching for columns [#METR-23947].

This commit is contained in:
Alexey Milovidov 2017-01-12 05:06:50 +03:00
parent 916552f175
commit 522e96d4bf
5 changed files with 136 additions and 18 deletions

View File

@ -1,3 +1,4 @@
#include <vector>
#include <DB/Analyzers/AnalyzeColumns.h>
#include <DB/Analyzers/CollectAliases.h>
#include <DB/Parsers/formatAST.h>
@ -23,11 +24,15 @@ namespace ErrorCodes
extern const int AMBIGUOUS_COLUMN_NAME;
extern const int UNKNOWN_TABLE;
extern const int THERE_IS_NO_COLUMN;
extern const int BAD_LAMBDA;
}
namespace
{
/// Find by fully qualified name, like db.table.column
static const CollectTables::TableInfo * findTableByDatabaseAndTableName(
const CollectTables::TableInfo * findTableByDatabaseAndTableName(
const CollectTables & tables, const String & database_name, const String & table_name)
{
for (const auto & table : tables.tables)
@ -56,7 +61,7 @@ static const CollectTables::TableInfo * findTableByDatabaseAndTableName(
* If there is no primary matches and many secondary matches - ambiguity.
* If there is no any matches - not found.
*/
static const CollectTables::TableInfo * findTableByNameOrAlias(
const CollectTables::TableInfo * findTableByNameOrAlias(
const CollectTables & tables, const String & name)
{
const CollectTables::TableInfo * primary_match = nullptr;
@ -92,7 +97,7 @@ static const CollectTables::TableInfo * findTableByNameOrAlias(
* Select a table, where specified column exists.
* If more than one such table - ambiguity.
*/
static const CollectTables::TableInfo * findTableWithUnqualifiedName(const CollectTables & tables, const String & column_name)
const CollectTables::TableInfo * findTableWithUnqualifiedName(const CollectTables & tables, const String & column_name)
{
const CollectTables::TableInfo * res = nullptr;
@ -126,7 +131,7 @@ static const CollectTables::TableInfo * findTableWithUnqualifiedName(const Colle
/// Create maximum-qualified identifier for column in table.
static ASTPtr createASTIdentifierForColumnInTable(const String & column, const CollectTables::TableInfo & table)
ASTPtr createASTIdentifierForColumnInTable(const String & column, const CollectTables::TableInfo & table)
{
ASTPtr database_name_identifier_node;
if (!table.database_name.empty())
@ -166,7 +171,7 @@ static ASTPtr createASTIdentifierForColumnInTable(const String & column, const C
}
static void createASTsForAllColumnsInTable(const CollectTables::TableInfo & table, ASTs & res)
void createASTsForAllColumnsInTable(const CollectTables::TableInfo & table, ASTs & res)
{
if (table.storage)
for (const auto & name : table.storage->getColumnNamesList())
@ -177,7 +182,7 @@ static void createASTsForAllColumnsInTable(const CollectTables::TableInfo & tabl
}
static ASTs expandUnqualifiedAsterisk(
ASTs expandUnqualifiedAsterisk(
AnalyzeColumns::Columns & columns, const CollectAliases & aliases, const CollectTables & tables)
{
ASTs res;
@ -187,7 +192,7 @@ static ASTs expandUnqualifiedAsterisk(
}
static ASTs expandQualifiedAsterisk(
ASTs expandQualifiedAsterisk(
const IAST & ast, AnalyzeColumns::Columns & columns, const CollectAliases & aliases, const CollectTables & tables)
{
if (ast.children.size() != 1)
@ -217,8 +222,17 @@ static ASTs expandQualifiedAsterisk(
}
static void processIdentifier(
const ASTPtr & ast, AnalyzeColumns::Columns & columns, const CollectAliases & aliases, const CollectTables & tables)
/// Parameters of lambda expressions.
using LambdaParameters = std::vector<String>;
/// Currently visible parameters in all scopes of lambda expressions.
/// Lambda expressions could be nested: arrayMap(x -> arrayMap(y -> x[y], x), [[1], [2, 3]])
using LambdaScopes = std::vector<LambdaParameters>;
void processIdentifier(
const ASTPtr & ast, AnalyzeColumns::Columns & columns, const CollectAliases & aliases, const CollectTables & tables,
const LambdaScopes & lambda_scopes)
{
const ASTIdentifier & identifier = static_cast<const ASTIdentifier &>(*ast);
@ -233,6 +247,15 @@ static void processIdentifier(
if (identifier.children.empty())
{
/** Lambda parameters are not columns from table. Just skip them.
* If identifier name are known as lambda parameter in any currently visible scope of lambda expressions.
*/
if (lambda_scopes.end() != std::find_if(lambda_scopes.begin(), lambda_scopes.end(),
[&identifier] (const LambdaParameters & names) { return names.end() != std::find(names.begin(), names.end(), identifier.name); }))
{
return;
}
table = findTableWithUnqualifiedName(tables, identifier.name);
if (table)
column_name = identifier.name;
@ -315,7 +338,67 @@ static void processIdentifier(
}
static void processImpl(ASTPtr & ast, AnalyzeColumns::Columns & columns, const CollectAliases & aliases, const CollectTables & tables)
LambdaParameters extractLambdaParameters(ASTPtr & ast)
{
/// Lambda parameters could be specified in AST in two forms:
/// - just as single parameter: x -> x + 1
/// - parameters in tuple: (x, y) -> x + 1
#define LAMBDA_ERROR_MESSAGE " There are two valid forms of lambda expressions: x -> ... and (x, y...) -> ..."
if (!ast->tryGetAlias().empty())
throw Exception("Lambda parameters cannot have aliases."
LAMBDA_ERROR_MESSAGE, ErrorCodes::BAD_LAMBDA);
if (const ASTIdentifier * identifier = typeid_cast<const ASTIdentifier *>(ast.get()))
{
return { identifier->name };
}
else if (const ASTFunction * function = typeid_cast<const ASTFunction *>(ast.get()))
{
if (function->name != "tuple")
throw Exception("Left hand side of '->' or first argument of 'lambda' is a function, but this function is not tuple."
LAMBDA_ERROR_MESSAGE " Found function '" + function->name + "' instead.", ErrorCodes::BAD_LAMBDA);
if (!function->arguments || function->arguments->children.empty())
throw Exception("Left hand side of '->' or first argument of 'lambda' is empty tuple."
LAMBDA_ERROR_MESSAGE, ErrorCodes::BAD_LAMBDA);
LambdaParameters res;
res.reserve(function->arguments->children.size());
for (const ASTPtr & arg : function->arguments->children)
{
const ASTIdentifier * arg_identifier = typeid_cast<const ASTIdentifier *>(arg.get());
if (!arg_identifier)
throw Exception("Left hand side of '->' or first argument of 'lambda' contains something that is not just identifier."
LAMBDA_ERROR_MESSAGE, ErrorCodes::BAD_LAMBDA);
if (!arg_identifier->children.empty())
throw Exception("Left hand side of '->' or first argument of 'lambda' contains compound identifier."
LAMBDA_ERROR_MESSAGE, ErrorCodes::BAD_LAMBDA);
if (!arg_identifier->alias.empty())
throw Exception("Lambda parameters cannot have aliases."
LAMBDA_ERROR_MESSAGE, ErrorCodes::BAD_LAMBDA);
res.emplace_back(arg_identifier->name);
}
return res;
}
else
throw Exception("Unexpected left hand side of '->' or first argument of 'lambda'."
LAMBDA_ERROR_MESSAGE, ErrorCodes::BAD_LAMBDA);
#undef LAMBDA_ERROR_MESSAGE
}
void processImpl(ASTPtr & ast, AnalyzeColumns::Columns & columns, const CollectAliases & aliases, const CollectTables & tables,
LambdaScopes & lambda_scopes)
{
/// Don't go into subqueries and table-like expressions.
if (typeid_cast<const ASTSelectQuery *>(ast.get())
@ -333,6 +416,25 @@ static void processImpl(ASTPtr & ast, AnalyzeColumns::Columns & columns, const C
{
func->arguments->children.clear();
}
/** Special case for lambda functions, like (x, y) -> x + y + column.
* We must memoize parameters from left hand side (x, y)
* and then analyze right hand side, skipping that parameters.
* In example, from right hand side "x + y + column", only "column" should be searched in tables,
* because x and y are just lambda parameters.
*/
if (func->name == "lambda")
{
auto num_arguments = func->arguments->children.size();
if (num_arguments != 2)
throw Exception("Lambda expression ('->' or 'lambda' function) must have exactly two arguments."
" Found " + toString(num_arguments) + " instead.", ErrorCodes::BAD_LAMBDA);
lambda_scopes.emplace_back(extractLambdaParameters(func->arguments->children[0]));
processImpl(func->arguments->children[1], columns, aliases, tables, lambda_scopes);
lambda_scopes.pop_back();
return;
}
}
else if (typeid_cast<ASTExpressionList *>(ast.get()))
{
@ -356,21 +458,22 @@ static void processImpl(ASTPtr & ast, AnalyzeColumns::Columns & columns, const C
}
else if (typeid_cast<const ASTIdentifier *>(ast.get()))
{
processIdentifier(ast, columns, aliases, tables);
processIdentifier(ast, columns, aliases, tables, lambda_scopes);
return;
}
/// TODO Skip parameters of lambda functions.
for (auto & child : ast->children)
processImpl(child, columns, aliases, tables);
processImpl(child, columns, aliases, tables, lambda_scopes);
}
}
void AnalyzeColumns::process(ASTPtr & ast, const CollectAliases & aliases, const CollectTables & tables)
{
LambdaScopes lambda_scopes;
for (auto & child : ast->children)
processImpl(child, columns, aliases, tables);
processImpl(child, columns, aliases, tables, lambda_scopes);
}

View File

@ -50,9 +50,9 @@ try
analyze_columns.dump(out);
out.next();
std::cerr << "\n";
std::cout << "\n";
formatAST(*ast, std::cout, 0, false, true);
std::cerr << "\n";
std::cout << "\n";
return 0;
}

View File

@ -5,4 +5,14 @@ one.dummy -> dummy UInt8. Database name: system. Table name: one. Alias: (none).
system.numbers.number -> number UInt64. Database name: system. Table name: numbers. Alias: t. Storage: SystemNumbers. AST: system.numbers.number
system.one.dummy -> dummy UInt8. Database name: system. Table name: one. Alias: (none). Storage: SystemOne. AST: system.one.dummy
t.number -> number UInt64. Database name: system. Table name: numbers. Alias: t. Storage: SystemNumbers. AST: t.number
SELECT dummy, number, one.dummy, numbers.number, system.one.dummy, system.numbers.number, system.one.dummy, system.numbers.number, system.one.dummy, system.numbers.number, system.one.dummy, system.numbers.number, system.numbers.number, t.number FROM system.one , system.numbers AS t
SELECT dummy, number, one.dummy, numbers.number, system.one.dummy, system.numbers.number, system.one.dummy, system.numbers.number, system.one.dummy, system.numbers.number, system.one.dummy, system.numbers.number, system.numbers.number, t.number FROM system.one , system.numbers AS t
c -> c UInt8. Database name: (none). Table name: (none). Alias: (none). Storage: (none). AST: c
SELECT arrayMap((x, y) -> arrayMap((y, z) -> x[y], x, c), [[1], [2, 3]]) FROM (SELECT 1 AS c, 2 AS d)
c -> c UInt8. Database name: (none). Table name: (none). Alias: (none). Storage: (none). AST: c
x -> x UInt8. Database name: (none). Table name: (none). Alias: (none). Storage: (none). AST: x
SELECT x, arrayMap((x, y) -> (x + y), x, c) FROM (SELECT 1 AS x, 2 AS c)

View File

@ -1,3 +1,7 @@
#!/bin/sh
echo "SELECT dummy, number, one.dummy, numbers.number, system.one.dummy, system.numbers.number, one.*, numbers.*, system.one.*, system.numbers.*, *, t.*, t.number FROM system.one, system.numbers AS t" | ./analyze_columns
echo
echo "SELECT arrayMap((x, y) -> arrayMap((y, z) -> x[y], x, c), [[1], [2, 3]]) FROM (SELECT 1 AS c, 2 AS d)" | ./analyze_columns
echo
echo "SELECT x, arrayMap((x, y) -> x + y, x, c) FROM (SELECT 1 AS x, 2 AS c)" | ./analyze_columns

View File

@ -358,6 +358,7 @@ namespace ErrorCodes
extern const int INDEX_OF_POSITIONAL_ARGUMENT_IS_OUT_OF_RANGE = 353;
extern const int ZLIB_INFLATE_FAILED = 354;
extern const int ZLIB_DEFLATE_FAILED = 355;
extern const int BAD_LAMBDA = 356;
extern const int KEEPER_EXCEPTION = 999;
extern const int POCO_EXCEPTION = 1000;