Reworked changes to std::shared_ptr<const RowPolicyFilter>.

This commit is contained in:
Vladimir Chebotaryov 2022-10-24 10:58:14 +03:00
parent 6d5d31e49c
commit d17b7387f9
11 changed files with 69 additions and 55 deletions

View File

@ -379,7 +379,7 @@ std::shared_ptr<const EnabledRowPolicies> ContextAccess::getEnabledRowPolicies()
return no_row_policies;
}
RowPolicyFilter ContextAccess::getRowPolicyFilter(const String & database, const String & table_name, RowPolicyFilterType filter_type, const RowPolicyFilter & combine_with_filter) const
RowPolicyFilterPtr ContextAccess::getRowPolicyFilter(const String & database, const String & table_name, RowPolicyFilterType filter_type, RowPolicyFilterPtr combine_with_filter) const
{
std::lock_guard lock{mutex};
if (enabled_row_policies)

View File

@ -87,7 +87,7 @@ public:
/// Returns the row policy filter for a specified table.
/// The function returns nullptr if there is no filter to apply.
RowPolicyFilter getRowPolicyFilter(const String & database, const String & table_name, RowPolicyFilterType filter_type, const RowPolicyFilter & combine_with_filter = {}) const;
RowPolicyFilterPtr getRowPolicyFilter(const String & database, const String & table_name, RowPolicyFilterType filter_type, RowPolicyFilterPtr combine_with_filter = {}) const;
/// Returns the quota to track resource consumption.
std::shared_ptr<const EnabledQuota> getQuota() const;

View File

@ -7,11 +7,10 @@
namespace DB
{
void RowPolicyFilter::optimize()
bool RowPolicyFilter::empty() const
{
bool value;
if (tryGetLiteralBool(expression.get(), value) && value)
expression.reset(); /// The condition is always true, no need to check it.
return !expression || (tryGetLiteralBool(expression.get(), value) && value);
}
size_t EnabledRowPolicies::Hash::operator()(const MixedFiltersKey & key) const
@ -30,7 +29,7 @@ EnabledRowPolicies::EnabledRowPolicies(const Params & params_) : params(params_)
EnabledRowPolicies::~EnabledRowPolicies() = default;
RowPolicyFilter EnabledRowPolicies::getFilter(const String & database, const String & table_name, RowPolicyFilterType filter_type) const
RowPolicyFilterPtr EnabledRowPolicies::getFilter(const String & database, const String & table_name, RowPolicyFilterType filter_type) const
{
/// We don't lock `mutex` here.
auto loaded = mixed_filters.load();
@ -38,26 +37,36 @@ RowPolicyFilter EnabledRowPolicies::getFilter(const String & database, const Str
if (it == loaded->end())
return {};
RowPolicyFilter filter = {it->second.ast, it->second.policies};
filter.optimize();
return filter;
return it->second;
}
RowPolicyFilter EnabledRowPolicies::getFilter(const String & database, const String & table_name, RowPolicyFilterType filter_type, const RowPolicyFilter & combine_with_filter) const
RowPolicyFilterPtr EnabledRowPolicies::getFilter(const String & database, const String & table_name, RowPolicyFilterType filter_type, RowPolicyFilterPtr combine_with_filter) const
{
RowPolicyFilter filter = getFilter(database, table_name, filter_type);
if (filter.expression && combine_with_filter.expression)
RowPolicyFilterPtr filter = getFilter(database, table_name, filter_type);
if (filter && combine_with_filter)
{
filter.expression = makeASTForLogicalAnd({filter.expression, combine_with_filter.expression});
}
else if (!filter.expression)
{
filter.expression = combine_with_filter.expression;
}
auto new_filter = std::make_shared<RowPolicyFilter>(*filter);
std::copy(combine_with_filter.policies.begin(), combine_with_filter.policies.end(), std::back_inserter(filter.policies));
filter.optimize();
if (filter->empty())
{
new_filter->expression = combine_with_filter->expression;
}
else if (combine_with_filter->empty())
{
new_filter->expression = filter->expression;
}
else
{
new_filter->expression = makeASTForLogicalAnd({filter->expression, combine_with_filter->expression});
}
std::copy(combine_with_filter->policies.begin(), combine_with_filter->policies.end(), std::back_inserter(new_filter->policies));
filter = new_filter;
}
else if (!filter)
{
filter = combine_with_filter;
}
return filter;
}

View File

@ -18,13 +18,17 @@ namespace DB
class IAST;
using ASTPtr = std::shared_ptr<IAST>;
struct RowPolicyFilter;
using RowPolicyFilterPtr = std::shared_ptr<const RowPolicyFilter>;
struct RowPolicyFilter
{
ASTPtr expression;
std::shared_ptr<const std::pair<String, String>> database_and_table_name;
std::vector<RowPolicyPtr> policies;
void optimize();
bool empty() const;
};
@ -52,8 +56,8 @@ public:
/// Returns prepared filter for a specific table and operations.
/// The function can return nullptr, that means there is no filters applied.
/// The returned filter can be a combination of the filters defined by multiple row policies.
RowPolicyFilter getFilter(const String & database, const String & table_name, RowPolicyFilterType filter_type) const;
RowPolicyFilter getFilter(const String & database, const String & table_name, RowPolicyFilterType filter_type, const RowPolicyFilter & combine_with_filter) const;
RowPolicyFilterPtr getFilter(const String & database, const String & table_name, RowPolicyFilterType filter_type) const;
RowPolicyFilterPtr getFilter(const String & database, const String & table_name, RowPolicyFilterType filter_type, RowPolicyFilterPtr combine_with_filter) const;
private:
friend class RowPolicyCache;
@ -70,19 +74,12 @@ private:
friend bool operator!=(const MixedFiltersKey & left, const MixedFiltersKey & right) { return left.toTuple() != right.toTuple(); }
};
struct MixedFiltersResult
{
ASTPtr ast;
std::shared_ptr<const std::pair<String, String>> database_and_table_name;
std::vector<RowPolicyPtr> policies;
};
struct Hash
{
size_t operator()(const MixedFiltersKey & key) const;
};
using MixedFiltersMap = std::unordered_map<MixedFiltersKey, MixedFiltersResult, Hash>;
using MixedFiltersMap = std::unordered_map<MixedFiltersKey, RowPolicyFilterPtr, Hash>;
const Params params;
mutable boost::atomic_shared_ptr<const MixedFiltersMap> mixed_filters;

View File

@ -244,10 +244,11 @@ void RowPolicyCache::mixFiltersFor(EnabledRowPolicies & enabled)
auto mixed_filters = boost::make_shared<MixedFiltersMap>();
for (auto & [key, mixer] : mixers)
{
auto & mixed_filter = (*mixed_filters)[key];
mixed_filter.database_and_table_name = std::move(mixer.database_and_table_name);
mixed_filter.ast = std::move(mixer.mixer).getResult(access_control.isEnabledUsersWithoutRowPoliciesCanReadRows());
mixed_filter.policies = std::move(mixer.policies);
auto mixed_filter = std::make_shared<RowPolicyFilter>();
mixed_filter->database_and_table_name = std::move(mixer.database_and_table_name);
mixed_filter->expression = std::move(mixer.mixer).getResult(access_control.isEnabledUsersWithoutRowPoliciesCanReadRows());
mixed_filter->policies = std::move(mixer.policies);
mixed_filters->emplace(key, std::move(mixed_filter));
}
enabled.mixed_filters.store(mixed_filters);

View File

@ -987,10 +987,10 @@ std::shared_ptr<const ContextAccess> Context::getAccess() const
return access ? access : ContextAccess::getFullAccess();
}
RowPolicyFilter Context::getRowPolicyFilter(const String & database, const String & table_name, RowPolicyFilterType filter_type) const
RowPolicyFilterPtr Context::getRowPolicyFilter(const String & database, const String & table_name, RowPolicyFilterType filter_type) const
{
auto lock = getLock();
RowPolicyFilter row_filter_of_initial_user;
RowPolicyFilterPtr row_filter_of_initial_user;
if (row_policies_of_initial_user)
row_filter_of_initial_user = row_policies_of_initial_user->getFilter(database, table_name, filter_type);
return getAccess()->getRowPolicyFilter(database, table_name, filter_type, row_filter_of_initial_user);

View File

@ -46,6 +46,7 @@ using UserPtr = std::shared_ptr<const User>;
struct EnabledRolesInfo;
class EnabledRowPolicies;
struct RowPolicyFilter;
using RowPolicyFilterPtr = std::shared_ptr<const RowPolicyFilter>;
class EnabledQuota;
struct QuotaUsage;
class AccessFlags;
@ -517,7 +518,7 @@ public:
std::shared_ptr<const ContextAccess> getAccess() const;
RowPolicyFilter getRowPolicyFilter(const String & database, const String & table_name, RowPolicyFilterType filter_type) const;
RowPolicyFilterPtr getRowPolicyFilter(const String & database, const String & table_name, RowPolicyFilterType filter_type) const;
/// Finds and sets extra row policies to be used based on `client_info.initial_user`,
/// if the initial user exists.

View File

@ -199,11 +199,14 @@ void InterpreterSelectIntersectExceptQuery::extendQueryLogElemImpl(QueryLogEleme
{
if (auto select_interpreter = dynamic_cast<InterpreterSelectQuery *>(interpreter.get()))
{
auto policies = select_interpreter->getUsedRowPolicies();
for (const auto & row_policy : policies)
auto filter = select_interpreter->getRowPolicyFilter();
if (filter)
{
auto name = row_policy->getFullName().toString();
elem.used_row_policies.emplace(std::move(name));
for (const auto & row_policy : filter->policies)
{
auto name = row_policy->getFullName().toString();
elem.used_row_policies.emplace(std::move(name));
}
}
}
}

View File

@ -616,13 +616,13 @@ InterpreterSelectQuery::InterpreterSelectQuery(
query_info.filter_asts.clear();
/// Fix source_header for filter actions.
if (row_policy_filter.expression)
if (row_policy_filter && !row_policy_filter->empty())
{
filter_info = generateFilterActions(
table_id, row_policy_filter.expression, context, storage, storage_snapshot, metadata_snapshot, required_columns,
table_id, row_policy_filter->expression, context, storage, storage_snapshot, metadata_snapshot, required_columns,
prepared_sets);
query_info.filter_asts.push_back(row_policy_filter.expression);
query_info.filter_asts.push_back(row_policy_filter->expression);
}
if (query_info.additional_filter_ast)
@ -1869,16 +1869,16 @@ void InterpreterSelectQuery::setProperClientInfo(size_t replica_num, size_t repl
context->getClientInfo().number_of_current_replica = replica_num;
}
const std::vector<RowPolicyPtr> & InterpreterSelectQuery::getUsedRowPolicies() const
RowPolicyFilterPtr InterpreterSelectQuery::getRowPolicyFilter() const
{
return row_policy_filter.policies;
return row_policy_filter;
}
void InterpreterSelectQuery::extendQueryLogElemImpl(QueryLogElement & elem, const ASTPtr & /*ast*/, ContextPtr /*context_*/) const
{
elem.query_kind = "Select";
for (const auto & row_policy : row_policy_filter.policies)
for (const auto & row_policy : row_policy_filter->policies)
{
auto name = row_policy->getFullName().toString();
elem.used_row_policies.emplace(std::move(name));

View File

@ -134,7 +134,7 @@ public:
FilterDAGInfoPtr getAdditionalQueryInfo() const { return additional_filter_info; }
const std::vector<RowPolicyPtr> & getUsedRowPolicies() const;
RowPolicyFilterPtr getRowPolicyFilter() const;
void extendQueryLogElemImpl(QueryLogElement & elem, const ASTPtr & ast, ContextPtr context) const override;
@ -218,7 +218,7 @@ private:
/// Is calculated in getSampleBlock. Is used later in readImpl.
ExpressionAnalysisResult analysis_result;
/// For row-level security.
RowPolicyFilter row_policy_filter;
RowPolicyFilterPtr row_policy_filter;
FilterDAGInfoPtr filter_info;
/// For additional_filter setting.

View File

@ -394,11 +394,14 @@ void InterpreterSelectWithUnionQuery::extendQueryLogElemImpl(QueryLogElement & e
{
if (auto select_interpreter = dynamic_cast<InterpreterSelectQuery *>(interpreter.get()))
{
auto policies = select_interpreter->getUsedRowPolicies();
for (const auto & row_policy : policies)
auto filter = select_interpreter->getRowPolicyFilter();
if (filter)
{
auto name = row_policy->getFullName().toString();
elem.used_row_policies.emplace(std::move(name));
for (const auto & row_policy : filter->policies)
{
auto name = row_policy->getFullName().toString();
elem.used_row_policies.emplace(std::move(name));
}
}
}
}