Add new table function viewIfPermitted().

This commit is contained in:
Vitaly Baranov 2022-07-09 17:51:59 +02:00
parent c85b2b5732
commit fbb2e14d54
11 changed files with 281 additions and 46 deletions

View File

@ -106,7 +106,9 @@ void Client::processError(const String & query) const
std::vector<String> Client::loadWarningMessages()
{
std::vector<String> messages;
connection->sendQuery(connection_parameters.timeouts, "SELECT message FROM system.warnings", "" /* query_id */,
connection->sendQuery(connection_parameters.timeouts,
"SELECT * FROM viewIfPermitted(SELECT message FROM system.warnings ELSE null('message String'))",
"" /* query_id */,
QueryProcessingStage::Complete,
&global_context->getSettingsRef(),
&global_context->getClientInfo(), false, {});

View File

@ -50,52 +50,58 @@ static String getLoadSuggestionQuery(Int32 suggestion_limit, bool basic_suggesti
{
/// NOTE: Once you will update the completion list,
/// do not forget to update 01676_clickhouse_client_autocomplete.sh
WriteBufferFromOwnString query;
query << "SELECT DISTINCT arrayJoin(extractAll(name, '[\\\\w_]{2,}')) AS res FROM ("
"SELECT name FROM system.functions"
" UNION ALL "
"SELECT name FROM system.table_engines"
" UNION ALL "
"SELECT name FROM system.formats"
" UNION ALL "
"SELECT name FROM system.table_functions"
" UNION ALL "
"SELECT name FROM system.data_type_families"
" UNION ALL "
"SELECT name FROM system.merge_tree_settings"
" UNION ALL "
"SELECT name FROM system.settings"
" UNION ALL ";
String query;
auto add_subquery = [&](std::string_view select, std::string_view result_column_name)
{
if (!query.empty())
query += " UNION ALL ";
query += fmt::format("SELECT * FROM viewIfPermitted({} ELSE null('{} String'))", select, result_column_name);
};
auto add_column = [&](std::string_view column_name, std::string_view table_name, bool distinct, std::optional<Int64> limit)
{
add_subquery(
fmt::format(
"SELECT {}{} FROM system.{}{}",
(distinct ? "DISTINCT " : ""),
column_name,
table_name,
(limit ? (" LIMIT " + std::to_string(*limit)) : "")),
column_name);
};
add_column("name", "functions", false, {});
add_column("name", "table_engines", false, {});
add_column("name", "formats", false, {});
add_column("name", "table_functions", false, {});
add_column("name", "data_type_families", false, {});
add_column("name", "merge_tree_settings", false, {});
add_column("name", "settings", false, {});
if (!basic_suggestion)
{
query << "SELECT cluster FROM system.clusters"
" UNION ALL "
"SELECT macro FROM system.macros"
" UNION ALL "
"SELECT policy_name FROM system.storage_policies"
" UNION ALL ";
add_column("cluster", "clusters", false, {});
add_column("macro", "macros", false, {});
add_column("policy_name", "storage_policies", false, {});
}
query << "SELECT concat(func.name, comb.name) FROM system.functions AS func CROSS JOIN system.aggregate_function_combinators AS comb WHERE is_aggregate";
add_subquery("SELECT concat(func.name, comb.name) AS x FROM system.functions AS func CROSS JOIN system.aggregate_function_combinators AS comb WHERE is_aggregate", "x");
/// The user may disable loading of databases, tables, columns by setting suggestion_limit to zero.
if (suggestion_limit > 0)
{
String limit_str = toString(suggestion_limit);
query << " UNION ALL "
"SELECT name FROM system.databases LIMIT " << limit_str
<< " UNION ALL "
"SELECT DISTINCT name FROM system.tables LIMIT " << limit_str
<< " UNION ALL ";
add_column("name", "databases", false, suggestion_limit);
add_column("name", "tables", true, suggestion_limit);
if (!basic_suggestion)
{
query << "SELECT DISTINCT name FROM system.dictionaries LIMIT " << limit_str
<< " UNION ALL ";
add_column("name", "dictionaries", true, suggestion_limit);
}
query << "SELECT DISTINCT name FROM system.columns LIMIT " << limit_str;
add_column("name", "columns", true, suggestion_limit);
}
query << ") WHERE notEmpty(res)";
return query.str();
query = "SELECT DISTINCT arrayJoin(extractAll(name, '[\\\\w_]{2,}')) AS res FROM (" + query + ") WHERE notEmpty(res)";
return query;
}
template <typename ConnectionType>

View File

@ -509,6 +509,25 @@ void ASTFunction::formatImplWithoutAlias(const FormatSettings & settings, Format
settings.ostr << ')';
written = true;
}
if (!written && name == "viewIfPermitted"sv)
{
/// viewIfPermitted() needs special formatting: ELSE instead of comma between arguments, and better indents too.
const auto * nl_or_nothing = settings.one_line ? "" : "\n";
auto indent0 = settings.one_line ? "" : String(4u * frame.indent, ' ');
auto indent1 = settings.one_line ? "" : String(4u * (frame.indent + 1), ' ');
auto indent2 = settings.one_line ? "" : String(4u * (frame.indent + 2), ' ');
settings.ostr << (settings.hilite ? hilite_function : "") << name << "(" << (settings.hilite ? hilite_none : "") << nl_or_nothing;
FormatStateStacked frame_nested = frame;
frame_nested.need_parens = false;
frame_nested.indent += 2;
arguments->children[0]->formatImpl(settings, state, frame_nested);
settings.ostr << nl_or_nothing << indent1 << (settings.hilite ? hilite_keyword : "") << (settings.one_line ? " " : "")
<< "ELSE " << (settings.hilite ? hilite_none : "") << nl_or_nothing << indent2;
arguments->children[1]->formatImpl(settings, state, frame_nested);
settings.ostr << nl_or_nothing << indent0 << ")";
return;
}
}
if (!written && arguments->children.size() >= 2)

View File

@ -1068,13 +1068,16 @@ bool ParserFunction::parseImpl(Pos & pos, ASTPtr & node, Expected & expected)
bool ParserTableFunctionView::parseImpl(Pos & pos, ASTPtr & node, Expected & expected)
{
ParserIdentifier id_parser;
ParserKeyword view("VIEW");
ParserSelectWithUnionQuery select;
ASTPtr identifier;
ASTPtr query;
if (!view.ignore(pos, expected))
bool if_permitted = false;
if (ParserKeyword{"VIEWIFPERMITTED"}.ignore(pos, expected))
if_permitted = true;
else if (!ParserKeyword{"VIEW"}.ignore(pos, expected))
return false;
if (pos->type != TokenType::OpeningRoundBracket)
@ -1094,15 +1097,30 @@ bool ParserTableFunctionView::parseImpl(Pos & pos, ASTPtr & node, Expected & exp
return false;
}
ASTPtr else_ast;
if (if_permitted)
{
if (!ParserKeyword{"ELSE"}.ignore(pos, expected))
return false;
if (!ParserWithOptionalAlias{std::make_unique<ParserFunction>(true, true), true}.parse(pos, else_ast, expected))
return false;
}
if (pos->type != TokenType::ClosingRoundBracket)
return false;
++pos;
auto expr_list = std::make_shared<ASTExpressionList>();
expr_list->children.push_back(query);
if (if_permitted)
expr_list->children.push_back(else_ast);
auto function_node = std::make_shared<ASTFunction>();
tryGetIdentifierNameInto(identifier, function_node->name);
auto expr_list_with_single_query = std::make_shared<ASTExpressionList>();
expr_list_with_single_query->children.push_back(query);
function_node->name = "view";
function_node->arguments = expr_list_with_single_query;
function_node->name = if_permitted ? "viewIfPermitted" : "view";
function_node->arguments = expr_list;
function_node->children.push_back(function_node->arguments);
node = function_node;
return true;
@ -1971,6 +1989,7 @@ const char * ParserAlias::restricted_keywords[] =
"WITH",
"INTERSECT",
"EXCEPT",
"ELSE",
nullptr
};

View File

@ -162,7 +162,7 @@ protected:
bool is_table_function;
};
// A special function parser for view table function.
// A special function parser for view and viewIfPermitted table functions.
// It parses an SELECT query as its argument and doesn't support getColumnName().
class ParserTableFunctionView : public IParserBase
{

View File

@ -180,9 +180,13 @@ void StorageView::replaceWithSubquery(ASTSelectQuery & outer_query, ASTPtr view_
if (!table_expression->database_and_table_name)
{
// If it's a view table function, add a fake db.table name.
if (table_expression->table_function && table_expression->table_function->as<ASTFunction>()->name == "view")
if (table_expression->table_function)
{
auto table_function_name = table_expression->table_function->as<ASTFunction>()->name;
if ((table_function_name == "view") || (table_function_name == "viewIfPermitted"))
table_expression->database_and_table_name = std::make_shared<ASTTableIdentifier>("__view");
else
}
if (!table_expression->database_and_table_name)
throw Exception("Logical error: incorrect table expression", ErrorCodes::LOGICAL_ERROR);
}

View File

@ -0,0 +1,113 @@
#include <Interpreters/InterpreterSelectWithUnionQuery.h>
#include <Parsers/ASTFunction.h>
#include <Parsers/ASTLiteral.h>
#include <Parsers/ASTSelectWithUnionQuery.h>
#include <Storages/StorageNull.h>
#include <Storages/StorageView.h>
#include <Storages/checkAndGetLiteralArgument.h>
#include <TableFunctions/ITableFunction.h>
#include <TableFunctions/TableFunctionFactory.h>
#include <TableFunctions/TableFunctionViewIfPermitted.h>
#include <TableFunctions/parseColumnsListForTableFunction.h>
#include <Interpreters/evaluateConstantExpression.h>
#include "registerTableFunctions.h"
namespace DB
{
namespace ErrorCodes
{
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
extern const int BAD_ARGUMENTS;
extern const int ACCESS_DENIED;
}
const ASTSelectWithUnionQuery & TableFunctionViewIfPermitted::getSelectQuery() const
{
return *create.select;
}
void TableFunctionViewIfPermitted::parseArguments(const ASTPtr & ast_function, ContextPtr context)
{
const auto * function = ast_function->as<ASTFunction>();
if (!function || !function->arguments || (function->arguments->children.size() != 2))
throw Exception(
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Table function '{}' requires two arguments: a SELECT query and a table function",
getName());
const auto & arguments = function->arguments->children;
auto * select = arguments[0]->as<ASTSelectWithUnionQuery>();
if (!select)
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Table function '{}' requires a SELECT query as its first argument", getName());
create.set(create.select, select->clone());
else_ast = arguments[1];
if (!else_ast->as<ASTFunction>())
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Table function '{}' requires a table function as its second argument", getName());
else_table_function = TableFunctionFactory::instance().get(else_ast, context);
}
ColumnsDescription TableFunctionViewIfPermitted::getActualTableStructure(ContextPtr context) const
{
return else_table_function->getActualTableStructure(context);
}
StoragePtr TableFunctionViewIfPermitted::executeImpl(
const ASTPtr & /* ast_function */, ContextPtr context, const std::string & table_name, ColumnsDescription /* cached_columns */) const
{
StoragePtr storage;
auto columns = getActualTableStructure(context);
if (isPermitted(context, columns))
{
storage = std::make_shared<StorageView>(StorageID(getDatabaseName(), table_name), create, columns, "");
}
else
{
storage = else_table_function->execute(else_ast, context, table_name);
}
storage->startup();
return storage;
}
bool TableFunctionViewIfPermitted::isPermitted(const ContextPtr & context, const ColumnsDescription & else_columns) const
{
Block sample_block;
try
{
/// Will throw ACCESS_DENIED if the current user is not allowed to execute the SELECT query.
sample_block = InterpreterSelectWithUnionQuery::getSampleBlock(create.children[0], context);
}
catch (Exception & e)
{
if (e.code() == ErrorCodes::ACCESS_DENIED)
return {};
throw;
}
/// We check that columns match only if permitted (otherwise we could reveal the structure to an user who must not know it).
ColumnsDescription columns{sample_block.getNamesAndTypesList()};
if (columns != else_columns)
{
throw Exception(
ErrorCodes::BAD_ARGUMENTS,
"Table function '{}' requires a SELECT query with the result columns matching a table function after 'ELSE'. "
"Currently the result columns of the SELECT query are {}, and the table function after 'ELSE' gives {}",
getName(),
columns.toString(),
else_columns.toString());
}
return true;
}
void registerTableFunctionViewIfPermitted(TableFunctionFactory & factory)
{
factory.registerFunction<TableFunctionViewIfPermitted>();
}
}

View File

@ -0,0 +1,35 @@
#pragma once
#include <TableFunctions/ITableFunction.h>
#include <Parsers/ASTCreateQuery.h>
#include <base/types.h>
namespace DB
{
/* viewIfPermitted(query ELSE null('structure'))
* Works as "view(query)" if the current user has the permissions required to execute "query"; works as "null('structure')" otherwise.
*/
class TableFunctionViewIfPermitted : public ITableFunction
{
public:
static constexpr auto name = "viewIfPermitted";
std::string getName() const override { return name; }
const ASTSelectWithUnionQuery & getSelectQuery() const;
private:
StoragePtr executeImpl(const ASTPtr & ast_function, ContextPtr context, const String & table_name, ColumnsDescription cached_columns) const override;
const char * getStorageTypeName() const override { return "ViewIfPermitted"; }
void parseArguments(const ASTPtr & ast_function, ContextPtr context) override;
ColumnsDescription getActualTableStructure(ContextPtr context) const override;
bool isPermitted(const ContextPtr & context, const ColumnsDescription & else_columns) const;
ASTCreateQuery create;
ASTPtr else_ast;
TableFunctionPtr else_table_function;
};
}

View File

@ -42,6 +42,7 @@ void registerTableFunctions()
registerTableFunctionJDBC(factory);
registerTableFunctionView(factory);
registerTableFunctionViewIfPermitted(factory);
#if USE_MYSQL
registerTableFunctionMySQL(factory);

View File

@ -40,6 +40,7 @@ void registerTableFunctionODBC(TableFunctionFactory & factory);
void registerTableFunctionJDBC(TableFunctionFactory & factory);
void registerTableFunctionView(TableFunctionFactory & factory);
void registerTableFunctionViewIfPermitted(TableFunctionFactory & factory);
#if USE_MYSQL
void registerTableFunctionMySQL(TableFunctionFactory & factory);

View File

@ -65,3 +65,38 @@ def test_merge():
"it's necessary to have grant SELECT ON default.table2"
in instance.query_and_get_error(select_query, user="A")
)
def test_view_if_permitted():
assert (
instance.query(
"SELECT * FROM viewIfPermitted(SELECT * FROM table1 ELSE null('x UInt32'))"
)
== "1\n"
)
expected_error = "requires a SELECT query with the result columns matching a table function after 'ELSE'"
assert expected_error in instance.query_and_get_error(
"SELECT * FROM viewIfPermitted(SELECT * FROM table1 ELSE null('x Int32'))"
)
assert expected_error in instance.query_and_get_error(
"SELECT * FROM viewIfPermitted(SELECT * FROM table1 ELSE null('y UInt32'))"
)
instance.query("CREATE USER A")
assert (
instance.query(
"SELECT * FROM viewIfPermitted(SELECT * FROM table1 ELSE null('x UInt32'))",
user="A",
)
== ""
)
instance.query("GRANT SELECT ON table1 TO A")
assert (
instance.query(
"SELECT * FROM viewIfPermitted(SELECT * FROM table1 ELSE null('x UInt32'))",
user="A",
)
== "1\n"
)