diff --git a/src/Access/EnabledRolesInfo.h b/src/Access/EnabledRolesInfo.h index f06b7478daf..091e1b64002 100644 --- a/src/Access/EnabledRolesInfo.h +++ b/src/Access/EnabledRolesInfo.h @@ -10,7 +10,7 @@ namespace DB { -/// Information about a role. +/// Information about roles enabled for a user at some specific time. struct EnabledRolesInfo { boost::container::flat_set current_roles; diff --git a/src/Functions/currentRoles.cpp b/src/Functions/currentRoles.cpp new file mode 100644 index 00000000000..0a4e23308d8 --- /dev/null +++ b/src/Functions/currentRoles.cpp @@ -0,0 +1,88 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace DB +{ + +namespace +{ + enum class Kind + { + CURRENT_ROLES, + ENABLED_ROLES, + DEFAULT_ROLES, + }; + + template + class FunctionCurrentRoles : public IFunction + { + public: + static constexpr auto name = (kind == Kind::CURRENT_ROLES) ? "currentRoles" : ((kind == Kind::ENABLED_ROLES) ? "enabledRoles" : "defaultRoles"); + static FunctionPtr create(const ContextPtr & context) { return std::make_shared(context); } + + String getName() const override { return name; } + + explicit FunctionCurrentRoles(const ContextPtr & context) + { + if constexpr (kind == Kind::CURRENT_ROLES) + { + role_names = context->getRolesInfo()->getCurrentRolesNames(); + } + else if constexpr (kind == Kind::ENABLED_ROLES) + { + role_names = context->getRolesInfo()->getEnabledRolesNames(); + } + else + { + static_assert(kind == Kind::DEFAULT_ROLES); + const auto & manager = context->getAccessControlManager(); + if (auto user = context->getUser()) + role_names = manager.tryReadNames(user->granted_roles.findGranted(user->default_roles)); + } + + /// We sort the names because the result of the function should not depend on the order of UUIDs. + std::sort(role_names.begin(), role_names.end()); + } + + size_t getNumberOfArguments() const override { return 0; } + bool isDeterministic() const override { return false; } + + DataTypePtr getReturnTypeImpl(const DataTypes & /*arguments*/) const override + { + return std::make_shared(std::make_shared()); + } + + ColumnPtr executeImpl(const ColumnsWithTypeAndName &, const DataTypePtr &, size_t input_rows_count) const override + { + auto col_res = ColumnArray::create(ColumnString::create()); + ColumnString & res_strings = typeid_cast(col_res->getData()); + ColumnArray::Offsets & res_offsets = col_res->getOffsets(); + for (const String & role_name : role_names) + res_strings.insertData(role_name.data(), role_name.length()); + res_offsets.push_back(res_strings.size()); + return ColumnConst::create(std::move(col_res), input_rows_count); + } + + private: + Strings role_names; + }; +} + +void registerFunctionCurrentRoles(FunctionFactory & factory) +{ + factory.registerFunction>(); + factory.registerFunction>(); + factory.registerFunction>(); +} + +} diff --git a/src/Functions/registerFunctionsMiscellaneous.cpp b/src/Functions/registerFunctionsMiscellaneous.cpp index a34572ecd38..12c54aeeefd 100644 --- a/src/Functions/registerFunctionsMiscellaneous.cpp +++ b/src/Functions/registerFunctionsMiscellaneous.cpp @@ -10,6 +10,7 @@ class FunctionFactory; void registerFunctionCurrentDatabase(FunctionFactory &); void registerFunctionCurrentUser(FunctionFactory &); void registerFunctionCurrentProfiles(FunctionFactory &); +void registerFunctionCurrentRoles(FunctionFactory &); void registerFunctionHostName(FunctionFactory &); void registerFunctionFQDN(FunctionFactory &); void registerFunctionVisibleWidth(FunctionFactory &); @@ -87,6 +88,7 @@ void registerFunctionsMiscellaneous(FunctionFactory & factory) registerFunctionCurrentDatabase(factory); registerFunctionCurrentUser(factory); registerFunctionCurrentProfiles(factory); + registerFunctionCurrentRoles(factory); registerFunctionHostName(factory); registerFunctionFQDN(factory); registerFunctionVisibleWidth(factory); diff --git a/tests/integration/test_role/test.py b/tests/integration/test_role/test.py index fd10db78c2e..1e253a93737 100644 --- a/tests/integration/test_role/test.py +++ b/tests/integration/test_role/test.py @@ -33,7 +33,7 @@ def cleanup_after_test(): yield finally: instance.query("DROP USER IF EXISTS A, B") - instance.query("DROP ROLE IF EXISTS R1, R2") + instance.query("DROP ROLE IF EXISTS R1, R2, R3, R4") def test_create_role(): @@ -240,3 +240,37 @@ def test_introspection(): assert instance.query("SELECT * from system.current_roles ORDER BY role_name", user='B') == TSV([["R2", 1, 1]]) assert instance.query("SELECT * from system.enabled_roles ORDER BY role_name", user='A') == TSV([["R1", 0, 1, 1]]) assert instance.query("SELECT * from system.enabled_roles ORDER BY role_name", user='B') == TSV([["R2", 1, 1, 1]]) + + +def test_function_current_roles(): + instance.query("CREATE USER A") + instance.query('CREATE ROLE R1, R2, R3, R4') + instance.query('GRANT R4 TO R2') + instance.query('GRANT R1,R2,R3 TO A') + + session_id = new_session_id() + assert instance.http_query('SELECT defaultRoles(), currentRoles(), enabledRoles()', user='A', params={'session_id':session_id}) == "['R1','R2','R3']\t['R1','R2','R3']\t['R1','R2','R3','R4']\n" + + instance.http_query('SET ROLE R1', user='A', params={'session_id':session_id}) + assert instance.http_query('SELECT defaultRoles(), currentRoles(), enabledRoles()', user='A', params={'session_id':session_id}) == "['R1','R2','R3']\t['R1']\t['R1']\n" + + instance.http_query('SET ROLE R2', user='A', params={'session_id':session_id}) + assert instance.http_query('SELECT defaultRoles(), currentRoles(), enabledRoles()', user='A', params={'session_id':session_id}) == "['R1','R2','R3']\t['R2']\t['R2','R4']\n" + + instance.http_query('SET ROLE NONE', user='A', params={'session_id':session_id}) + assert instance.http_query('SELECT defaultRoles(), currentRoles(), enabledRoles()', user='A', params={'session_id':session_id}) == "['R1','R2','R3']\t[]\t[]\n" + + instance.http_query('SET ROLE DEFAULT', user='A', params={'session_id':session_id}) + assert instance.http_query('SELECT defaultRoles(), currentRoles(), enabledRoles()', user='A', params={'session_id':session_id}) == "['R1','R2','R3']\t['R1','R2','R3']\t['R1','R2','R3','R4']\n" + + instance.query('SET DEFAULT ROLE R2 TO A') + assert instance.http_query('SELECT defaultRoles(), currentRoles(), enabledRoles()', user='A', params={'session_id':session_id}) == "['R2']\t['R1','R2','R3']\t['R1','R2','R3','R4']\n" + + instance.query('REVOKE R3 FROM A') + assert instance.http_query('SELECT defaultRoles(), currentRoles(), enabledRoles()', user='A', params={'session_id':session_id}) == "['R2']\t['R1','R2']\t['R1','R2','R4']\n" + + instance.query('REVOKE R2 FROM A') + assert instance.http_query('SELECT defaultRoles(), currentRoles(), enabledRoles()', user='A', params={'session_id':session_id}) == "[]\t['R1']\t['R1']\n" + + instance.query('SET DEFAULT ROLE ALL TO A') + assert instance.http_query('SELECT defaultRoles(), currentRoles(), enabledRoles()', user='A', params={'session_id':session_id}) == "['R1']\t['R1']\t['R1']\n"