ClickHouse/dbms/src/Interpreters/ExecuteScalarSubqueriesVisitor.cpp
Alexey Milovidov c2d4c11cb8 Fixes for #5331
2019-06-14 20:15:30 +03:00

148 lines
4.7 KiB
C++

#include <Parsers/ASTFunction.h>
#include <Parsers/ASTLiteral.h>
#include <Parsers/ASTSubquery.h>
#include <Parsers/ASTSelectQuery.h>
#include <Parsers/ASTTablesInSelectQuery.h>
#include <Parsers/ASTExpressionList.h>
#include <Interpreters/Context.h>
#include <Interpreters/QueryNormalizer.h>
#include <Interpreters/InterpreterSelectWithUnionQuery.h>
#include <Interpreters/ExecuteScalarSubqueriesVisitor.h>
#include <Interpreters/addTypeConversionToAST.h>
#include <DataStreams/IBlockInputStream.h>
#include <DataTypes/DataTypeAggregateFunction.h>
namespace DB
{
namespace ErrorCodes
{
extern const int INCORRECT_RESULT_OF_SCALAR_SUBQUERY;
extern const int TOO_MANY_ROWS;
}
bool ExecuteScalarSubqueriesMatcher::needChildVisit(ASTPtr & node, const ASTPtr & child)
{
/// Processed
if (node->as<ASTSubquery>() || node->as<ASTFunction>())
return false;
/// Don't descend into subqueries in FROM section
if (node->as<ASTTableExpression>())
return false;
if (node->as<ASTSelectQuery>())
{
/// Do not go to FROM, JOIN, UNION.
if (child->as<ASTTableExpression>() || child->as<ASTSelectQuery>())
return false;
}
return true;
}
void ExecuteScalarSubqueriesMatcher::visit(ASTPtr & ast, Data & data)
{
if (const auto * t = ast->as<ASTSubquery>())
visit(*t, ast, data);
if (const auto * t = ast->as<ASTFunction>())
visit(*t, ast, data);
}
void ExecuteScalarSubqueriesMatcher::visit(const ASTSubquery & subquery, ASTPtr & ast, Data & data)
{
Context subquery_context = data.context;
Settings subquery_settings = data.context.getSettings();
subquery_settings.max_result_rows = 1;
subquery_settings.extremes = 0;
subquery_context.setSettings(subquery_settings);
ASTPtr subquery_select = subquery.children.at(0);
BlockIO res = InterpreterSelectWithUnionQuery(
subquery_select, subquery_context, SelectQueryOptions(QueryProcessingStage::Complete, data.subquery_depth + 1)).execute();
Block block;
try
{
block = res.in->read();
if (!block)
{
/// Interpret subquery with empty result as Null literal
auto ast_new = std::make_unique<ASTLiteral>(Null());
ast_new->setAlias(ast->tryGetAlias());
ast = std::move(ast_new);
return;
}
if (block.rows() != 1 || res.in->read())
throw Exception("Scalar subquery returned more than one row", ErrorCodes::INCORRECT_RESULT_OF_SCALAR_SUBQUERY);
}
catch (const Exception & e)
{
if (e.code() == ErrorCodes::TOO_MANY_ROWS)
throw Exception("Scalar subquery returned more than one row", ErrorCodes::INCORRECT_RESULT_OF_SCALAR_SUBQUERY);
else
throw;
}
size_t columns = block.columns();
if (columns == 1)
{
auto lit = std::make_unique<ASTLiteral>((*block.safeGetByPosition(0).column)[0]);
lit->alias = subquery.alias;
lit->prefer_alias_to_column_name = subquery.prefer_alias_to_column_name;
ast = addTypeConversionToAST(std::move(lit), block.safeGetByPosition(0).type->getName());
}
else
{
auto tuple = std::make_shared<ASTFunction>();
tuple->alias = subquery.alias;
ast = tuple;
tuple->name = "tuple";
auto exp_list = std::make_shared<ASTExpressionList>();
tuple->arguments = exp_list;
tuple->children.push_back(tuple->arguments);
exp_list->children.resize(columns);
for (size_t i = 0; i < columns; ++i)
{
exp_list->children[i] = addTypeConversionToAST(
std::make_unique<ASTLiteral>((*block.safeGetByPosition(i).column)[0]),
block.safeGetByPosition(i).type->getName());
}
}
}
void ExecuteScalarSubqueriesMatcher::visit(const ASTFunction & func, ASTPtr & ast, Data & data)
{
/// Don't descend into subqueries in arguments of IN operator.
/// But if an argument is not subquery, than deeper may be scalar subqueries and we need to descend in them.
std::vector<ASTPtr *> out;
if (functionIsInOrGlobalInOperator(func.name))
{
for (auto & child : ast->children)
{
if (child != func.arguments)
out.push_back(&child);
else
for (size_t i = 0, size = func.arguments->children.size(); i < size; ++i)
if (i != 1 || !func.arguments->children[i]->as<ASTSubquery>())
out.push_back(&func.arguments->children[i]);
}
}
else
for (auto & child : ast->children)
out.push_back(&child);
for (ASTPtr * add_node : out)
Visitor(data).visit(*add_node);
}
}