row policy template: initial, works, restrictive rules fix

This commit is contained in:
Ilya Golshtein 2023-03-03 11:00:12 +01:00
parent a2e7c77fe2
commit f0d21a9100
9 changed files with 183 additions and 26 deletions

View File

@ -28,6 +28,9 @@
#include <boost/range/algorithm/copy.hpp>
#include <boost/range/algorithm_ext/push_back.hpp>
#include <Common/logger_useful.h>
namespace DB
{
namespace ErrorCodes
@ -62,6 +65,7 @@ AccessEntityPtr deserializeAccessEntityImpl(const String & definition)
const char * end = begin + definition.size();
while (pos < end)
{
LOG_TRACE((&Poco::Logger::get("deserializeAccessEntityImpl")), "{}", std::string(pos, end));
queries.emplace_back(parseQueryAndMovePosition(parser, pos, end, "", true, 0, DBMS_DEFAULT_MAX_PARSER_DEPTH));
while (isWhitespaceASCII(*pos) || *pos == ';')
++pos;

View File

@ -3,6 +3,8 @@
#include <boost/range/adaptor/map.hpp>
#include <boost/range/algorithm/copy.hpp>
#include <Common/logger_useful.h>
namespace DB
{
@ -18,6 +20,12 @@ size_t EnabledRowPolicies::Hash::operator()(const MixedFiltersKey & key) const
return std::hash<std::string_view>{}(key.database) - std::hash<std::string_view>{}(key.table_name) + static_cast<size_t>(key.filter_type);
}
// size_t EnabledRowPolicies::Hash::operator()(const MixedFiltersKey & key) const
// {
// return std::hash<std::string_view>{}(key.database) + static_cast<size_t>(key.filter_type);
// }
EnabledRowPolicies::EnabledRowPolicies() : params()
{
}
@ -32,11 +40,37 @@ EnabledRowPolicies::~EnabledRowPolicies() = default;
RowPolicyFilterPtr EnabledRowPolicies::getFilter(const String & database, const String & table_name, RowPolicyFilterType filter_type) const
{
/// We don't lock `mutex` here.
auto loaded = mixed_filters.load();
{
for (auto it = loaded->begin(); it != loaded->end(); ++it)
{
LOG_TRACE((&Poco::Logger::get("EnabledRowPolicies::getFilter")), " db: {}, table {}", it->first.database, it->first.table_name);
}
}
auto it = loaded->find({database, table_name, filter_type});
if (it == loaded->end())
return {};
{
it = loaded->find({database, "*", filter_type});
if (it == loaded->end())
{
LOG_TRACE((&Poco::Logger::get("EnabledRowPolicies::getFilter")), "db: {}, table {} - not found ({} records)",
database, table_name, loaded->size());
return {};
}
}
LOG_TRACE((&Poco::Logger::get("EnabledRowPolicies::getFilter")), "db: {}, table {} - found ({} records)",
database, table_name, loaded->size());
return it->second;
}

View File

@ -72,6 +72,14 @@ private:
auto toTuple() const { return std::tie(database, table_name, filter_type); }
friend bool operator==(const MixedFiltersKey & left, const MixedFiltersKey & right) { return left.toTuple() == right.toTuple(); }
friend bool operator!=(const MixedFiltersKey & left, const MixedFiltersKey & right) { return left.toTuple() != right.toTuple(); }
// friend bool operator==(const MixedFiltersKey & left, const MixedFiltersKey & right)
// {
// return left.database == right.database && left.filter_type == right.filter_type;
// }
// friend bool operator!=(const MixedFiltersKey & left, const MixedFiltersKey & right)
// {
// return left.database != right.database || left.filter_type != right.filter_type;
// }
};
struct Hash

View File

@ -34,6 +34,7 @@ struct RowPolicy : public IAccessEntity
/// in addition to all the restrictive policies.
void setPermissive(bool permissive_ = true) { setRestrictive(!permissive_); }
bool isPermissive() const { return !isRestrictive(); }
bool isDatabase() const { return full_name.table_name == "*"; }
/// Sets that the policy is restrictive.
/// A row is only accessible if at least one of the permissive policies passes,

View File

@ -11,6 +11,8 @@
#include <boost/smart_ptr/make_shared.hpp>
#include <Core/Defines.h>
#include <Common/logger_useful.h>
namespace DB
{
@ -148,9 +150,19 @@ void RowPolicyCache::ensureAllRowPoliciesRead()
for (const UUID & id : access_control.findAll<RowPolicy>())
{
auto quota = access_control.tryRead<RowPolicy>(id);
if (quota)
all_policies.emplace(id, PolicyInfo(quota));
auto policy = access_control.tryRead<RowPolicy>(id);
if (policy)
{
PolicyInfo policy_info(policy);
if (policy_info.database_and_table_name->second == "*")
{
database_policies.emplace(id, std::move(policy_info));
}
else
{
table_policies.emplace(id, std::move(policy_info));
}
}
}
}
@ -158,15 +170,23 @@ void RowPolicyCache::ensureAllRowPoliciesRead()
void RowPolicyCache::rowPolicyAddedOrChanged(const UUID & policy_id, const RowPolicyPtr & new_policy)
{
std::lock_guard lock{mutex};
auto it = all_policies.find(policy_id);
if (it == all_policies.end())
bool found = true;
auto it = table_policies.find(policy_id);
if (it == table_policies.end())
{
it = all_policies.emplace(policy_id, PolicyInfo(new_policy)).first;
it = database_policies.find(policy_id);
if (it == database_policies.end())
{
PolicyMap & policy_map = new_policy->isDatabase() ? database_policies : table_policies;
it = policy_map.emplace(policy_id, PolicyInfo(new_policy)).first;
found = false;
}
}
else
if (found && it->second.policy == new_policy)
{
if (it->second.policy == new_policy)
return;
return;
}
auto & info = it->second;
@ -178,7 +198,15 @@ void RowPolicyCache::rowPolicyAddedOrChanged(const UUID & policy_id, const RowPo
void RowPolicyCache::rowPolicyRemoved(const UUID & policy_id)
{
std::lock_guard lock{mutex};
all_policies.erase(policy_id);
auto it = database_policies.find(policy_id);
if (it != database_policies.end())
{
database_policies.erase(it);
}
else
{
table_policies.erase(policy_id);
}
mixFilters();
}
@ -215,22 +243,71 @@ void RowPolicyCache::mixFiltersFor(EnabledRowPolicies & enabled)
std::vector<RowPolicyPtr> policies;
};
std::unordered_map<MixedFiltersKey, MixerWithNames, Hash> mixers;
std::unordered_map<MixedFiltersKey, MixerWithNames, Hash> table_mixers;
std::unordered_map<MixedFiltersKey, MixerWithNames, Hash> database_mixers;
for (const auto & [policy_id, info] : all_policies)
for (const auto & [policy_id, info] : database_policies)
{
const auto & policy = *info.policy;
bool match = info.roles->match(enabled.params.user_id, enabled.params.enabled_roles);
MixedFiltersKey key;
key.database = info.database_and_table_name->first;
key.table_name = info.database_and_table_name->second;
for (auto filter_type : collections::range(0, RowPolicyFilterType::MAX))
{
auto filter_type_i = static_cast<size_t>(filter_type);
if (info.parsed_filters[filter_type_i])
{
key.filter_type = filter_type;
auto & mixer = mixers[key];
MixedFiltersKey key{info.database_and_table_name->first,
info.database_and_table_name->second,
filter_type};
LOG_TRACE((&Poco::Logger::get("mixFiltersFor")), "db: {} : {}", key.database, key.table_name);
auto & mixer = database_mixers[key]; // getting database level mixer
mixer.database_and_table_name = info.database_and_table_name;
if (match)
{
mixer.mixer.add(info.parsed_filters[filter_type_i], policy.isRestrictive());
mixer.policies.push_back(info.policy);
}
}
}
}
for (const auto & [policy_id, info] : table_policies)
{
const auto & policy = *info.policy;
bool match = info.roles->match(enabled.params.user_id, enabled.params.enabled_roles);
for (auto filter_type : collections::range(0, RowPolicyFilterType::MAX))
{
auto filter_type_i = static_cast<size_t>(filter_type);
if (info.parsed_filters[filter_type_i])
{
MixedFiltersKey key{info.database_and_table_name->first,
info.database_and_table_name->second,
filter_type};
LOG_TRACE((&Poco::Logger::get("mixFiltersFor")), "table: {} : {}", key.database, key.table_name);
auto table_it = table_mixers.find(key);
if (table_it == table_mixers.end())
{
LOG_TRACE((&Poco::Logger::get("mixFiltersFor")), "table: not found, looking for db");
MixedFiltersKey database_key = key;
database_key.table_name = "*";
auto database_it = database_mixers.find(database_key);
if (database_it == database_mixers.end())
{
LOG_TRACE((&Poco::Logger::get("mixFiltersFor")), "table: not found, database not found");
table_it = table_mixers.try_emplace(key).first;
}
else
{
LOG_TRACE((&Poco::Logger::get("mixFiltersFor")), "table: not found, database found");
table_it = table_mixers.insert({key, database_it->second}).first;
}
}
auto & mixer = table_it->second; // table_mixers[key]; getting table level mixer
mixer.database_and_table_name = info.database_and_table_name;
if (match)
{
@ -242,15 +319,20 @@ void RowPolicyCache::mixFiltersFor(EnabledRowPolicies & enabled)
}
auto mixed_filters = boost::make_shared<MixedFiltersMap>();
for (auto & [key, mixer] : mixers)
for (auto mixer_map_ptr : { &table_mixers, &database_mixers})
{
auto mixed_filter = std::make_shared<RowPolicyFilter>();
mixed_filter->database_and_table_name = std::move(mixer.database_and_table_name);
mixed_filter->expression = std::move(mixer.mixer).getResult(access_control.isEnabledUsersWithoutRowPoliciesCanReadRows());
mixed_filter->policies = std::move(mixer.policies);
mixed_filters->emplace(key, std::move(mixed_filter));
for (auto & [key, mixer] : *mixer_map_ptr)
{
auto mixed_filter = std::make_shared<RowPolicyFilter>();
mixed_filter->database_and_table_name = std::move(mixer.database_and_table_name);
mixed_filter->expression = std::move(mixer.mixer).getResult(access_control.isEnabledUsersWithoutRowPoliciesCanReadRows());
mixed_filter->policies = std::move(mixer.policies);
mixed_filters->emplace(key, std::move(mixed_filter));
}
}
enabled.mixed_filters.store(mixed_filters);
}

View File

@ -12,6 +12,7 @@ namespace DB
class AccessControl;
struct RolesOrUsersSet;
struct RowPolicy;
using RowPolicyPtr = std::shared_ptr<const RowPolicy>;
/// Stores read and parsed row policies.
@ -35,14 +36,18 @@ private:
ASTPtr parsed_filters[static_cast<size_t>(RowPolicyFilterType::MAX)];
};
using PolicyMap = std::unordered_map<UUID, PolicyInfo>;
void ensureAllRowPoliciesRead();
void rowPolicyAddedOrChanged(const UUID & policy_id, const RowPolicyPtr & new_policy);
void rowPolicyRemoved(const UUID & policy_id);
void mixFilters();
void mixFiltersFor(EnabledRowPolicies & enabled);
const AccessControl & access_control;
std::unordered_map<UUID, PolicyInfo> all_policies;
PolicyMap database_policies;
PolicyMap table_policies;
bool all_policies_read = false;
scope_guard subscription;
std::map<EnabledRowPolicies::Params, std::weak_ptr<EnabledRowPolicies>> enabled_row_policies;

View File

@ -203,6 +203,7 @@ namespace
bool ParserCreateRowPolicyQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected)
{
// poco_assert("ParserCreateRowPolicyQuery::parseImpl" == nullptr);
bool alter = false;
if (attach_mode)
{

View File

@ -7,6 +7,8 @@
#include <Parsers/parseDatabaseAndTableName.h>
#include <base/insertAtEnd.h>
#include <Common/logger_useful.h>
namespace DB
{
@ -26,8 +28,19 @@ namespace
return IParserBase::wrapParseImpl(pos, [&]
{
String res_database, res_table_name;
if (!parseDatabaseAndTableName(pos, expected, res_database, res_table_name))
// if (!parseDatabaseAndTableName(pos, expected, res_database, res_table_name))
bool any_database = false;
bool any_table = true;
if (!parseDatabaseAndTableNameOrAsterisks(pos, expected, res_database, any_database, res_table_name, any_table))
{
// poco_assert("parseDatabaseAndTableName failed" == nullptr);
LOG_TRACE((&Poco::Logger::get("ParserRowPolicyName")), "parseDatabaseAndTableName failed");
return false;
}
if (any_table)
res_table_name = "*";
/// If table is specified without DB it cannot be followed by "ON"
/// (but can be followed by "ON CLUSTER").
@ -51,8 +64,10 @@ namespace
}
bool parseOnDBAndTableName(IParser::Pos & pos, Expected & expected, String & database, String & table_name)
{
// poco_assert("parseOnDBAndTableNames" == nullptr);
return IParserBase::wrapParseImpl(pos, [&]
{
return ParserKeyword{"ON"}.ignore(pos, expected) && parseDBAndTableName(pos, expected, database, table_name);
@ -62,6 +77,9 @@ namespace
bool parseOnDBAndTableNames(IParser::Pos & pos, Expected & expected, std::vector<std::pair<String, String>> & database_and_table_names)
{
// poco_assert("parseOnDBAndTableNames" == nullptr);
return IParserBase::wrapParseImpl(pos, [&]
{
if (!ParserKeyword{"ON"}.ignore(pos, expected))
@ -146,6 +164,7 @@ namespace
bool ParserRowPolicyName::parseImpl(Pos & pos, ASTPtr & node, Expected & expected)
{
// poco_assert("ParserRowPolicyName::parseImpl" == nullptr);
std::vector<RowPolicyName> full_names;
String cluster;
if (!parseRowPolicyNamesAroundON(pos, expected, false, false, allow_on_cluster, full_names, cluster))
@ -162,6 +181,7 @@ bool ParserRowPolicyName::parseImpl(Pos & pos, ASTPtr & node, Expected & expecte
bool ParserRowPolicyNames::parseImpl(Pos & pos, ASTPtr & node, Expected & expected)
{
// poco_assert("ParserRowPolicyName::parseImpl" == nullptr);
std::vector<RowPolicyName> full_names;
size_t num_added_names_last_time = 0;
String cluster;

View File

@ -139,6 +139,8 @@ void writeCommonErrorMessage(
if (!query_description.empty())
out << " (" << query_description << ")";
// poco_assert("writeCommonErrorMessage" == nullptr);
out << ": failed at position " << (last_token.begin - begin + 1);
if (last_token.type == TokenType::EndOfStream || last_token.type == TokenType::Semicolon)