From 6baccb963d3a69919f0a17ab7f8f4abe88675ab2 Mon Sep 17 00:00:00 2001 From: Vitaly Baranov Date: Tue, 3 Dec 2019 21:19:11 +0300 Subject: [PATCH] Add functions currentRowPolicies() and system table 'system.row_policies'. --- dbms/src/Functions/currentRowPolicies.cpp | 225 ++++++++++++++++++ dbms/src/Functions/registerFunctions.h | 1 + .../registerFunctionsMiscellaneous.cpp | 1 + .../System/StorageSystemRowPolicies.cpp | 59 +++++ .../System/StorageSystemRowPolicies.h | 26 ++ .../Storages/System/attachSystemTables.cpp | 2 + .../tests/integration/test_row_policy/test.py | 16 ++ 7 files changed, 330 insertions(+) create mode 100644 dbms/src/Functions/currentRowPolicies.cpp create mode 100644 dbms/src/Storages/System/StorageSystemRowPolicies.cpp create mode 100644 dbms/src/Storages/System/StorageSystemRowPolicies.h diff --git a/dbms/src/Functions/currentRowPolicies.cpp b/dbms/src/Functions/currentRowPolicies.cpp new file mode 100644 index 00000000000..4ffd40ed1b9 --- /dev/null +++ b/dbms/src/Functions/currentRowPolicies.cpp @@ -0,0 +1,225 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace DB +{ +namespace ErrorCodes +{ + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; + extern const int ILLEGAL_TYPE_OF_ARGUMENT; +} + + +/// The currentRowPolicies() function can be called with 0..2 arguments: +/// currentRowPolicies() returns array of tuples (database, table_name, row_policy_name) for all the row policies applied for the current user; +/// currentRowPolicies(table_name) is equivalent to currentRowPolicies(currentDatabase(), table_name); +/// currentRowPolicies(database, table_name) returns array of names of the row policies applied to a specific table and for the current user. +class FunctionCurrentRowPolicies : public IFunction +{ +public: + static constexpr auto name = "currentRowPolicies"; + + static FunctionPtr create(const Context & context_) { return std::make_shared(context_); } + explicit FunctionCurrentRowPolicies(const Context & context_) : context(context_) {} + + String getName() const override { return name; } + size_t getNumberOfArguments() const override { return 0; } + bool isVariadic() const override { return true; } + + void checkNumberOfArgumentsIfVariadic(size_t number_of_arguments) const override + { + if (number_of_arguments > 2) + throw Exception("Number of arguments for function " + String(name) + " doesn't match: passed " + + toString(number_of_arguments) + ", should be 0..2", + ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + } + + DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override + { + if (arguments.empty()) + return std::make_shared(std::make_shared( + DataTypes{std::make_shared(), std::make_shared(), std::make_shared()})); + else + return std::make_shared(std::make_shared()); + } + + bool isDeterministic() const override { return false; } + + void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result_pos, size_t input_rows_count) override + { + if (arguments.empty()) + { + auto database_column = ColumnString::create(); + auto table_name_column = ColumnString::create(); + auto policy_name_column = ColumnString::create(); + for (const auto & policy_id : context.getRowPolicy()->getCurrentPolicyIDs()) + { + const auto policy = context.getAccessControlManager().tryRead(policy_id); + if (policy) + { + const String database = policy->getDatabase(); + const String table_name = policy->getTableName(); + const String policy_name = policy->getName(); + database_column->insertData(database.data(), database.length()); + table_name_column->insertData(table_name.data(), table_name.length()); + policy_name_column->insertData(policy_name.data(), policy_name.length()); + } + } + auto offset_column = ColumnArray::ColumnOffsets::create(); + offset_column->insertValue(policy_name_column->size()); + block.getByPosition(result_pos).column = ColumnConst::create( + ColumnArray::create( + ColumnTuple::create(Columns{std::move(database_column), std::move(table_name_column), std::move(policy_name_column)}), + std::move(offset_column)), + input_rows_count); + return; + } + + const IColumn * database_column = nullptr; + if (arguments.size() == 2) + { + const auto & database_column_with_type = block.getByPosition(arguments[0]); + if (!isStringOrFixedString(database_column_with_type.type)) + throw Exception{"The first argument of function " + String(name) + + " should be a string containing database name, illegal type: " + + database_column_with_type.type->getName(), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; + database_column = database_column_with_type.column.get(); + } + + const auto & table_name_column_with_type = block.getByPosition(arguments[arguments.size() - 1]); + if (!isStringOrFixedString(table_name_column_with_type.type)) + throw Exception{"The" + String(database_column ? " last" : "") + " argument of function " + String(name) + + " should be a string containing table name, illegal type: " + table_name_column_with_type.type->getName(), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; + const IColumn * table_name_column = table_name_column_with_type.column.get(); + + auto policy_name_column = ColumnString::create(); + auto offset_column = ColumnArray::ColumnOffsets::create(); + for (const auto i : ext::range(0, input_rows_count)) + { + String database = database_column ? database_column->getDataAt(i).toString() : context.getCurrentDatabase(); + String table_name = table_name_column->getDataAt(i).toString(); + for (const auto & policy_id : context.getRowPolicy()->getCurrentPolicyIDs(database, table_name)) + { + const auto policy = context.getAccessControlManager().tryRead(policy_id); + if (policy) + { + const String policy_name = policy->getName(); + policy_name_column->insertData(policy_name.data(), policy_name.length()); + } + } + offset_column->insertValue(policy_name_column->size()); + } + + block.getByPosition(result_pos).column = ColumnArray::create(std::move(policy_name_column), std::move(offset_column)); + } + +private: + const Context & context; +}; + + +/// The currentRowPolicyIDs() function can be called with 0..2 arguments: +/// currentRowPolicyIDs() returns array of IDs of all the row policies applied for the current user; +/// currentRowPolicyIDs(table_name) is equivalent to currentRowPolicyIDs(currentDatabase(), table_name); +/// currentRowPolicyIDs(database, table_name) returns array of IDs of the row policies applied to a specific table and for the current user. +class FunctionCurrentRowPolicyIDs : public IFunction +{ +public: + static constexpr auto name = "currentRowPolicyIDs"; + + static FunctionPtr create(const Context & context_) { return std::make_shared(context_); } + explicit FunctionCurrentRowPolicyIDs(const Context & context_) : context(context_) {} + + String getName() const override { return name; } + size_t getNumberOfArguments() const override { return 0; } + bool isVariadic() const override { return true; } + + void checkNumberOfArgumentsIfVariadic(size_t number_of_arguments) const override + { + if (number_of_arguments > 2) + throw Exception("Number of arguments for function " + String(name) + " doesn't match: passed " + + toString(number_of_arguments) + ", should be 0..2", + ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + } + + DataTypePtr getReturnTypeImpl(const DataTypes & /* arguments */) const override + { + return std::make_shared(std::make_shared()); + } + + bool isDeterministic() const override { return false; } + + void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result_pos, size_t input_rows_count) override + { + if (arguments.empty()) + { + auto policy_id_column = ColumnVector::create(); + for (const auto & policy_id : context.getRowPolicy()->getCurrentPolicyIDs()) + policy_id_column->insertValue(policy_id); + auto offset_column = ColumnArray::ColumnOffsets::create(); + offset_column->insertValue(policy_id_column->size()); + block.getByPosition(result_pos).column + = ColumnConst::create(ColumnArray::create(std::move(policy_id_column), std::move(offset_column)), input_rows_count); + return; + } + + const IColumn * database_column = nullptr; + if (arguments.size() == 2) + { + const auto & database_column_with_type = block.getByPosition(arguments[0]); + if (!isStringOrFixedString(database_column_with_type.type)) + throw Exception{"The first argument of function " + String(name) + + " should be a string containing database name, illegal type: " + + database_column_with_type.type->getName(), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; + database_column = database_column_with_type.column.get(); + } + + const auto & table_name_column_with_type = block.getByPosition(arguments[arguments.size() - 1]); + if (!isStringOrFixedString(table_name_column_with_type.type)) + throw Exception{"The" + String(database_column ? " last" : "") + " argument of function " + String(name) + + " should be a string containing table name, illegal type: " + table_name_column_with_type.type->getName(), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; + const IColumn * table_name_column = table_name_column_with_type.column.get(); + + auto policy_id_column = ColumnVector::create(); + auto offset_column = ColumnArray::ColumnOffsets::create(); + for (const auto i : ext::range(0, input_rows_count)) + { + String database = database_column ? database_column->getDataAt(i).toString() : context.getCurrentDatabase(); + String table_name = table_name_column->getDataAt(i).toString(); + for (const auto & policy_id : context.getRowPolicy()->getCurrentPolicyIDs(database, table_name)) + policy_id_column->insertValue(policy_id); + offset_column->insertValue(policy_id_column->size()); + } + + block.getByPosition(result_pos).column = ColumnArray::create(std::move(policy_id_column), std::move(offset_column)); + } + +private: + const Context & context; +}; + + +void registerFunctionCurrentRowPolicies(FunctionFactory & factory) +{ + factory.registerFunction(); + factory.registerFunction(); +} + +} diff --git a/dbms/src/Functions/registerFunctions.h b/dbms/src/Functions/registerFunctions.h index 087fd6b7e2b..5827ae5894c 100644 --- a/dbms/src/Functions/registerFunctions.h +++ b/dbms/src/Functions/registerFunctions.h @@ -9,6 +9,7 @@ class FunctionFactory; void registerFunctionCurrentDatabase(FunctionFactory &); void registerFunctionCurrentUser(FunctionFactory &); void registerFunctionCurrentQuota(FunctionFactory &); +void registerFunctionCurrentRowPolicies(FunctionFactory &); void registerFunctionHostName(FunctionFactory &); void registerFunctionFQDN(FunctionFactory &); void registerFunctionVisibleWidth(FunctionFactory &); diff --git a/dbms/src/Functions/registerFunctionsMiscellaneous.cpp b/dbms/src/Functions/registerFunctionsMiscellaneous.cpp index c45ccf57f64..98c749189d4 100644 --- a/dbms/src/Functions/registerFunctionsMiscellaneous.cpp +++ b/dbms/src/Functions/registerFunctionsMiscellaneous.cpp @@ -8,6 +8,7 @@ void registerFunctionsMiscellaneous(FunctionFactory & factory) registerFunctionCurrentDatabase(factory); registerFunctionCurrentUser(factory); registerFunctionCurrentQuota(factory); + registerFunctionCurrentRowPolicies(factory); registerFunctionHostName(factory); registerFunctionFQDN(factory); registerFunctionVisibleWidth(factory); diff --git a/dbms/src/Storages/System/StorageSystemRowPolicies.cpp b/dbms/src/Storages/System/StorageSystemRowPolicies.cpp new file mode 100644 index 00000000000..8ac4ac1b755 --- /dev/null +++ b/dbms/src/Storages/System/StorageSystemRowPolicies.cpp @@ -0,0 +1,59 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace DB +{ +NamesAndTypesList StorageSystemRowPolicies::getNamesAndTypes() +{ + NamesAndTypesList names_and_types{ + {"database", std::make_shared()}, + {"table", std::make_shared()}, + {"name", std::make_shared()}, + {"full_name", std::make_shared()}, + {"id", std::make_shared()}, + {"source", std::make_shared()}, + {"restrictive", std::make_shared()}, + }; + + for (auto index : ext::range_with_static_cast(RowPolicy::MAX_CONDITION_INDEX)) + names_and_types.push_back({RowPolicy::conditionIndexToColumnName(index), std::make_shared()}); + + return names_and_types; +} + + +void StorageSystemRowPolicies::fillData(MutableColumns & res_columns, const Context & context, const SelectQueryInfo &) const +{ + const auto & access_control = context.getAccessControlManager(); + std::vector ids = access_control.findAll(); + + for (const auto & id : ids) + { + auto policy = access_control.tryRead(id); + if (!policy) + continue; + const auto * storage = access_control.findStorage(id); + + size_t i = 0; + res_columns[i++]->insert(policy->getDatabase()); + res_columns[i++]->insert(policy->getTableName()); + res_columns[i++]->insert(policy->getName()); + res_columns[i++]->insert(policy->getFullName()); + res_columns[i++]->insert(id); + res_columns[i++]->insert(storage ? storage->getStorageName() : ""); + res_columns[i++]->insert(policy->isRestrictive()); + + for (auto index : ext::range(RowPolicy::MAX_CONDITION_INDEX)) + res_columns[i++]->insert(policy->conditions[index]); + } +} +} diff --git a/dbms/src/Storages/System/StorageSystemRowPolicies.h b/dbms/src/Storages/System/StorageSystemRowPolicies.h new file mode 100644 index 00000000000..c28342eb18c --- /dev/null +++ b/dbms/src/Storages/System/StorageSystemRowPolicies.h @@ -0,0 +1,26 @@ +#pragma once + +#include +#include + + +namespace DB +{ + +class Context; + + +/// Implements `row_policies` system table, which allows you to get information about row policies. +class StorageSystemRowPolicies : public ext::shared_ptr_helper, public IStorageSystemOneBlock +{ +public: + std::string getName() const override { return "SystemRowPolicies"; } + static NamesAndTypesList getNamesAndTypes(); + +protected: + friend struct ext::shared_ptr_helper; + using IStorageSystemOneBlock::IStorageSystemOneBlock; + void fillData(MutableColumns & res_columns, const Context & context, const SelectQueryInfo &) const override; +}; + +} diff --git a/dbms/src/Storages/System/attachSystemTables.cpp b/dbms/src/Storages/System/attachSystemTables.cpp index 2b8e630cbed..e8e265ca1e8 100644 --- a/dbms/src/Storages/System/attachSystemTables.cpp +++ b/dbms/src/Storages/System/attachSystemTables.cpp @@ -29,6 +29,7 @@ #include #include #include +#include #include #include #include @@ -56,6 +57,7 @@ void attachSystemTablesLocal(IDatabase & system_database) system_database.attachTable("settings", StorageSystemSettings::create("settings")); system_database.attachTable("quotas", StorageSystemQuotas::create("quotas")); system_database.attachTable("quota_usage", StorageSystemQuotaUsage::create("quota_usage")); + system_database.attachTable("row_policies", StorageSystemRowPolicies::create("row_policies")); system_database.attachTable("merge_tree_settings", SystemMergeTreeSettings::create("merge_tree_settings")); system_database.attachTable("build_options", StorageSystemBuildOptions::create("build_options")); system_database.attachTable("formats", StorageSystemFormats::create("formats")); diff --git a/dbms/tests/integration/test_row_policy/test.py b/dbms/tests/integration/test_row_policy/test.py index bc5061e6704..3de63d56fa0 100644 --- a/dbms/tests/integration/test_row_policy/test.py +++ b/dbms/tests/integration/test_row_policy/test.py @@ -137,3 +137,19 @@ def test_reload_users_xml_by_timer(): assert_eq_with_retry(instance, "SELECT * FROM mydb.filtered_table1", "1\t0\n1\t1") assert_eq_with_retry(instance, "SELECT * FROM mydb.filtered_table2", "0\t0\t0\t0\n0\t0\t6\t0") assert_eq_with_retry(instance, "SELECT * FROM mydb.filtered_table3", "0\t1\n1\t0") + + +def test_introspection(): + assert instance.query("SELECT currentRowPolicies('mydb', 'filtered_table1')") == "['default']\n" + assert instance.query("SELECT currentRowPolicies('mydb', 'filtered_table2')") == "['default']\n" + assert instance.query("SELECT currentRowPolicies('mydb', 'filtered_table3')") == "['default']\n" + assert instance.query("SELECT arraySort(currentRowPolicies())") == "[('mydb','filtered_table1','default'),('mydb','filtered_table2','default'),('mydb','filtered_table3','default')]\n" + + policy1 = "mydb\tfiltered_table1\tdefault\tdefault ON mydb.filtered_table1\t9e8a8f62-4965-2b5e-8599-57c7b99b3549\tusers.xml\t0\ta = 1\t\t\t\t\n" + policy2 = "mydb\tfiltered_table2\tdefault\tdefault ON mydb.filtered_table2\tcffae79d-b9bf-a2ef-b798-019c18470b25\tusers.xml\t0\ta + b < 1 or c - d > 5\t\t\t\t\n" + policy3 = "mydb\tfiltered_table3\tdefault\tdefault ON mydb.filtered_table3\t12fc5cef-e3da-3940-ec79-d8be3911f42b\tusers.xml\t0\tc = 1\t\t\t\t\n" + assert instance.query("SELECT * from system.row_policies WHERE has(currentRowPolicyIDs('mydb', 'filtered_table1'), id) ORDER BY table, name") == policy1 + assert instance.query("SELECT * from system.row_policies WHERE has(currentRowPolicyIDs('mydb', 'filtered_table2'), id) ORDER BY table, name") == policy2 + assert instance.query("SELECT * from system.row_policies WHERE has(currentRowPolicyIDs('mydb', 'filtered_table3'), id) ORDER BY table, name") == policy3 + assert instance.query("SELECT * from system.row_policies ORDER BY table, name") == policy1 + policy2 + policy3 + assert instance.query("SELECT * from system.row_policies WHERE has(currentRowPolicyIDs(), id) ORDER BY table, name") == policy1 + policy2 + policy3