Improve system table for row policies. Remove function currentRowPolicies().

This commit is contained in:
Vitaly Baranov 2020-05-07 05:45:27 +03:00
parent 9a89a04c1f
commit e64e2ebdf6
33 changed files with 681 additions and 836 deletions

View File

@ -485,7 +485,7 @@ std::shared_ptr<const EnabledRolesInfo> ContextAccess::getRolesInfo() const
return roles_info; return roles_info;
} }
std::shared_ptr<const EnabledRowPolicies> ContextAccess::getRowPolicies() const std::shared_ptr<const EnabledRowPolicies> ContextAccess::getEnabledRowPolicies() const
{ {
std::lock_guard lock{mutex}; std::lock_guard lock{mutex};
return enabled_row_policies; return enabled_row_policies;

View File

@ -70,7 +70,7 @@ public:
/// Returns information about enabled row policies. /// Returns information about enabled row policies.
/// The function can return nullptr. /// The function can return nullptr.
std::shared_ptr<const EnabledRowPolicies> getRowPolicies() const; std::shared_ptr<const EnabledRowPolicies> getEnabledRowPolicies() const;
/// Returns the row policy filter for a specified table. /// Returns the row policy filter for a specified table.
/// The function returns nullptr if there is no filter to apply. /// The function returns nullptr if there is no filter to apply.

View File

@ -6,9 +6,9 @@
namespace DB namespace DB
{ {
size_t EnabledRowPolicies::Hash::operator()(const DatabaseAndTableNameRef & database_and_table_name) const size_t EnabledRowPolicies::Hash::operator()(const MixedConditionKey & key) const
{ {
return std::hash<std::string_view>{}(database_and_table_name.first) - std::hash<std::string_view>{}(database_and_table_name.second); return std::hash<std::string_view>{}(key.database) - std::hash<std::string_view>{}(key.table_name) + static_cast<size_t>(key.condition_type);
} }
@ -20,16 +20,22 @@ EnabledRowPolicies::EnabledRowPolicies(const Params & params_)
EnabledRowPolicies::~EnabledRowPolicies() = default; EnabledRowPolicies::~EnabledRowPolicies() = default;
ASTPtr EnabledRowPolicies::getCondition(const String & database, const String & table_name, ConditionType type) const ASTPtr EnabledRowPolicies::getCondition(const String & database, const String & table_name, ConditionType condition_type) const
{ {
/// We don't lock `mutex` here. /// We don't lock `mutex` here.
auto loaded = map_of_mixed_conditions.load(); auto loaded = map_of_mixed_conditions.load();
auto it = loaded->find({database, table_name}); auto it = loaded->find({database, table_name, condition_type});
if (it == loaded->end()) if (it == loaded->end())
return {}; return {};
return it->second.mixed_conditions[type];
}
auto condition = it->second.ast;
bool value;
if (tryGetLiteralBool(condition.get(), value) && value)
return nullptr; /// The condition is always true, no need to check it.
return condition;
}
ASTPtr EnabledRowPolicies::getCondition(const String & database, const String & table_name, ConditionType type, const ASTPtr & extra_condition) const ASTPtr EnabledRowPolicies::getCondition(const String & database, const String & table_name, ConditionType type, const ASTPtr & extra_condition) const
{ {
@ -41,31 +47,9 @@ ASTPtr EnabledRowPolicies::getCondition(const String & database, const String &
bool value; bool value;
if (tryGetLiteralBool(condition.get(), value) && value) if (tryGetLiteralBool(condition.get(), value) && value)
condition = nullptr; /// The condition is always true, no need to check it. return nullptr; /// The condition is always true, no need to check it.
return condition; return condition;
} }
std::vector<UUID> EnabledRowPolicies::getCurrentPolicyIDs() const
{
/// We don't lock `mutex` here.
auto loaded = map_of_mixed_conditions.load();
std::vector<UUID> policy_ids;
for (const auto & mixed_conditions : *loaded | boost::adaptors::map_values)
boost::range::copy(mixed_conditions.policy_ids, std::back_inserter(policy_ids));
return policy_ids;
}
std::vector<UUID> EnabledRowPolicies::getCurrentPolicyIDs(const String & database, const String & table_name) const
{
/// We don't lock `mutex` here.
auto loaded = map_of_mixed_conditions.load();
auto it = loaded->find({database, table_name});
if (it == loaded->end())
return {};
return it->second.policy_ids;
}
} }

View File

@ -4,8 +4,8 @@
#include <Core/Types.h> #include <Core/Types.h>
#include <Core/UUID.h> #include <Core/UUID.h>
#include <boost/smart_ptr/atomic_shared_ptr.hpp> #include <boost/smart_ptr/atomic_shared_ptr.hpp>
#include <memory>
#include <unordered_map> #include <unordered_map>
#include <memory>
namespace DB namespace DB
@ -42,30 +42,32 @@ public:
ASTPtr getCondition(const String & database, const String & table_name, ConditionType type) const; ASTPtr getCondition(const String & database, const String & table_name, ConditionType type) const;
ASTPtr getCondition(const String & database, const String & table_name, ConditionType type, const ASTPtr & extra_condition) const; ASTPtr getCondition(const String & database, const String & table_name, ConditionType type, const ASTPtr & extra_condition) const;
/// Returns IDs of all the policies used by the current user.
std::vector<UUID> getCurrentPolicyIDs() const;
/// Returns IDs of the policies used by a concrete table.
std::vector<UUID> getCurrentPolicyIDs(const String & database, const String & table_name) const;
private: private:
friend class RowPolicyCache; friend class RowPolicyCache;
EnabledRowPolicies(const Params & params_); EnabledRowPolicies(const Params & params_);
using DatabaseAndTableName = std::pair<String, String>; struct MixedConditionKey
using DatabaseAndTableNameRef = std::pair<std::string_view, std::string_view>; {
std::string_view database;
std::string_view table_name;
ConditionType condition_type;
auto toTuple() const { return std::tie(database, table_name, condition_type); }
friend bool operator==(const MixedConditionKey & left, const MixedConditionKey & right) { return left.toTuple() == right.toTuple(); }
friend bool operator!=(const MixedConditionKey & left, const MixedConditionKey & right) { return left.toTuple() != right.toTuple(); }
};
struct Hash struct Hash
{ {
size_t operator()(const DatabaseAndTableNameRef & database_and_table_name) const; size_t operator()(const MixedConditionKey & key) const;
}; };
using ParsedConditions = std::array<ASTPtr, RowPolicy::MAX_CONDITION_TYPE>;
struct MixedConditions struct MixedCondition
{ {
std::unique_ptr<DatabaseAndTableName> database_and_table_name_keeper; ASTPtr ast;
ParsedConditions mixed_conditions; std::shared_ptr<const std::pair<String, String>> database_and_table_name;
std::vector<UUID> policy_ids;
}; };
using MapOfMixedConditions = std::unordered_map<DatabaseAndTableNameRef, MixedConditions, Hash>; using MapOfMixedConditions = std::unordered_map<MixedConditionKey, MixedCondition, Hash>;
const Params params; const Params params;
mutable boost::atomic_shared_ptr<const MapOfMixedConditions> map_of_mixed_conditions; mutable boost::atomic_shared_ptr<const MapOfMixedConditions> map_of_mixed_conditions;

View File

@ -8,7 +8,6 @@ namespace DB
namespace ErrorCodes namespace ErrorCodes
{ {
extern const int NOT_IMPLEMENTED; extern const int NOT_IMPLEMENTED;
extern const int LOGICAL_ERROR;
} }
@ -75,34 +74,4 @@ bool RowPolicy::equal(const IAccessEntity & other) const
&& restrictive == other_policy.restrictive && (to_roles == other_policy.to_roles); && restrictive == other_policy.restrictive && (to_roles == other_policy.to_roles);
} }
const char * RowPolicy::conditionTypeToString(ConditionType index)
{
switch (index)
{
case SELECT_FILTER: return "SELECT_FILTER";
case INSERT_CHECK: return "INSERT_CHECK";
case UPDATE_FILTER: return "UPDATE_FILTER";
case UPDATE_CHECK: return "UPDATE_CHECK";
case DELETE_FILTER: return "DELETE_FILTER";
case MAX_CONDITION_TYPE: break;
}
throw Exception("Unexpected condition type: " + std::to_string(static_cast<int>(index)), ErrorCodes::LOGICAL_ERROR);
}
const char * RowPolicy::conditionTypeToColumnName(ConditionType index)
{
switch (index)
{
case SELECT_FILTER: return "select_filter";
case INSERT_CHECK: return "insert_check";
case UPDATE_FILTER: return "update_filter";
case UPDATE_CHECK: return "update_check";
case DELETE_FILTER: return "delete_filter";
case MAX_CONDITION_TYPE: break;
}
throw Exception("Unexpected condition type: " + std::to_string(static_cast<int>(index)), ErrorCodes::LOGICAL_ERROR);
}
} }

View File

@ -2,10 +2,16 @@
#include <Access/IAccessEntity.h> #include <Access/IAccessEntity.h>
#include <Access/ExtendedRoleSet.h> #include <Access/ExtendedRoleSet.h>
#include <array>
namespace DB namespace DB
{ {
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
}
/** Represents a row level security policy for a table. /** Represents a row level security policy for a table.
*/ */
@ -43,17 +49,27 @@ struct RowPolicy : public IAccessEntity
enum ConditionType enum ConditionType
{ {
SELECT_FILTER, SELECT_FILTER,
#if 0 /// Row-level security for INSERT, UPDATE, DELETE is not implemented yet.
INSERT_CHECK, INSERT_CHECK,
UPDATE_FILTER, UPDATE_FILTER,
UPDATE_CHECK, UPDATE_CHECK,
DELETE_FILTER, DELETE_FILTER,
#endif
MAX_CONDITION_TYPE MAX_CONDITION_TYPE
}; };
static const char * conditionTypeToString(ConditionType index);
static const char * conditionTypeToColumnName(ConditionType index);
String conditions[MAX_CONDITION_TYPE]; struct ConditionTypeInfo
{
const char * const raw_name;
const String name; /// Lowercased with underscores, e.g. "select_filter".
const String command; /// Uppercased without last word, e.g. "SELECT".
const bool is_check; /// E.g. false for SELECT_FILTER.
static const ConditionTypeInfo & get(ConditionType type);
};
std::array<String, MAX_CONDITION_TYPE> conditions;
/// Sets that the policy is permissive. /// Sets that the policy is permissive.
/// A row is only accessible if at least one of the permissive policies passes, /// A row is only accessible if at least one of the permissive policies passes,
@ -83,4 +99,58 @@ private:
}; };
using RowPolicyPtr = std::shared_ptr<const RowPolicy>; using RowPolicyPtr = std::shared_ptr<const RowPolicy>;
inline const RowPolicy::ConditionTypeInfo & RowPolicy::ConditionTypeInfo::get(ConditionType type_)
{
static constexpr auto make_info = [](const char * raw_name_)
{
String init_name = raw_name_;
boost::to_lower(init_name);
size_t underscore_pos = init_name.find('_');
String init_command = init_name.substr(0, underscore_pos);
boost::to_upper(init_command);
bool init_is_check = (std::string_view{init_name}.substr(underscore_pos + 1) == "check");
return ConditionTypeInfo{raw_name_, std::move(init_name), std::move(init_command), init_is_check};
};
switch (type_)
{
case SELECT_FILTER:
{
static const ConditionTypeInfo info = make_info("SELECT_FILTER");
return info;
}
#if 0 /// Row-level security for INSERT, UPDATE, DELETE is not implemented yet.
case INSERT_CHECK:
{
static const ConditionTypeInfo info = make_info("INSERT_CHECK");
return info;
}
case UPDATE_FILTER:
{
static const ConditionTypeInfo info = make_info("UPDATE_FILTER");
return info;
}
case UPDATE_CHECK:
{
static const ConditionTypeInfo info = make_info("UPDATE_CHECK");
return info;
}
case DELETE_FILTER:
{
static const ConditionTypeInfo info = make_info("DELETE_FILTER");
return info;
}
#endif
case MAX_CONDITION_TYPE: break;
}
throw Exception("Unknown type: " + std::to_string(static_cast<size_t>(type_)), ErrorCodes::LOGICAL_ERROR);
}
inline String toString(RowPolicy::ConditionType type)
{
return RowPolicy::ConditionTypeInfo::get(type).raw_name;
}
} }

View File

@ -57,6 +57,7 @@ void RowPolicyCache::PolicyInfo::setPolicy(const RowPolicyPtr & policy_)
{ {
policy = policy_; policy = policy_;
roles = &policy->to_roles; roles = &policy->to_roles;
database_and_table_name = std::make_shared<std::pair<String, String>>(policy->getDatabase(), policy->getTableName());
for (auto type : ext::range(0, MAX_CONDITION_TYPE)) for (auto type : ext::range(0, MAX_CONDITION_TYPE))
{ {
@ -84,7 +85,7 @@ void RowPolicyCache::PolicyInfo::setPolicy(const RowPolicyPtr & policy_)
{ {
tryLogCurrentException( tryLogCurrentException(
&Poco::Logger::get("RowPolicy"), &Poco::Logger::get("RowPolicy"),
String("Could not parse the condition ") + RowPolicy::conditionTypeToString(type) + " of row policy " String("Could not parse the condition ") + toString(type) + " of row policy "
+ backQuote(policy->getName())); + backQuote(policy->getName()));
} }
} }
@ -196,43 +197,45 @@ void RowPolicyCache::mixConditions()
void RowPolicyCache::mixConditionsFor(EnabledRowPolicies & enabled) void RowPolicyCache::mixConditionsFor(EnabledRowPolicies & enabled)
{ {
/// `mutex` is already locked. /// `mutex` is already locked.
struct Mixers
{
ConditionsMixer mixers[MAX_CONDITION_TYPE];
std::vector<UUID> policy_ids;
};
using MapOfMixedConditions = EnabledRowPolicies::MapOfMixedConditions; using MapOfMixedConditions = EnabledRowPolicies::MapOfMixedConditions;
using DatabaseAndTableName = EnabledRowPolicies::DatabaseAndTableName; using MixedConditionKey = EnabledRowPolicies::MixedConditionKey;
using DatabaseAndTableNameRef = EnabledRowPolicies::DatabaseAndTableNameRef;
using Hash = EnabledRowPolicies::Hash; using Hash = EnabledRowPolicies::Hash;
std::unordered_map<DatabaseAndTableName, Mixers, Hash> map_of_mixers; struct MixerWithNames
{
ConditionsMixer mixer;
std::shared_ptr<const std::pair<String, String>> database_and_table_name;
};
std::unordered_map<MixedConditionKey, MixerWithNames, Hash> map_of_mixers;
for (const auto & [policy_id, info] : all_policies) for (const auto & [policy_id, info] : all_policies)
{ {
const auto & policy = *info.policy; const auto & policy = *info.policy;
auto & mixers = map_of_mixers[std::pair{policy.getDatabase(), policy.getTableName()}]; bool match = info.roles->match(enabled.params.user_id, enabled.params.enabled_roles);
if (info.roles->match(enabled.params.user_id, enabled.params.enabled_roles)) MixedConditionKey key;
key.database = info.database_and_table_name->first;
key.table_name = info.database_and_table_name->second;
for (auto type : ext::range(0, MAX_CONDITION_TYPE))
{ {
mixers.policy_ids.push_back(policy_id); if (info.parsed_conditions[type])
for (auto type : ext::range(0, MAX_CONDITION_TYPE)) {
if (info.parsed_conditions[type]) key.condition_type = type;
mixers.mixers[type].add(info.parsed_conditions[type], policy.isRestrictive()); auto & mixer = map_of_mixers[key];
mixer.database_and_table_name = info.database_and_table_name;
if (match)
mixer.mixer.add(info.parsed_conditions[type], policy.isRestrictive());
}
} }
} }
auto map_of_mixed_conditions = boost::make_shared<MapOfMixedConditions>(); auto map_of_mixed_conditions = boost::make_shared<MapOfMixedConditions>();
for (auto & [database_and_table_name, mixers] : map_of_mixers) for (auto & [key, mixer] : map_of_mixers)
{ {
auto database_and_table_name_keeper = std::make_unique<DatabaseAndTableName>(); auto & mixed_condition = (*map_of_mixed_conditions)[key];
database_and_table_name_keeper->first = database_and_table_name.first; mixed_condition.database_and_table_name = mixer.database_and_table_name;
database_and_table_name_keeper->second = database_and_table_name.second; mixed_condition.ast = std::move(mixer.mixer).getResult();
auto & mixed_conditions = (*map_of_mixed_conditions)[DatabaseAndTableNameRef{database_and_table_name_keeper->first,
database_and_table_name_keeper->second}];
mixed_conditions.database_and_table_name_keeper = std::move(database_and_table_name_keeper);
mixed_conditions.policy_ids = std::move(mixers.policy_ids);
for (auto type : ext::range(0, MAX_CONDITION_TYPE))
mixed_conditions.mixed_conditions[type] = std::move(mixers.mixers[type]).getResult();
} }
enabled.map_of_mixed_conditions.store(map_of_mixed_conditions); enabled.map_of_mixed_conditions.store(map_of_mixed_conditions);

View File

@ -21,8 +21,6 @@ public:
std::shared_ptr<const EnabledRowPolicies> getEnabledRowPolicies(const UUID & user_id, const boost::container::flat_set<UUID> & enabled_roles); std::shared_ptr<const EnabledRowPolicies> getEnabledRowPolicies(const UUID & user_id, const boost::container::flat_set<UUID> & enabled_roles);
private: private:
using ParsedConditions = EnabledRowPolicies::ParsedConditions;
struct PolicyInfo struct PolicyInfo
{ {
PolicyInfo(const RowPolicyPtr & policy_) { setPolicy(policy_); } PolicyInfo(const RowPolicyPtr & policy_) { setPolicy(policy_); }
@ -30,7 +28,8 @@ private:
RowPolicyPtr policy; RowPolicyPtr policy;
const ExtendedRoleSet * roles = nullptr; const ExtendedRoleSet * roles = nullptr;
ParsedConditions parsed_conditions; std::shared_ptr<const std::pair<String, String>> database_and_table_name;
ASTPtr parsed_conditions[RowPolicy::MAX_CONDITION_TYPE];
}; };
void ensureAllRowPoliciesRead(); void ensureAllRowPoliciesRead();

View File

@ -1,237 +0,0 @@
#include <Functions/IFunction.h>
#include <Functions/FunctionFactory.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeString.h>
#include <DataTypes/DataTypeTuple.h>
#include <DataTypes/DataTypeUUID.h>
#include <Columns/ColumnArray.h>
#include <Columns/ColumnConst.h>
#include <Columns/ColumnString.h>
#include <Columns/ColumnTuple.h>
#include <Interpreters/Context.h>
#include <Access/EnabledRowPolicies.h>
#include <Access/AccessControlManager.h>
#include <ext/range.h>
#include <IO/WriteHelpers.h>
namespace DB
{
namespace ErrorCodes
{
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
/// The currentRowPolicies() function can be called with 0..2 arguments:
/// currentRowPolicies() returns array of tuples (database, table_name, row_policy_name) for all the row policies applied for the current user;
/// currentRowPolicies(table_name) is equivalent to currentRowPolicies(currentDatabase(), table_name);
/// currentRowPolicies(database, table_name) returns array of names of the row policies applied to a specific table and for the current user.
class FunctionCurrentRowPolicies : public IFunction
{
public:
static constexpr auto name = "currentRowPolicies";
static FunctionPtr create(const Context & context_) { return std::make_shared<FunctionCurrentRowPolicies>(context_); }
explicit FunctionCurrentRowPolicies(const Context & context_) : context(context_) {}
String getName() const override { return name; }
size_t getNumberOfArguments() const override { return 0; }
bool isVariadic() const override { return true; }
void checkNumberOfArgumentsIfVariadic(size_t number_of_arguments) const override
{
if (number_of_arguments > 2)
throw Exception("Number of arguments for function " + String(name) + " doesn't match: passed "
+ toString(number_of_arguments) + ", should be 0..2",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
}
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (arguments.empty())
return std::make_shared<DataTypeArray>(std::make_shared<DataTypeTuple>(
DataTypes{std::make_shared<DataTypeString>(), std::make_shared<DataTypeString>(), std::make_shared<DataTypeString>()}));
else
return std::make_shared<DataTypeArray>(std::make_shared<DataTypeString>());
}
bool isDeterministic() const override { return false; }
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result_pos, size_t input_rows_count) override
{
if (arguments.empty())
{
auto database_column = ColumnString::create();
auto table_name_column = ColumnString::create();
auto policy_name_column = ColumnString::create();
if (auto policies = context.getRowPolicies())
{
for (const auto & policy_id : policies->getCurrentPolicyIDs())
{
const auto policy = context.getAccessControlManager().tryRead<RowPolicy>(policy_id);
if (policy)
{
const String database = policy->getDatabase();
const String table_name = policy->getTableName();
const String policy_name = policy->getShortName();
database_column->insertData(database.data(), database.length());
table_name_column->insertData(table_name.data(), table_name.length());
policy_name_column->insertData(policy_name.data(), policy_name.length());
}
}
}
auto offset_column = ColumnArray::ColumnOffsets::create();
offset_column->insertValue(policy_name_column->size());
block.getByPosition(result_pos).column = ColumnConst::create(
ColumnArray::create(
ColumnTuple::create(Columns{std::move(database_column), std::move(table_name_column), std::move(policy_name_column)}),
std::move(offset_column)),
input_rows_count);
return;
}
const IColumn * database_column = nullptr;
if (arguments.size() == 2)
{
const auto & database_column_with_type = block.getByPosition(arguments[0]);
if (!isStringOrFixedString(database_column_with_type.type))
throw Exception{"The first argument of function " + String(name)
+ " should be a string containing database name, illegal type: "
+ database_column_with_type.type->getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
database_column = database_column_with_type.column.get();
}
const auto & table_name_column_with_type = block.getByPosition(arguments[arguments.size() - 1]);
if (!isStringOrFixedString(table_name_column_with_type.type))
throw Exception{"The" + String(database_column ? " last" : "") + " argument of function " + String(name)
+ " should be a string containing table name, illegal type: " + table_name_column_with_type.type->getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
const IColumn * table_name_column = table_name_column_with_type.column.get();
auto policy_name_column = ColumnString::create();
auto offset_column = ColumnArray::ColumnOffsets::create();
for (const auto i : ext::range(0, input_rows_count))
{
String database = database_column ? database_column->getDataAt(i).toString() : context.getCurrentDatabase();
String table_name = table_name_column->getDataAt(i).toString();
if (auto policies = context.getRowPolicies())
{
for (const auto & policy_id : policies->getCurrentPolicyIDs(database, table_name))
{
const auto policy = context.getAccessControlManager().tryRead<RowPolicy>(policy_id);
if (policy)
{
const String policy_name = policy->getShortName();
policy_name_column->insertData(policy_name.data(), policy_name.length());
}
}
}
offset_column->insertValue(policy_name_column->size());
}
block.getByPosition(result_pos).column = ColumnArray::create(std::move(policy_name_column), std::move(offset_column));
}
private:
const Context & context;
};
/// The currentRowPolicyIDs() function can be called with 0..2 arguments:
/// currentRowPolicyIDs() returns array of IDs of all the row policies applied for the current user;
/// currentRowPolicyIDs(table_name) is equivalent to currentRowPolicyIDs(currentDatabase(), table_name);
/// currentRowPolicyIDs(database, table_name) returns array of IDs of the row policies applied to a specific table and for the current user.
class FunctionCurrentRowPolicyIDs : public IFunction
{
public:
static constexpr auto name = "currentRowPolicyIDs";
static FunctionPtr create(const Context & context_) { return std::make_shared<FunctionCurrentRowPolicyIDs>(context_); }
explicit FunctionCurrentRowPolicyIDs(const Context & context_) : context(context_) {}
String getName() const override { return name; }
size_t getNumberOfArguments() const override { return 0; }
bool isVariadic() const override { return true; }
void checkNumberOfArgumentsIfVariadic(size_t number_of_arguments) const override
{
if (number_of_arguments > 2)
throw Exception("Number of arguments for function " + String(name) + " doesn't match: passed "
+ toString(number_of_arguments) + ", should be 0..2",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
}
DataTypePtr getReturnTypeImpl(const DataTypes & /* arguments */) const override
{
return std::make_shared<DataTypeArray>(std::make_shared<DataTypeUUID>());
}
bool isDeterministic() const override { return false; }
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result_pos, size_t input_rows_count) override
{
if (arguments.empty())
{
auto policy_id_column = ColumnVector<UInt128>::create();
if (auto policies = context.getRowPolicies())
{
for (const auto & policy_id : policies->getCurrentPolicyIDs())
policy_id_column->insertValue(policy_id);
}
auto offset_column = ColumnArray::ColumnOffsets::create();
offset_column->insertValue(policy_id_column->size());
block.getByPosition(result_pos).column
= ColumnConst::create(ColumnArray::create(std::move(policy_id_column), std::move(offset_column)), input_rows_count);
return;
}
const IColumn * database_column = nullptr;
if (arguments.size() == 2)
{
const auto & database_column_with_type = block.getByPosition(arguments[0]);
if (!isStringOrFixedString(database_column_with_type.type))
throw Exception{"The first argument of function " + String(name)
+ " should be a string containing database name, illegal type: "
+ database_column_with_type.type->getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
database_column = database_column_with_type.column.get();
}
const auto & table_name_column_with_type = block.getByPosition(arguments[arguments.size() - 1]);
if (!isStringOrFixedString(table_name_column_with_type.type))
throw Exception{"The" + String(database_column ? " last" : "") + " argument of function " + String(name)
+ " should be a string containing table name, illegal type: " + table_name_column_with_type.type->getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
const IColumn * table_name_column = table_name_column_with_type.column.get();
auto policy_id_column = ColumnVector<UInt128>::create();
auto offset_column = ColumnArray::ColumnOffsets::create();
for (const auto i : ext::range(0, input_rows_count))
{
String database = database_column ? database_column->getDataAt(i).toString() : context.getCurrentDatabase();
String table_name = table_name_column->getDataAt(i).toString();
if (auto policies = context.getRowPolicies())
{
for (const auto & policy_id : policies->getCurrentPolicyIDs(database, table_name))
policy_id_column->insertValue(policy_id);
}
offset_column->insertValue(policy_id_column->size());
}
block.getByPosition(result_pos).column = ColumnArray::create(std::move(policy_id_column), std::move(offset_column));
}
private:
const Context & context;
};
void registerFunctionCurrentRowPolicies(FunctionFactory & factory)
{
factory.registerFunction<FunctionCurrentRowPolicies>();
factory.registerFunction<FunctionCurrentRowPolicyIDs>();
}
}

View File

@ -10,7 +10,6 @@ class FunctionFactory;
void registerFunctionCurrentDatabase(FunctionFactory &); void registerFunctionCurrentDatabase(FunctionFactory &);
void registerFunctionCurrentUser(FunctionFactory &); void registerFunctionCurrentUser(FunctionFactory &);
void registerFunctionCurrentQuota(FunctionFactory &); void registerFunctionCurrentQuota(FunctionFactory &);
void registerFunctionCurrentRowPolicies(FunctionFactory &);
void registerFunctionHostName(FunctionFactory &); void registerFunctionHostName(FunctionFactory &);
void registerFunctionFQDN(FunctionFactory &); void registerFunctionFQDN(FunctionFactory &);
void registerFunctionVisibleWidth(FunctionFactory &); void registerFunctionVisibleWidth(FunctionFactory &);
@ -69,7 +68,6 @@ void registerFunctionsMiscellaneous(FunctionFactory & factory)
registerFunctionCurrentDatabase(factory); registerFunctionCurrentDatabase(factory);
registerFunctionCurrentUser(factory); registerFunctionCurrentUser(factory);
registerFunctionCurrentQuota(factory); registerFunctionCurrentQuota(factory);
registerFunctionCurrentRowPolicies(factory);
registerFunctionHostName(factory); registerFunctionHostName(factory);
registerFunctionFQDN(factory); registerFunctionFQDN(factory);
registerFunctionVisibleWidth(factory); registerFunctionVisibleWidth(factory);

View File

@ -128,7 +128,6 @@ SRCS(
CRC.cpp CRC.cpp
currentDatabase.cpp currentDatabase.cpp
currentQuota.cpp currentQuota.cpp
currentRowPolicies.cpp
currentUser.cpp currentUser.cpp
dateDiff.cpp dateDiff.cpp
defaultValueOfArgumentType.cpp defaultValueOfArgumentType.cpp

View File

@ -776,11 +776,6 @@ ASTPtr Context::getRowPolicyCondition(const String & database, const String & ta
return getAccess()->getRowPolicyCondition(database, table_name, type, initial_condition); return getAccess()->getRowPolicyCondition(database, table_name, type, initial_condition);
} }
std::shared_ptr<const EnabledRowPolicies> Context::getRowPolicies() const
{
return getAccess()->getRowPolicies();
}
void Context::setInitialRowPolicy() void Context::setInitialRowPolicy()
{ {
auto lock = getLock(); auto lock = getLock();

View File

@ -278,7 +278,6 @@ public:
std::shared_ptr<const ContextAccess> getAccess() const; std::shared_ptr<const ContextAccess> getAccess() const;
std::shared_ptr<const EnabledRowPolicies> getRowPolicies() const;
ASTPtr getRowPolicyCondition(const String & database, const String & table_name, RowPolicy::ConditionType type) const; ASTPtr getRowPolicyCondition(const String & database, const String & table_name, RowPolicy::ConditionType type) const;
/// Sets an extra row policy based on `client_info.initial_user`, if it exists. /// Sets an extra row policy based on `client_info.initial_user`, if it exists.

View File

@ -29,8 +29,12 @@ namespace
if (query.is_restrictive) if (query.is_restrictive)
policy.setRestrictive(*query.is_restrictive); policy.setRestrictive(*query.is_restrictive);
for (const auto & [index, condition] : query.conditions) for (auto condition_type : ext::range(RowPolicy::MAX_CONDITION_TYPE))
policy.conditions[index] = condition ? serializeAST(*condition) : String{}; {
const auto & condition = query.conditions[condition_type];
if (condition)
policy.conditions[condition_type] = *condition ? serializeAST(**condition) : String{};
}
const ExtendedRoleSet * roles = nullptr; const ExtendedRoleSet * roles = nullptr;
std::optional<ExtendedRoleSet> temp_role_set; std::optional<ExtendedRoleSet> temp_role_set;

View File

@ -16,11 +16,11 @@
#include <Parsers/ASTSelectWithUnionQuery.h> #include <Parsers/ASTSelectWithUnionQuery.h>
#include <Parsers/ASTSetQuery.h> #include <Parsers/ASTSetQuery.h>
#include <Parsers/ASTSetRoleQuery.h> #include <Parsers/ASTSetRoleQuery.h>
#include <Parsers/ASTShowAccessEntitiesQuery.h>
#include <Parsers/ASTShowCreateAccessEntityQuery.h> #include <Parsers/ASTShowCreateAccessEntityQuery.h>
#include <Parsers/ASTShowProcesslistQuery.h>
#include <Parsers/ASTShowGrantsQuery.h> #include <Parsers/ASTShowGrantsQuery.h>
#include <Parsers/ASTShowQuotasQuery.h> #include <Parsers/ASTShowQuotasQuery.h>
#include <Parsers/ASTShowRowPoliciesQuery.h> #include <Parsers/ASTShowProcesslistQuery.h>
#include <Parsers/ASTShowTablesQuery.h> #include <Parsers/ASTShowTablesQuery.h>
#include <Parsers/ASTUseQuery.h> #include <Parsers/ASTUseQuery.h>
#include <Parsers/ASTExplainQuery.h> #include <Parsers/ASTExplainQuery.h>
@ -50,12 +50,12 @@
#include <Interpreters/InterpreterSelectWithUnionQuery.h> #include <Interpreters/InterpreterSelectWithUnionQuery.h>
#include <Interpreters/InterpreterSetQuery.h> #include <Interpreters/InterpreterSetQuery.h>
#include <Interpreters/InterpreterSetRoleQuery.h> #include <Interpreters/InterpreterSetRoleQuery.h>
#include <Interpreters/InterpreterShowAccessEntitiesQuery.h>
#include <Interpreters/InterpreterShowCreateAccessEntityQuery.h> #include <Interpreters/InterpreterShowCreateAccessEntityQuery.h>
#include <Interpreters/InterpreterShowCreateQuery.h>
#include <Interpreters/InterpreterShowProcesslistQuery.h>
#include <Interpreters/InterpreterShowGrantsQuery.h> #include <Interpreters/InterpreterShowGrantsQuery.h>
#include <Interpreters/InterpreterShowQuotasQuery.h> #include <Interpreters/InterpreterShowQuotasQuery.h>
#include <Interpreters/InterpreterShowRowPoliciesQuery.h> #include <Interpreters/InterpreterShowCreateQuery.h>
#include <Interpreters/InterpreterShowProcesslistQuery.h>
#include <Interpreters/InterpreterShowTablesQuery.h> #include <Interpreters/InterpreterShowTablesQuery.h>
#include <Interpreters/InterpreterSystemQuery.h> #include <Interpreters/InterpreterSystemQuery.h>
#include <Interpreters/InterpreterUseQuery.h> #include <Interpreters/InterpreterUseQuery.h>
@ -230,9 +230,9 @@ std::unique_ptr<IInterpreter> InterpreterFactory::get(ASTPtr & query, Context &
{ {
return std::make_unique<InterpreterShowQuotasQuery>(query, context); return std::make_unique<InterpreterShowQuotasQuery>(query, context);
} }
else if (query->as<ASTShowRowPoliciesQuery>()) else if (query->as<ASTShowAccessEntitiesQuery>())
{ {
return std::make_unique<InterpreterShowRowPoliciesQuery>(query, context); return std::make_unique<InterpreterShowAccessEntitiesQuery>(query, context);
} }
else else
throw Exception("Unknown type of query: " + query->getID(), ErrorCodes::UNKNOWN_TYPE_OF_QUERY); throw Exception("Unknown type of query: " + query->getID(), ErrorCodes::UNKNOWN_TYPE_OF_QUERY);

View File

@ -0,0 +1,67 @@
#include <Interpreters/InterpreterShowAccessEntitiesQuery.h>
#include <Parsers/ASTShowAccessEntitiesQuery.h>
#include <Parsers/formatAST.h>
#include <Interpreters/executeQuery.h>
#include <Common/StringUtils/StringUtils.h>
#include <Common/quoteString.h>
#include <Interpreters/Context.h>
namespace DB
{
namespace ErrorCodes
{
extern const int NOT_IMPLEMENTED;
}
using EntityType = IAccessEntity::Type;
InterpreterShowAccessEntitiesQuery::InterpreterShowAccessEntitiesQuery(const ASTPtr & query_ptr_, Context & context_)
: query_ptr(query_ptr_), context(context_)
{
}
BlockIO InterpreterShowAccessEntitiesQuery::execute()
{
return executeQuery(getRewrittenQuery(), context, true);
}
String InterpreterShowAccessEntitiesQuery::getRewrittenQuery() const
{
const auto & query = query_ptr->as<ASTShowAccessEntitiesQuery &>();
String origin;
String expr = "name";
String filter;
if (query.type == EntityType::ROW_POLICY)
{
origin = "row_policies";
const String & table_name = query.table_name;
String database;
bool show_short_name = false;
if (!table_name.empty())
{
database = query.database;
if (database.empty())
database = context.getCurrentDatabase();
show_short_name = true;
}
if (!table_name.empty())
filter = "database = " + quoteString(database) + " AND table = " + quoteString(table_name);
if (show_short_name)
expr = "short_name";
}
else
throw Exception(toString(query.type) + ": type is not supported by SHOW query", ErrorCodes::NOT_IMPLEMENTED);
return "SELECT " + expr + " from system." + origin +
(filter.empty() ? "" : " WHERE " + filter) + " ORDER BY " + expr;
}
}

View File

@ -8,15 +8,14 @@ namespace DB
{ {
class Context; class Context;
class InterpreterShowRowPoliciesQuery : public IInterpreter class InterpreterShowAccessEntitiesQuery : public IInterpreter
{ {
public: public:
InterpreterShowRowPoliciesQuery(const ASTPtr & query_ptr_, Context & context_); InterpreterShowAccessEntitiesQuery(const ASTPtr & query_ptr_, Context & context_);
BlockIO execute() override; BlockIO execute() override;
private: private:
String getRewrittenQuery() const; String getRewrittenQuery() const;
String getResultDescription() const;
ASTPtr query_ptr; ASTPtr query_ptr;
Context & context; Context & context;

View File

@ -167,14 +167,14 @@ namespace
if (policy.isRestrictive()) if (policy.isRestrictive())
query->is_restrictive = policy.isRestrictive(); query->is_restrictive = policy.isRestrictive();
for (auto index : ext::range(RowPolicy::MAX_CONDITION_TYPE)) for (auto type : ext::range(RowPolicy::MAX_CONDITION_TYPE))
{ {
const auto & condition = policy.conditions[index]; const auto & condition = policy.conditions[static_cast<size_t>(type)];
if (!condition.empty()) if (!condition.empty())
{ {
ParserExpression parser; ParserExpression parser;
ASTPtr expr = parseQuery(parser, condition, 0, DBMS_DEFAULT_MAX_PARSER_DEPTH); ASTPtr expr = parseQuery(parser, condition, 0, DBMS_DEFAULT_MAX_PARSER_DEPTH);
query->conditions.push_back(std::pair{index, expr}); query->conditions[static_cast<size_t>(type)] = expr;
} }
} }

View File

@ -1,69 +0,0 @@
#include <Interpreters/InterpreterShowRowPoliciesQuery.h>
#include <Parsers/ASTShowRowPoliciesQuery.h>
#include <Parsers/formatAST.h>
#include <Interpreters/executeQuery.h>
#include <Common/StringUtils/StringUtils.h>
#include <Common/quoteString.h>
#include <Interpreters/Context.h>
#include <ext/range.h>
namespace DB
{
InterpreterShowRowPoliciesQuery::InterpreterShowRowPoliciesQuery(const ASTPtr & query_ptr_, Context & context_)
: query_ptr(query_ptr_), context(context_)
{
}
BlockIO InterpreterShowRowPoliciesQuery::execute()
{
return executeQuery(getRewrittenQuery(), context, true);
}
String InterpreterShowRowPoliciesQuery::getRewrittenQuery() const
{
const auto & query = query_ptr->as<ASTShowRowPoliciesQuery &>();
const String & table_name = query.table_name;
String database;
if (!table_name.empty())
{
database = query.database;
if (database.empty())
database = context.getCurrentDatabase();
}
String filter;
if (query.current)
{
if (table_name.empty())
filter = "has(currentRowPolicyIDs(), id)";
else
filter = "has(currentRowPolicyIDs(" + quoteString(database) + ", " + quoteString(table_name) + "), id)";
}
else
{
if (!table_name.empty())
filter = "database = " + quoteString(database) + " AND table = " + quoteString(table_name);
}
String expr = table_name.empty() ? "name" : "short_name";
return "SELECT " + expr + " AS " + backQuote(getResultDescription()) + " from system.row_policies"
+ (filter.empty() ? "" : " WHERE " + filter) + " ORDER BY " + expr;
}
String InterpreterShowRowPoliciesQuery::getResultDescription() const
{
std::stringstream ss;
formatAST(*query_ptr, ss, false, true);
String desc = ss.str();
String prefix = "SHOW ";
if (startsWith(desc, prefix))
desc = desc.substr(prefix.length()); /// `desc` always starts with "SHOW ", so we can trim this prefix.
return desc;
}
}

View File

@ -83,12 +83,12 @@ SRCS(
InterpreterSelectWithUnionQuery.cpp InterpreterSelectWithUnionQuery.cpp
InterpreterSetQuery.cpp InterpreterSetQuery.cpp
InterpreterSetRoleQuery.cpp InterpreterSetRoleQuery.cpp
InterpreterShowAccessEntitiesQuery.cpp
InterpreterShowCreateAccessEntityQuery.cpp InterpreterShowCreateAccessEntityQuery.cpp
InterpreterShowCreateQuery.cpp InterpreterShowCreateQuery.cpp
InterpreterShowGrantsQuery.cpp InterpreterShowGrantsQuery.cpp
InterpreterShowProcesslistQuery.cpp InterpreterShowProcesslistQuery.cpp
InterpreterShowQuotasQuery.cpp InterpreterShowQuotasQuery.cpp
InterpreterShowRowPoliciesQuery.cpp
InterpreterShowTablesQuery.cpp InterpreterShowTablesQuery.cpp
InterpreterSystemQuery.cpp InterpreterSystemQuery.cpp
InterpreterUseQuery.cpp InterpreterUseQuery.cpp

View File

@ -2,6 +2,7 @@
#include <Parsers/ASTExtendedRoleSet.h> #include <Parsers/ASTExtendedRoleSet.h>
#include <Parsers/formatAST.h> #include <Parsers/formatAST.h>
#include <Common/quoteString.h> #include <Common/quoteString.h>
#include <ext/range.h>
#include <boost/range/algorithm/transform.hpp> #include <boost/range/algorithm/transform.hpp>
#include <sstream> #include <sstream>
@ -11,6 +12,9 @@ namespace DB
namespace namespace
{ {
using ConditionType = RowPolicy::ConditionType; using ConditionType = RowPolicy::ConditionType;
using ConditionTypeInfo = RowPolicy::ConditionTypeInfo;
constexpr auto MAX_CONDITION_TYPE = RowPolicy::MAX_CONDITION_TYPE;
void formatRenameTo(const String & new_short_name, const IAST::FormatSettings & settings) void formatRenameTo(const String & new_short_name, const IAST::FormatSettings & settings)
{ {
@ -28,90 +32,89 @@ namespace
void formatConditionalExpression(const ASTPtr & expr, const IAST::FormatSettings & settings) void formatConditionalExpression(const ASTPtr & expr, const IAST::FormatSettings & settings)
{ {
if (!expr) if (expr)
{ expr->format(settings);
else
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " NONE" << (settings.hilite ? IAST::hilite_none : ""); settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " NONE" << (settings.hilite ? IAST::hilite_none : "");
return;
}
expr->format(settings);
} }
std::vector<std::pair<ConditionType, String>> void formatCondition(const boost::container::flat_set<std::string_view> & commands, const String & filter, const String & check, bool alter, const IAST::FormatSettings & settings)
conditionalExpressionsToStrings(const std::vector<std::pair<ConditionType, ASTPtr>> & exprs, const IAST::FormatSettings & settings)
{ {
std::vector<std::pair<ConditionType, String>> result; settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " FOR " << (settings.hilite ? IAST::hilite_none : "");
std::stringstream ss;
IAST::FormatSettings temp_settings(ss, settings);
boost::range::transform(exprs, std::back_inserter(result), [&](const std::pair<ConditionType, ASTPtr> & in)
{
formatConditionalExpression(in.second, temp_settings);
auto out = std::pair{in.first, ss.str()};
ss.str("");
return out;
});
return result;
}
void formatConditions(const char * op, const std::optional<String> & filter, const std::optional<String> & check, bool alter, const IAST::FormatSettings & settings)
{
if (op)
{
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " FOR" << (settings.hilite ? IAST::hilite_none : "");
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << ' ' << op << (settings.hilite ? IAST::hilite_none : "");
}
if (filter)
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " USING " << (settings.hilite ? IAST::hilite_none : "") << *filter;
if (check && (alter || (check != filter)))
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " WITH CHECK " << (settings.hilite ? IAST::hilite_none : "") << *check;
}
void formatMultipleConditions(const std::vector<std::pair<ConditionType, ASTPtr>> & conditions, bool alter, const IAST::FormatSettings & settings)
{
std::optional<String> scond[RowPolicy::MAX_CONDITION_TYPE];
for (const auto & [index, scondition] : conditionalExpressionsToStrings(conditions, settings))
scond[index] = scondition;
if ((scond[RowPolicy::SELECT_FILTER] == scond[RowPolicy::UPDATE_FILTER])
&& (scond[RowPolicy::UPDATE_FILTER] == scond[RowPolicy::DELETE_FILTER])
&& (scond[RowPolicy::INSERT_CHECK] == scond[RowPolicy::UPDATE_CHECK])
&& (scond[RowPolicy::SELECT_FILTER] || scond[RowPolicy::INSERT_CHECK]))
{
formatConditions(nullptr, scond[RowPolicy::SELECT_FILTER], scond[RowPolicy::INSERT_CHECK], alter, settings);
return;
}
bool need_comma = false; bool need_comma = false;
if (scond[RowPolicy::SELECT_FILTER]) for (const auto & command : commands)
{ {
if (std::exchange(need_comma, true)) if (std::exchange(need_comma, true))
settings.ostr << ','; settings.ostr << ", ";
formatConditions("SELECT", scond[RowPolicy::SELECT_FILTER], {}, alter, settings); settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << command << (settings.hilite ? IAST::hilite_none : "");
}
if (scond[RowPolicy::INSERT_CHECK])
{
if (std::exchange(need_comma, true))
settings.ostr << ',';
formatConditions("INSERT", {}, scond[RowPolicy::INSERT_CHECK], alter, settings);
}
if (scond[RowPolicy::UPDATE_FILTER] || scond[RowPolicy::UPDATE_CHECK])
{
if (std::exchange(need_comma, true))
settings.ostr << ',';
formatConditions("UPDATE", scond[RowPolicy::UPDATE_FILTER], scond[RowPolicy::UPDATE_CHECK], alter, settings);
}
if (scond[RowPolicy::DELETE_FILTER])
{
if (std::exchange(need_comma, true))
settings.ostr << ',';
formatConditions("DELETE", scond[RowPolicy::DELETE_FILTER], {}, alter, settings);
} }
if (!filter.empty())
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " USING " << (settings.hilite ? IAST::hilite_none : "") << filter;
if (!check.empty() && (alter || (check != filter)))
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " WITH CHECK " << (settings.hilite ? IAST::hilite_none : "") << check;
} }
void formatMultipleConditions(const std::array<std::optional<ASTPtr>, MAX_CONDITION_TYPE> & conditions, bool alter, const IAST::FormatSettings & settings)
{
std::array<String, MAX_CONDITION_TYPE> conditions_as_strings;
std::stringstream temp_sstream;
IAST::FormatSettings temp_settings(temp_sstream, settings);
for (auto condition_type : ext::range(MAX_CONDITION_TYPE))
{
const auto & condition = conditions[condition_type];
if (condition)
{
formatConditionalExpression(*condition, temp_settings);
conditions_as_strings[condition_type] = temp_sstream.str();
temp_sstream.str("");
}
}
boost::container::flat_set<std::string_view> commands;
String filter, check;
do
{
commands.clear();
filter.clear();
check.clear();
/// Collect commands using the same filter and check conditions.
for (auto condition_type : ext::range(MAX_CONDITION_TYPE))
{
const String & condition = conditions_as_strings[condition_type];
if (condition.empty())
continue;
const auto & type_info = ConditionTypeInfo::get(condition_type);
if (type_info.is_check)
{
if (check.empty())
check = condition;
else if (check != condition)
continue;
}
else
{
if (filter.empty())
filter = condition;
else if (filter != condition)
continue;
}
commands.emplace(type_info.command);
conditions_as_strings[condition_type].clear(); /// Skip this condition on the next iteration.
}
if (!filter.empty() || !check.empty())
formatCondition(commands, filter, check, alter, settings);
}
while (!filter.empty() || !check.empty());
}
void formatToRoles(const ASTExtendedRoleSet & roles, const IAST::FormatSettings & settings) void formatToRoles(const ASTExtendedRoleSet & roles, const IAST::FormatSettings & settings)
{ {
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " TO " << (settings.hilite ? IAST::hilite_none : ""); settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " TO " << (settings.hilite ? IAST::hilite_none : "");

View File

@ -3,8 +3,8 @@
#include <Parsers/IAST.h> #include <Parsers/IAST.h>
#include <Parsers/ASTQueryWithOnCluster.h> #include <Parsers/ASTQueryWithOnCluster.h>
#include <Access/RowPolicy.h> #include <Access/RowPolicy.h>
#include <utility> #include <array>
#include <vector> #include <optional>
namespace DB namespace DB
@ -40,8 +40,7 @@ public:
String new_short_name; String new_short_name;
std::optional<bool> is_restrictive; std::optional<bool> is_restrictive;
using ConditionType = RowPolicy::ConditionType; std::array<std::optional<ASTPtr>, RowPolicy::MAX_CONDITION_TYPE> conditions; /// `nullopt` means "not set", `nullptr` means set to NONE.
std::vector<std::pair<ConditionType, ASTPtr>> conditions;
std::shared_ptr<ASTExtendedRoleSet> roles; std::shared_ptr<ASTExtendedRoleSet> roles;

View File

@ -0,0 +1,37 @@
#include <Parsers/ASTShowAccessEntitiesQuery.h>
#include <Common/quoteString.h>
namespace DB
{
namespace ErrorCodes
{
extern const int NOT_IMPLEMENTED;
}
String ASTShowAccessEntitiesQuery::getID(char) const
{
if (type == EntityType::ROW_POLICY)
return "SHOW ROW POLICIES query";
else
throw Exception(toString(type) + ": type is not supported by SHOW query", ErrorCodes::NOT_IMPLEMENTED);
}
void ASTShowAccessEntitiesQuery::formatQueryImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const
{
if (type == EntityType::ROW_POLICY)
settings.ostr << (settings.hilite ? hilite_keyword : "") << "SHOW ROW POLICIES" << (settings.hilite ? hilite_none : "");
else
throw Exception(toString(type) + ": type is not supported by SHOW query", ErrorCodes::NOT_IMPLEMENTED);
if ((type == EntityType::ROW_POLICY) && !table_name.empty())
{
settings.ostr << (settings.hilite ? hilite_keyword : "") << " ON " << (settings.hilite ? hilite_none : "");
if (!database.empty())
settings.ostr << backQuoteIfNeed(database) << ".";
settings.ostr << backQuoteIfNeed(table_name);
}
}
}

View File

@ -1,20 +1,24 @@
#pragma once #pragma once
#include <Parsers/ASTQueryWithOutput.h> #include <Parsers/ASTQueryWithOutput.h>
#include <Access/IAccessEntity.h>
namespace DB namespace DB
{ {
/// SHOW [ROW] POLICIES [CURRENT] [ON [database.]table]
class ASTShowRowPoliciesQuery : public ASTQueryWithOutput /// SHOW [ROW] POLICIES [ON [database.]table]
class ASTShowAccessEntitiesQuery : public ASTQueryWithOutput
{ {
public: public:
bool current = false; using EntityType = IAccessEntity::Type;
EntityType type;
String database; String database;
String table_name; String table_name;
String getID(char) const override { return "SHOW POLICIES query"; } String getID(char) const override;
ASTPtr clone() const override { return std::make_shared<ASTShowRowPoliciesQuery>(*this); } ASTPtr clone() const override { return std::make_shared<ASTShowAccessEntitiesQuery>(*this); }
protected: protected:
void formatQueryImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const override; void formatQueryImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const override;

View File

@ -1,22 +0,0 @@
#include <Parsers/ASTShowRowPoliciesQuery.h>
#include <Common/quoteString.h>
namespace DB
{
void ASTShowRowPoliciesQuery::formatQueryImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const
{
settings.ostr << (settings.hilite ? hilite_keyword : "") << "SHOW POLICIES" << (settings.hilite ? hilite_none : "");
if (current)
settings.ostr << (settings.hilite ? hilite_keyword : "") << " CURRENT" << (settings.hilite ? hilite_none : "");
if (!table_name.empty())
{
settings.ostr << (settings.hilite ? hilite_keyword : "") << " ON " << (settings.hilite ? hilite_none : "");
if (!database.empty())
settings.ostr << backQuoteIfNeed(database) << ".";
settings.ostr << backQuoteIfNeed(table_name);
}
}
}

View File

@ -8,18 +8,17 @@
#include <Parsers/ExpressionListParsers.h> #include <Parsers/ExpressionListParsers.h>
#include <Parsers/ExpressionElementParsers.h> #include <Parsers/ExpressionElementParsers.h>
#include <Parsers/ASTLiteral.h> #include <Parsers/ASTLiteral.h>
#include <ext/range.h>
namespace DB namespace DB
{ {
namespace ErrorCodes
{
}
namespace namespace
{ {
using ConditionType = RowPolicy::ConditionType; using ConditionType = RowPolicy::ConditionType;
using ConditionTypeInfo = RowPolicy::ConditionTypeInfo;
constexpr auto MAX_CONDITION_TYPE = RowPolicy::MAX_CONDITION_TYPE;
bool parseRenameTo(IParserBase::Pos & pos, Expected & expected, String & new_short_name) bool parseRenameTo(IParserBase::Pos & pos, Expected & expected, String & new_short_name)
{ {
@ -73,111 +72,93 @@ namespace
}); });
} }
bool parseConditions(IParserBase::Pos & pos, Expected & expected, bool alter, std::vector<std::pair<ConditionType, ASTPtr>> & conditions) bool parseConditions(
IParserBase::Pos & pos, Expected & expected, bool alter, std::array<std::optional<ASTPtr>, MAX_CONDITION_TYPE> & conditions)
{ {
return IParserBase::wrapParseImpl(pos, [&] return IParserBase::wrapParseImpl(pos, [&]
{ {
static constexpr char select_op[] = "SELECT"; boost::container::flat_set<std::string_view> commands;
static constexpr char insert_op[] = "INSERT";
static constexpr char update_op[] = "UPDATE"; auto add_all_commands = [&]
static constexpr char delete_op[] = "DELETE"; {
std::vector<const char *> ops; for (auto condition_type : ext::range(MAX_CONDITION_TYPE))
{
const std::string_view & command = ConditionTypeInfo::get(condition_type).command;
commands.emplace(command);
}
};
if (ParserKeyword{"FOR"}.ignore(pos, expected)) if (ParserKeyword{"FOR"}.ignore(pos, expected))
{ {
do do
{ {
if (ParserKeyword{"SELECT"}.ignore(pos, expected)) size_t old_size = commands.size();
ops.push_back(select_op); if (ParserKeyword{"ALL"}.ignore(pos, expected))
#if 0 /// INSERT, UPDATE, DELETE are not supported yet
else if (ParserKeyword{"INSERT"}.ignore(pos, expected))
ops.push_back(insert_op);
else if (ParserKeyword{"UPDATE"}.ignore(pos, expected))
ops.push_back(update_op);
else if (ParserKeyword{"DELETE"}.ignore(pos, expected))
ops.push_back(delete_op);
else if (ParserKeyword{"ALL"}.ignore(pos, expected))
{ {
add_all_commands();
} }
#endif
else else
{
for (auto condition_type : ext::range(MAX_CONDITION_TYPE))
{
const std::string_view & command = ConditionTypeInfo::get(condition_type).command;
if (ParserKeyword{command.data()}.ignore(pos, expected))
{
commands.emplace(command);
break;
}
}
}
if (commands.size() == old_size)
return false; return false;
} }
while (ParserToken{TokenType::Comma}.ignore(pos, expected)); while (ParserToken{TokenType::Comma}.ignore(pos, expected));
} }
if (ops.empty())
{
ops.push_back(select_op);
#if 0 /// INSERT, UPDATE, DELETE are not supported yet
ops.push_back(insert_op);
ops.push_back(update_op);
ops.push_back(delete_op);
#endif
}
std::optional<ASTPtr> filter; std::optional<ASTPtr> filter;
std::optional<ASTPtr> check; std::optional<ASTPtr> check;
bool keyword_using = false, keyword_with_check = false;
if (ParserKeyword{"USING"}.ignore(pos, expected)) if (ParserKeyword{"USING"}.ignore(pos, expected))
{ {
keyword_using = true;
if (!parseConditionalExpression(pos, expected, filter)) if (!parseConditionalExpression(pos, expected, filter))
return false; return false;
} }
#if 0 /// INSERT, UPDATE, DELETE are not supported yet
if (ParserKeyword{"WITH CHECK"}.ignore(pos, expected)) if (ParserKeyword{"WITH CHECK"}.ignore(pos, expected))
{ {
keyword_with_check = true;
if (!parseConditionalExpression(pos, expected, check)) if (!parseConditionalExpression(pos, expected, check))
return false; return false;
} }
#endif
if (!keyword_using && !keyword_with_check) if (!filter && !check)
return false; return false;
if (filter && !check && !alter) if (commands.empty())
add_all_commands();
if (!check && !alter)
check = filter; check = filter;
auto set_condition = [&](ConditionType index, const ASTPtr & condition) for (auto condition_type : ext::range(MAX_CONDITION_TYPE))
{ {
auto it = std::find_if(conditions.begin(), conditions.end(), [index](const std::pair<ConditionType, ASTPtr> & element) const auto & type_info = ConditionTypeInfo::get(condition_type);
if (commands.count(type_info.command))
{ {
return element.first == index; if (type_info.is_check && check)
}); conditions[condition_type] = check;
if (it == conditions.end()) else if (filter)
it = conditions.insert(conditions.end(), std::pair<ConditionType, ASTPtr>{index, nullptr}); conditions[condition_type] = filter;
it->second = condition;
};
for (const auto & op : ops)
{
if ((op == select_op) && filter)
set_condition(RowPolicy::SELECT_FILTER, *filter);
else if ((op == insert_op) && check)
set_condition(RowPolicy::INSERT_CHECK, *check);
else if (op == update_op)
{
if (filter)
set_condition(RowPolicy::UPDATE_FILTER, *filter);
if (check)
set_condition(RowPolicy::UPDATE_CHECK, *check);
} }
else if ((op == delete_op) && filter)
set_condition(RowPolicy::DELETE_FILTER, *filter);
else
__builtin_unreachable();
} }
return true; return true;
}); });
} }
bool parseMultipleConditions(IParserBase::Pos & pos, Expected & expected, bool alter, std::vector<std::pair<ConditionType, ASTPtr>> & conditions) bool parseMultipleConditions(
IParserBase::Pos & pos, Expected & expected, bool alter, std::array<std::optional<ASTPtr>, MAX_CONDITION_TYPE> & conditions)
{ {
return IParserBase::wrapParseImpl(pos, [&] return IParserBase::wrapParseImpl(pos, [&]
{ {
std::vector<std::pair<ConditionType, ASTPtr>> res_conditions; std::array<std::optional<ASTPtr>, MAX_CONDITION_TYPE> res_conditions;
do do
{ {
if (!parseConditions(pos, expected, alter, res_conditions)) if (!parseConditions(pos, expected, alter, res_conditions))
@ -256,7 +237,7 @@ bool ParserCreateRowPolicyQuery::parseImpl(Pos & pos, ASTPtr & node, Expected &
String new_short_name; String new_short_name;
std::optional<bool> is_restrictive; std::optional<bool> is_restrictive;
std::vector<std::pair<ConditionType, ASTPtr>> conditions; std::array<std::optional<ASTPtr>, MAX_CONDITION_TYPE> conditions;
String cluster; String cluster;
while (true) while (true)

View File

@ -14,10 +14,10 @@
#include <Parsers/ParserWatchQuery.h> #include <Parsers/ParserWatchQuery.h>
#include <Parsers/ParserSetQuery.h> #include <Parsers/ParserSetQuery.h>
#include <Parsers/ASTExplainQuery.h> #include <Parsers/ASTExplainQuery.h>
#include <Parsers/ParserShowGrantsQuery.h> #include <Parsers/ParserShowAccessEntitiesQuery.h>
#include <Parsers/ParserShowCreateAccessEntityQuery.h> #include <Parsers/ParserShowCreateAccessEntityQuery.h>
#include <Parsers/ParserShowGrantsQuery.h>
#include <Parsers/ParserShowQuotasQuery.h> #include <Parsers/ParserShowQuotasQuery.h>
#include <Parsers/ParserShowRowPoliciesQuery.h>
namespace DB namespace DB
@ -38,10 +38,10 @@ bool ParserQueryWithOutput::parseImpl(Pos & pos, ASTPtr & node, Expected & expec
ParserOptimizeQuery optimize_p; ParserOptimizeQuery optimize_p;
ParserKillQueryQuery kill_query_p; ParserKillQueryQuery kill_query_p;
ParserWatchQuery watch_p; ParserWatchQuery watch_p;
ParserShowAccessEntitiesQuery show_access_entities_p;
ParserShowCreateAccessEntityQuery show_create_access_entity_p; ParserShowCreateAccessEntityQuery show_create_access_entity_p;
ParserShowGrantsQuery show_grants_p; ParserShowGrantsQuery show_grants_p;
ParserShowQuotasQuery show_quotas_p; ParserShowQuotasQuery show_quotas_p;
ParserShowRowPoliciesQuery show_row_policies_p;
ASTPtr query; ASTPtr query;
@ -70,9 +70,9 @@ bool ParserQueryWithOutput::parseImpl(Pos & pos, ASTPtr & node, Expected & expec
|| kill_query_p.parse(pos, query, expected) || kill_query_p.parse(pos, query, expected)
|| optimize_p.parse(pos, query, expected) || optimize_p.parse(pos, query, expected)
|| watch_p.parse(pos, query, expected) || watch_p.parse(pos, query, expected)
|| show_access_entities_p.parse(pos, query, expected)
|| show_grants_p.parse(pos, query, expected) || show_grants_p.parse(pos, query, expected)
|| show_quotas_p.parse(pos, query, expected) || show_quotas_p.parse(pos, query, expected);
|| show_row_policies_p.parse(pos, query, expected);
if (!parsed) if (!parsed)
return false; return false;

View File

@ -1,5 +1,5 @@
#include <Parsers/ParserShowRowPoliciesQuery.h> #include <Parsers/ParserShowAccessEntitiesQuery.h>
#include <Parsers/ASTShowRowPoliciesQuery.h> #include <Parsers/ASTShowAccessEntitiesQuery.h>
#include <Parsers/CommonParsers.h> #include <Parsers/CommonParsers.h>
#include <Parsers/parseDatabaseAndTableName.h> #include <Parsers/parseDatabaseAndTableName.h>
@ -8,6 +8,8 @@ namespace DB
{ {
namespace namespace
{ {
using EntityType = IAccessEntity::Type;
bool parseONDatabaseAndTableName(IParserBase::Pos & pos, Expected & expected, String & database, String & table_name) bool parseONDatabaseAndTableName(IParserBase::Pos & pos, Expected & expected, String & database, String & table_name)
{ {
return IParserBase::wrapParseImpl(pos, [&] return IParserBase::wrapParseImpl(pos, [&]
@ -20,21 +22,21 @@ namespace
} }
bool ParserShowRowPoliciesQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) bool ParserShowAccessEntitiesQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected)
{ {
if (!ParserKeyword{"SHOW POLICIES"}.ignore(pos, expected) && !ParserKeyword{"SHOW ROW POLICIES"}.ignore(pos, expected)) if (!ParserKeyword{"SHOW POLICIES"}.ignore(pos, expected) && !ParserKeyword{"SHOW ROW POLICIES"}.ignore(pos, expected))
return false; return false;
bool current = ParserKeyword{"CURRENT"}.ignore(pos, expected);
String database, table_name; String database, table_name;
parseONDatabaseAndTableName(pos, expected, database, table_name); parseONDatabaseAndTableName(pos, expected, database, table_name);
auto query = std::make_shared<ASTShowRowPoliciesQuery>(); auto query = std::make_shared<ASTShowAccessEntitiesQuery>();
query->current = current; node = query;
query->type = EntityType::ROW_POLICY;
query->database = std::move(database); query->database = std::move(database);
query->table_name = std::move(table_name); query->table_name = std::move(table_name);
node = query;
return true; return true;
} }
} }

View File

@ -6,12 +6,12 @@
namespace DB namespace DB
{ {
/** Parses queries like /** Parses queries like
* SHOW [ROW] POLICIES [CURRENT] [ON [database.]table] * SHOW [ROW] POLICIES [ON [database.]table]
*/ */
class ParserShowRowPoliciesQuery : public IParserBase class ParserShowAccessEntitiesQuery : public IParserBase
{ {
protected: protected:
const char * getName() const override { return "SHOW POLICIES query"; } const char * getName() const override { return "ShowAccessEntitiesQuery"; }
bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected) override; bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected) override;
}; };
} }

View File

@ -42,10 +42,10 @@ SRCS(
ASTSelectWithUnionQuery.cpp ASTSelectWithUnionQuery.cpp
ASTSetRoleQuery.cpp ASTSetRoleQuery.cpp
ASTSettingsProfileElement.cpp ASTSettingsProfileElement.cpp
ASTShowAccessEntitiesQuery.cpp
ASTShowCreateAccessEntityQuery.cpp ASTShowCreateAccessEntityQuery.cpp
ASTShowGrantsQuery.cpp ASTShowGrantsQuery.cpp
ASTShowQuotasQuery.cpp ASTShowQuotasQuery.cpp
ASTShowRowPoliciesQuery.cpp
ASTShowTablesQuery.cpp ASTShowTablesQuery.cpp
ASTSubquery.cpp ASTSubquery.cpp
ASTSystemQuery.cpp ASTSystemQuery.cpp
@ -94,10 +94,10 @@ SRCS(
ParserSetQuery.cpp ParserSetQuery.cpp
ParserSetRoleQuery.cpp ParserSetRoleQuery.cpp
ParserSettingsProfileElement.cpp ParserSettingsProfileElement.cpp
ParserShowAccessEntitiesQuery.cpp
ParserShowCreateAccessEntityQuery.cpp ParserShowCreateAccessEntityQuery.cpp
ParserShowGrantsQuery.cpp ParserShowGrantsQuery.cpp
ParserShowQuotasQuery.cpp ParserShowQuotasQuery.cpp
ParserShowRowPoliciesQuery.cpp
ParserShowTablesQuery.cpp ParserShowTablesQuery.cpp
ParserSystemQuery.cpp ParserSystemQuery.cpp
ParserTablePropertiesQuery.cpp ParserTablePropertiesQuery.cpp

View File

@ -2,32 +2,52 @@
#include <DataTypes/DataTypeString.h> #include <DataTypes/DataTypeString.h>
#include <DataTypes/DataTypesNumber.h> #include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypeUUID.h> #include <DataTypes/DataTypeUUID.h>
#include <DataTypes/DataTypeDateTime.h>
#include <DataTypes/DataTypeNullable.h> #include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypeArray.h>
#include <Columns/ColumnString.h>
#include <Columns/ColumnsNumber.h>
#include <Columns/ColumnArray.h>
#include <Columns/ColumnNullable.h>
#include <Interpreters/Context.h> #include <Interpreters/Context.h>
#include <Parsers/ASTExtendedRoleSet.h>
#include <Access/AccessControlManager.h> #include <Access/AccessControlManager.h>
#include <Access/RowPolicy.h> #include <Access/RowPolicy.h>
#include <Access/AccessFlags.h> #include <Access/AccessFlags.h>
#include <ext/range.h> #include <ext/range.h>
#include <boost/range/algorithm_ext/push_back.hpp>
namespace DB namespace DB
{ {
using ConditionTypeInfo = RowPolicy::ConditionTypeInfo;
constexpr auto MAX_CONDITION_TYPE = RowPolicy::MAX_CONDITION_TYPE;
NamesAndTypesList StorageSystemRowPolicies::getNamesAndTypes() NamesAndTypesList StorageSystemRowPolicies::getNamesAndTypes()
{ {
NamesAndTypesList names_and_types{ NamesAndTypesList names_and_types{
{"name", std::make_shared<DataTypeString>()},
{"short_name", std::make_shared<DataTypeString>()},
{"database", std::make_shared<DataTypeString>()}, {"database", std::make_shared<DataTypeString>()},
{"table", std::make_shared<DataTypeString>()}, {"table", std::make_shared<DataTypeString>()},
{"short_name", std::make_shared<DataTypeString>()},
{"name", std::make_shared<DataTypeString>()},
{"id", std::make_shared<DataTypeUUID>()}, {"id", std::make_shared<DataTypeUUID>()},
{"source", std::make_shared<DataTypeString>()}, {"source", std::make_shared<DataTypeString>()},
{"restrictive", std::make_shared<DataTypeUInt8>()},
}; };
for (auto index : ext::range(RowPolicy::MAX_CONDITION_TYPE)) for (auto type : ext::range(MAX_CONDITION_TYPE))
names_and_types.push_back({RowPolicy::conditionTypeToColumnName(index), std::make_shared<DataTypeString>()}); {
const String & column_name = ConditionTypeInfo::get(type).name;
names_and_types.push_back({column_name, std::make_shared<DataTypeNullable>(std::make_shared<DataTypeString>())});
}
NamesAndTypesList extra_names_and_types{
{"is_restrictive", std::make_shared<DataTypeUInt8>()},
{"apply_to_all", std::make_shared<DataTypeUInt8>()},
{"apply_to_list", std::make_shared<DataTypeArray>(std::make_shared<DataTypeString>())},
{"apply_to_except", std::make_shared<DataTypeArray>(std::make_shared<DataTypeString>())}
};
boost::range::push_back(names_and_types, std::move(extra_names_and_types));
return names_and_types; return names_and_types;
} }
@ -38,24 +58,83 @@ void StorageSystemRowPolicies::fillData(MutableColumns & res_columns, const Cont
const auto & access_control = context.getAccessControlManager(); const auto & access_control = context.getAccessControlManager();
std::vector<UUID> ids = access_control.findAll<RowPolicy>(); std::vector<UUID> ids = access_control.findAll<RowPolicy>();
size_t column_index = 0;
auto & column_name = assert_cast<ColumnString &>(*res_columns[column_index++]);
auto & column_short_name = assert_cast<ColumnString &>(*res_columns[column_index++]);
auto & column_database = assert_cast<ColumnString &>(*res_columns[column_index++]);
auto & column_table = assert_cast<ColumnString &>(*res_columns[column_index++]);
auto & column_id = assert_cast<ColumnUInt128 &>(*res_columns[column_index++]).getData();
auto & column_storage = assert_cast<ColumnString &>(*res_columns[column_index++]);
ColumnString * column_condition[MAX_CONDITION_TYPE];
NullMap * column_condition_null_map[MAX_CONDITION_TYPE];
for (auto condition_type : ext::range(MAX_CONDITION_TYPE))
{
column_condition[condition_type] = &assert_cast<ColumnString &>(assert_cast<ColumnNullable &>(*res_columns[column_index]).getNestedColumn());
column_condition_null_map[condition_type] = &assert_cast<ColumnNullable &>(*res_columns[column_index++]).getNullMapData();
}
auto & column_is_restrictive = assert_cast<ColumnUInt8 &>(*res_columns[column_index++]).getData();
auto & column_apply_to_all = assert_cast<ColumnUInt8 &>(*res_columns[column_index++]).getData();
auto & column_apply_to_list = assert_cast<ColumnString &>(assert_cast<ColumnArray &>(*res_columns[column_index]).getData());
auto & column_apply_to_list_offsets = assert_cast<ColumnArray &>(*res_columns[column_index++]).getOffsets();
auto & column_apply_to_except = assert_cast<ColumnString &>(assert_cast<ColumnArray &>(*res_columns[column_index]).getData());
auto & column_apply_to_except_offsets = assert_cast<ColumnArray &>(*res_columns[column_index++]).getOffsets();
auto add_row = [&](const String & name,
const RowPolicy::NameParts & name_parts,
const UUID & id,
const String & storage_name,
const std::array<String, MAX_CONDITION_TYPE> & conditions,
bool is_restrictive,
const ExtendedRoleSet & apply_to)
{
column_name.insertData(name.data(), name.length());
column_short_name.insertData(name_parts.short_name.data(), name_parts.short_name.length());
column_database.insertData(name_parts.database.data(), name_parts.database.length());
column_table.insertData(name_parts.table_name.data(), name_parts.table_name.length());
column_id.push_back(id);
column_storage.insertData(storage_name.data(), storage_name.length());
for (auto condition_type : ext::range(MAX_CONDITION_TYPE))
{
const String & condition = conditions[condition_type];
if (condition.empty())
{
column_condition[condition_type]->insertDefault();
column_condition_null_map[condition_type]->push_back(true);
}
else
{
column_condition[condition_type]->insertData(condition.data(), condition.length());
column_condition_null_map[condition_type]->push_back(false);
}
}
column_is_restrictive.push_back(is_restrictive);
auto apply_to_ast = apply_to.toASTWithNames(access_control);
column_apply_to_all.push_back(apply_to_ast->all);
for (const auto & role_name : apply_to_ast->names)
column_apply_to_list.insertData(role_name.data(), role_name.length());
column_apply_to_list_offsets.push_back(column_apply_to_list.size());
for (const auto & role_name : apply_to_ast->except_names)
column_apply_to_except.insertData(role_name.data(), role_name.length());
column_apply_to_except_offsets.push_back(column_apply_to_except.size());
};
for (const auto & id : ids) for (const auto & id : ids)
{ {
auto policy = access_control.tryRead<RowPolicy>(id); auto policy = access_control.tryRead<RowPolicy>(id);
if (!policy) if (!policy)
continue; continue;
const auto * storage = access_control.findStorage(id); const auto * storage = access_control.findStorage(id);
if (!storage)
continue;
size_t i = 0; add_row(policy->getName(), policy->getNameParts(), id, storage->getStorageName(), policy->conditions, policy->isRestrictive(), policy->to_roles);
res_columns[i++]->insert(policy->getDatabase());
res_columns[i++]->insert(policy->getTableName());
res_columns[i++]->insert(policy->getShortName());
res_columns[i++]->insert(policy->getName());
res_columns[i++]->insert(id);
res_columns[i++]->insert(storage ? storage->getStorageName() : "");
res_columns[i++]->insert(policy->isRestrictive());
for (auto index : ext::range(RowPolicy::MAX_CONDITION_TYPE))
res_columns[i++]->insert(policy->conditions[index]);
} }
} }
} }

View File

@ -3,13 +3,13 @@
<test_local_cluster> <test_local_cluster>
<shard> <shard>
<replica> <replica>
<host>instance1</host> <host>node</host>
<port>9000</port> <port>9000</port>
</replica> </replica>
</shard> </shard>
<shard> <shard>
<replica> <replica>
<host>instance2</host> <host>node2</host>
<port>9000</port> <port>9000</port>
</replica> </replica>
</shard> </shard>

View File

@ -1,21 +1,22 @@
import pytest import pytest
from helpers.cluster import ClickHouseCluster from helpers.cluster import ClickHouseCluster
from helpers.test_tools import assert_eq_with_retry from helpers.test_tools import assert_eq_with_retry, TSV
import os import os
import re import re
import time import time
cluster = ClickHouseCluster(__file__) cluster = ClickHouseCluster(__file__)
instance = cluster.add_instance('instance1', config_dir="configs", with_zookeeper=True) node = cluster.add_instance('node', config_dir="configs", with_zookeeper=True)
instance2 = cluster.add_instance('instance2', config_dir="configs", with_zookeeper=True) node2 = cluster.add_instance('node2', config_dir="configs", with_zookeeper=True)
nodes = [node, node2]
def copy_policy_xml(local_file_name, reload_immediately = True): def copy_policy_xml(local_file_name, reload_immediately = True):
script_dir = os.path.dirname(os.path.realpath(__file__)) script_dir = os.path.dirname(os.path.realpath(__file__))
instance.copy_file_to_container(os.path.join(script_dir, local_file_name), '/etc/clickhouse-server/users.d/row_policy.xml') for current_node in nodes:
instance2.copy_file_to_container(os.path.join(script_dir, local_file_name), '/etc/clickhouse-server/users.d/row_policy.xml') current_node.copy_file_to_container(os.path.join(script_dir, local_file_name), '/etc/clickhouse-server/users.d/row_policy.xml')
if reload_immediately: if reload_immediately:
instance.query("SYSTEM RELOAD CONFIG") current_node.query("SYSTEM RELOAD CONFIG")
@pytest.fixture(scope="module", autouse=True) @pytest.fixture(scope="module", autouse=True)
@ -23,42 +24,25 @@ def started_cluster():
try: try:
cluster.start() cluster.start()
instance.query(''' for current_node in nodes:
CREATE DATABASE mydb ENGINE=Ordinary; current_node.query('''
CREATE DATABASE mydb ENGINE=Ordinary;
CREATE TABLE mydb.filtered_table1 (a UInt8, b UInt8) ENGINE MergeTree ORDER BY a; CREATE TABLE mydb.filtered_table1 (a UInt8, b UInt8) ENGINE MergeTree ORDER BY a;
INSERT INTO mydb.filtered_table1 values (0, 0), (0, 1), (1, 0), (1, 1); INSERT INTO mydb.filtered_table1 values (0, 0), (0, 1), (1, 0), (1, 1);
CREATE TABLE mydb.table (a UInt8, b UInt8) ENGINE MergeTree ORDER BY a; CREATE TABLE mydb.table (a UInt8, b UInt8) ENGINE MergeTree ORDER BY a;
INSERT INTO mydb.table values (0, 0), (0, 1), (1, 0), (1, 1); INSERT INTO mydb.table values (0, 0), (0, 1), (1, 0), (1, 1);
CREATE TABLE mydb.filtered_table2 (a UInt8, b UInt8, c UInt8, d UInt8) ENGINE MergeTree ORDER BY a; CREATE TABLE mydb.filtered_table2 (a UInt8, b UInt8, c UInt8, d UInt8) ENGINE MergeTree ORDER BY a;
INSERT INTO mydb.filtered_table2 values (0, 0, 0, 0), (1, 2, 3, 4), (4, 3, 2, 1), (0, 0, 6, 0); INSERT INTO mydb.filtered_table2 values (0, 0, 0, 0), (1, 2, 3, 4), (4, 3, 2, 1), (0, 0, 6, 0);
CREATE TABLE mydb.filtered_table3 (a UInt8, b UInt8, c UInt16 ALIAS a + b) ENGINE MergeTree ORDER BY a; CREATE TABLE mydb.filtered_table3 (a UInt8, b UInt8, c UInt16 ALIAS a + b) ENGINE MergeTree ORDER BY a;
INSERT INTO mydb.filtered_table3 values (0, 0), (0, 1), (1, 0), (1, 1); INSERT INTO mydb.filtered_table3 values (0, 0), (0, 1), (1, 0), (1, 1);
CREATE TABLE mydb.`.filtered_table4` (a UInt8, b UInt8, c UInt16 ALIAS a + b) ENGINE MergeTree ORDER BY a; CREATE TABLE mydb.`.filtered_table4` (a UInt8, b UInt8, c UInt16 ALIAS a + b) ENGINE MergeTree ORDER BY a;
INSERT INTO mydb.`.filtered_table4` values (0, 0), (0, 1), (1, 0), (1, 1); INSERT INTO mydb.`.filtered_table4` values (0, 0), (0, 1), (1, 0), (1, 1);
''') ''')
instance2.query('''
CREATE DATABASE mydb ENGINE=Ordinary;
CREATE TABLE mydb.filtered_table1 (a UInt8, b UInt8) ENGINE MergeTree ORDER BY a;
INSERT INTO mydb.filtered_table1 values (0, 0), (0, 1), (1, 0), (1, 1);
CREATE TABLE mydb.table (a UInt8, b UInt8) ENGINE MergeTree ORDER BY a;
INSERT INTO mydb.table values (0, 0), (0, 1), (1, 0), (1, 1);
CREATE TABLE mydb.filtered_table2 (a UInt8, b UInt8, c UInt8, d UInt8) ENGINE MergeTree ORDER BY a;
INSERT INTO mydb.filtered_table2 values (0, 0, 0, 0), (1, 2, 3, 4), (4, 3, 2, 1), (0, 0, 6, 0);
CREATE TABLE mydb.filtered_table3 (a UInt8, b UInt8, c UInt16 ALIAS a + b) ENGINE MergeTree ORDER BY a;
INSERT INTO mydb.filtered_table3 values (0, 0), (0, 1), (1, 0), (1, 1);
CREATE TABLE mydb.`.filtered_table4` (a UInt8, b UInt8, c UInt16 ALIAS a + b) ENGINE MergeTree ORDER BY a;
INSERT INTO mydb.`.filtered_table4` values (0, 0), (0, 1), (1, 0), (1, 1);
''')
yield cluster yield cluster
@ -72,243 +56,239 @@ def reset_policies():
yield yield
finally: finally:
copy_policy_xml('normal_filters.xml') copy_policy_xml('normal_filters.xml')
instance.query("DROP POLICY IF EXISTS pA, pB ON mydb.filtered_table1") for current_node in nodes:
current_node.query("DROP POLICY IF EXISTS pA, pB ON mydb.filtered_table1")
def test_smoke(): def test_smoke():
assert instance.query("SELECT * FROM mydb.filtered_table1") == "1\t0\n1\t1\n" assert node.query("SELECT * FROM mydb.filtered_table1") == TSV([[1,0], [1, 1]])
assert instance.query("SELECT * FROM mydb.filtered_table2") == "0\t0\t0\t0\n0\t0\t6\t0\n" assert node.query("SELECT * FROM mydb.filtered_table2") == TSV([[0, 0, 0, 0], [0, 0, 6, 0]])
assert instance.query("SELECT * FROM mydb.filtered_table3") == "0\t1\n1\t0\n" assert node.query("SELECT * FROM mydb.filtered_table3") == TSV([[0, 1], [1, 0]])
assert instance.query("SELECT a FROM mydb.filtered_table1") == "1\n1\n" assert node.query("SELECT a FROM mydb.filtered_table1") == TSV([[1], [1]])
assert instance.query("SELECT b FROM mydb.filtered_table1") == "0\n1\n" assert node.query("SELECT b FROM mydb.filtered_table1") == TSV([[0], [1]])
assert instance.query("SELECT a FROM mydb.filtered_table1 WHERE a = 1") == "1\n1\n" assert node.query("SELECT a FROM mydb.filtered_table1 WHERE a = 1") == TSV([[1], [1]])
assert instance.query("SELECT a FROM mydb.filtered_table1 WHERE a IN (1)") == "1\n1\n" assert node.query("SELECT a FROM mydb.filtered_table1 WHERE a IN (1)") == TSV([[1], [1]])
assert instance.query("SELECT a = 1 FROM mydb.filtered_table1") == "1\n1\n" assert node.query("SELECT a = 1 FROM mydb.filtered_table1") == TSV([[1], [1]])
assert instance.query("SELECT a FROM mydb.filtered_table3") == "0\n1\n" assert node.query("SELECT a FROM mydb.filtered_table3") == TSV([[0], [1]])
assert instance.query("SELECT b FROM mydb.filtered_table3") == "1\n0\n" assert node.query("SELECT b FROM mydb.filtered_table3") == TSV([[1], [0]])
assert instance.query("SELECT c FROM mydb.filtered_table3") == "1\n1\n" assert node.query("SELECT c FROM mydb.filtered_table3") == TSV([[1], [1]])
assert instance.query("SELECT a + b FROM mydb.filtered_table3") == "1\n1\n" assert node.query("SELECT a + b FROM mydb.filtered_table3") == TSV([[1], [1]])
assert instance.query("SELECT a FROM mydb.filtered_table3 WHERE c = 1") == "0\n1\n" assert node.query("SELECT a FROM mydb.filtered_table3 WHERE c = 1") == TSV([[0], [1]])
assert instance.query("SELECT c = 1 FROM mydb.filtered_table3") == "1\n1\n" assert node.query("SELECT c = 1 FROM mydb.filtered_table3") == TSV([[1], [1]])
assert instance.query("SELECT a + b = 1 FROM mydb.filtered_table3") == "1\n1\n" assert node.query("SELECT a + b = 1 FROM mydb.filtered_table3") == TSV([[1], [1]])
def test_join(): def test_join():
assert instance.query("SELECT * FROM mydb.filtered_table1 as t1 ANY LEFT JOIN mydb.filtered_table1 as t2 ON t1.a = t2.b") == "1\t0\t1\t1\n1\t1\t1\t1\n" assert node.query("SELECT * FROM mydb.filtered_table1 as t1 ANY LEFT JOIN mydb.filtered_table1 as t2 ON t1.a = t2.b") == TSV([[1, 0, 1, 1], [1, 1, 1, 1]])
assert instance.query("SELECT * FROM mydb.filtered_table1 as t2 ANY RIGHT JOIN mydb.filtered_table1 as t1 ON t2.b = t1.a") == "1\t1\t1\t0\n" assert node.query("SELECT * FROM mydb.filtered_table1 as t2 ANY RIGHT JOIN mydb.filtered_table1 as t1 ON t2.b = t1.a") == TSV([[1, 1, 1, 0]])
def test_cannot_trick_row_policy_with_keyword_with(): def test_cannot_trick_row_policy_with_keyword_with():
assert instance.query("WITH 0 AS a SELECT * FROM mydb.filtered_table1") == "1\t0\n1\t1\n" assert node.query("WITH 0 AS a SELECT * FROM mydb.filtered_table1") == TSV([[1, 0], [1, 1]])
assert instance.query("WITH 0 AS a SELECT a, b FROM mydb.filtered_table1") == "0\t0\n0\t1\n" assert node.query("WITH 0 AS a SELECT a, b FROM mydb.filtered_table1") == TSV([[0, 0], [0, 1]])
assert instance.query("WITH 0 AS a SELECT a FROM mydb.filtered_table1") == "0\n0\n" assert node.query("WITH 0 AS a SELECT a FROM mydb.filtered_table1") == TSV([[0], [0]])
assert instance.query("WITH 0 AS a SELECT b FROM mydb.filtered_table1") == "0\n1\n" assert node.query("WITH 0 AS a SELECT b FROM mydb.filtered_table1") == TSV([[0], [1]])
def test_prewhere_not_supported(): def test_prewhere_not_supported():
expected_error = "PREWHERE is not supported if the table is filtered by row-level security" expected_error = "PREWHERE is not supported if the table is filtered by row-level security"
assert expected_error in instance.query_and_get_error("SELECT * FROM mydb.filtered_table1 PREWHERE 1") assert expected_error in node.query_and_get_error("SELECT * FROM mydb.filtered_table1 PREWHERE 1")
assert expected_error in instance.query_and_get_error("SELECT * FROM mydb.filtered_table2 PREWHERE 1") assert expected_error in node.query_and_get_error("SELECT * FROM mydb.filtered_table2 PREWHERE 1")
assert expected_error in instance.query_and_get_error("SELECT * FROM mydb.filtered_table3 PREWHERE 1") assert expected_error in node.query_and_get_error("SELECT * FROM mydb.filtered_table3 PREWHERE 1")
# However PREWHERE should still work for user without filtering. # However PREWHERE should still work for user without filtering.
assert instance.query("SELECT * FROM mydb.filtered_table1 PREWHERE 1", user="another") == "0\t0\n0\t1\n1\t0\n1\t1\n" assert node.query("SELECT * FROM mydb.filtered_table1 PREWHERE 1", user="another") == TSV([[0, 0], [0, 1], [1, 0], [1, 1]])
def test_single_table_name(): def test_single_table_name():
copy_policy_xml('tag_with_table_name.xml') copy_policy_xml('tag_with_table_name.xml')
assert instance.query("SELECT * FROM mydb.table") == "1\t0\n1\t1\n" assert node.query("SELECT * FROM mydb.table") == TSV([[1, 0], [1, 1]])
assert instance.query("SELECT * FROM mydb.filtered_table2") == "0\t0\t0\t0\n0\t0\t6\t0\n" assert node.query("SELECT * FROM mydb.filtered_table2") == TSV([[0, 0, 0, 0], [0, 0, 6, 0]])
assert instance.query("SELECT * FROM mydb.filtered_table3") == "0\t1\n1\t0\n" assert node.query("SELECT * FROM mydb.filtered_table3") == TSV([[0, 1], [1, 0]])
assert instance.query("SELECT a FROM mydb.table") == "1\n1\n" assert node.query("SELECT a FROM mydb.table") == TSV([[1], [1]])
assert instance.query("SELECT b FROM mydb.table") == "0\n1\n" assert node.query("SELECT b FROM mydb.table") == TSV([[0], [1]])
assert instance.query("SELECT a FROM mydb.table WHERE a = 1") == "1\n1\n" assert node.query("SELECT a FROM mydb.table WHERE a = 1") == TSV([[1], [1]])
assert instance.query("SELECT a = 1 FROM mydb.table") == "1\n1\n" assert node.query("SELECT a = 1 FROM mydb.table") == TSV([[1], [1]])
assert instance.query("SELECT a FROM mydb.filtered_table3") == "0\n1\n" assert node.query("SELECT a FROM mydb.filtered_table3") == TSV([[0], [1]])
assert instance.query("SELECT b FROM mydb.filtered_table3") == "1\n0\n" assert node.query("SELECT b FROM mydb.filtered_table3") == TSV([[1], [0]])
assert instance.query("SELECT c FROM mydb.filtered_table3") == "1\n1\n" assert node.query("SELECT c FROM mydb.filtered_table3") == TSV([[1], [1]])
assert instance.query("SELECT a + b FROM mydb.filtered_table3") == "1\n1\n" assert node.query("SELECT a + b FROM mydb.filtered_table3") == TSV([[1], [1]])
assert instance.query("SELECT a FROM mydb.filtered_table3 WHERE c = 1") == "0\n1\n" assert node.query("SELECT a FROM mydb.filtered_table3 WHERE c = 1") == TSV([[0], [1]])
assert instance.query("SELECT c = 1 FROM mydb.filtered_table3") == "1\n1\n" assert node.query("SELECT c = 1 FROM mydb.filtered_table3") == TSV([[1], [1]])
assert instance.query("SELECT a + b = 1 FROM mydb.filtered_table3") == "1\n1\n" assert node.query("SELECT a + b = 1 FROM mydb.filtered_table3") == TSV([[1], [1]])
def test_custom_table_name(): def test_custom_table_name():
copy_policy_xml('multiple_tags_with_table_names.xml') copy_policy_xml('multiple_tags_with_table_names.xml')
assert instance.query("SELECT * FROM mydb.table") == "1\t0\n1\t1\n" assert node.query("SELECT * FROM mydb.table") == TSV([[1, 0], [1, 1]])
assert instance.query("SELECT * FROM mydb.filtered_table2") == "0\t0\t0\t0\n0\t0\t6\t0\n" assert node.query("SELECT * FROM mydb.filtered_table2") == TSV([[0, 0, 0, 0], [0, 0, 6, 0]])
assert instance.query("SELECT * FROM mydb.`.filtered_table4`") == "0\t1\n1\t0\n" assert node.query("SELECT * FROM mydb.`.filtered_table4`") == TSV([[0, 1], [1, 0]])
assert instance.query("SELECT a FROM mydb.table") == "1\n1\n" assert node.query("SELECT a FROM mydb.table") == TSV([[1], [1]])
assert instance.query("SELECT b FROM mydb.table") == "0\n1\n" assert node.query("SELECT b FROM mydb.table") == TSV([[0], [1]])
assert instance.query("SELECT a FROM mydb.table WHERE a = 1") == "1\n1\n" assert node.query("SELECT a FROM mydb.table WHERE a = 1") == TSV([[1], [1]])
assert instance.query("SELECT a = 1 FROM mydb.table") == "1\n1\n" assert node.query("SELECT a = 1 FROM mydb.table") == TSV([[1], [1]])
assert instance.query("SELECT a FROM mydb.`.filtered_table4`") == "0\n1\n" assert node.query("SELECT a FROM mydb.`.filtered_table4`") == TSV([[0], [1]])
assert instance.query("SELECT b FROM mydb.`.filtered_table4`") == "1\n0\n" assert node.query("SELECT b FROM mydb.`.filtered_table4`") == TSV([[1], [0]])
assert instance.query("SELECT c FROM mydb.`.filtered_table4`") == "1\n1\n" assert node.query("SELECT c FROM mydb.`.filtered_table4`") == TSV([[1], [1]])
assert instance.query("SELECT a + b FROM mydb.`.filtered_table4`") == "1\n1\n" assert node.query("SELECT a + b FROM mydb.`.filtered_table4`") == TSV([[1], [1]])
assert instance.query("SELECT a FROM mydb.`.filtered_table4` WHERE c = 1") == "0\n1\n" assert node.query("SELECT a FROM mydb.`.filtered_table4` WHERE c = 1") == TSV([[0], [1]])
assert instance.query("SELECT c = 1 FROM mydb.`.filtered_table4`") == "1\n1\n" assert node.query("SELECT c = 1 FROM mydb.`.filtered_table4`") == TSV([[1], [1]])
assert instance.query("SELECT a + b = 1 FROM mydb.`.filtered_table4`") == "1\n1\n" assert node.query("SELECT a + b = 1 FROM mydb.`.filtered_table4`") == TSV([[1], [1]])
def test_change_of_users_xml_changes_row_policies(): def test_change_of_users_xml_changes_row_policies():
copy_policy_xml('normal_filters.xml') copy_policy_xml('normal_filters.xml')
assert instance.query("SELECT * FROM mydb.filtered_table1") == "1\t0\n1\t1\n" assert node.query("SELECT * FROM mydb.filtered_table1") == TSV([[1, 0], [1, 1]])
assert instance.query("SELECT * FROM mydb.filtered_table2") == "0\t0\t0\t0\n0\t0\t6\t0\n" assert node.query("SELECT * FROM mydb.filtered_table2") == TSV([[0, 0, 0, 0], [0, 0, 6, 0]])
assert instance.query("SELECT * FROM mydb.filtered_table3") == "0\t1\n1\t0\n" assert node.query("SELECT * FROM mydb.filtered_table3") == TSV([[0, 1], [1, 0]])
copy_policy_xml('all_rows.xml') copy_policy_xml('all_rows.xml')
assert instance.query("SELECT * FROM mydb.filtered_table1") == "0\t0\n0\t1\n1\t0\n1\t1\n" assert node.query("SELECT * FROM mydb.filtered_table1") == TSV([[0, 0], [0, 1], [1, 0], [1, 1]])
assert instance.query("SELECT * FROM mydb.filtered_table2") == "0\t0\t0\t0\n0\t0\t6\t0\n1\t2\t3\t4\n4\t3\t2\t1\n" assert node.query("SELECT * FROM mydb.filtered_table2") == TSV([[0, 0, 0, 0], [0, 0, 6, 0], [1, 2, 3, 4], [4, 3, 2, 1]])
assert instance.query("SELECT * FROM mydb.filtered_table3") == "0\t0\n0\t1\n1\t0\n1\t1\n" assert node.query("SELECT * FROM mydb.filtered_table3") == TSV([[0, 0], [0, 1], [1, 0], [1, 1]])
copy_policy_xml('no_rows.xml') copy_policy_xml('no_rows.xml')
assert instance.query("SELECT * FROM mydb.filtered_table1") == "" assert node.query("SELECT * FROM mydb.filtered_table1") == ""
assert instance.query("SELECT * FROM mydb.filtered_table2") == "" assert node.query("SELECT * FROM mydb.filtered_table2") == ""
assert instance.query("SELECT * FROM mydb.filtered_table3") == "" assert node.query("SELECT * FROM mydb.filtered_table3") == ""
copy_policy_xml('normal_filters.xml') copy_policy_xml('normal_filters.xml')
assert instance.query("SELECT * FROM mydb.filtered_table1") == "1\t0\n1\t1\n" assert node.query("SELECT * FROM mydb.filtered_table1") == TSV([[1, 0], [1, 1]])
assert instance.query("SELECT * FROM mydb.filtered_table2") == "0\t0\t0\t0\n0\t0\t6\t0\n" assert node.query("SELECT * FROM mydb.filtered_table2") == TSV([[0, 0, 0, 0], [0, 0, 6, 0]])
assert instance.query("SELECT * FROM mydb.filtered_table3") == "0\t1\n1\t0\n" assert node.query("SELECT * FROM mydb.filtered_table3") == TSV([[0, 1], [1, 0]])
copy_policy_xml('no_filters.xml') copy_policy_xml('no_filters.xml')
assert instance.query("SELECT * FROM mydb.filtered_table1") == "0\t0\n0\t1\n1\t0\n1\t1\n" assert node.query("SELECT * FROM mydb.filtered_table1") == TSV([[0, 0], [0, 1], [1, 0], [1, 1]])
assert instance.query("SELECT * FROM mydb.filtered_table2") == "0\t0\t0\t0\n0\t0\t6\t0\n1\t2\t3\t4\n4\t3\t2\t1\n" assert node.query("SELECT * FROM mydb.filtered_table2") == TSV([[0, 0, 0, 0], [0, 0, 6, 0], [1, 2, 3, 4], [4, 3, 2, 1]])
assert instance.query("SELECT * FROM mydb.filtered_table3") == "0\t0\n0\t1\n1\t0\n1\t1\n" assert node.query("SELECT * FROM mydb.filtered_table3") == TSV([[0, 0], [0, 1], [1, 0], [1, 1]])
copy_policy_xml('normal_filters.xml') copy_policy_xml('normal_filters.xml')
assert instance.query("SELECT * FROM mydb.filtered_table1") == "1\t0\n1\t1\n" assert node.query("SELECT * FROM mydb.filtered_table1") == TSV([[1, 0], [1, 1]])
assert instance.query("SELECT * FROM mydb.filtered_table2") == "0\t0\t0\t0\n0\t0\t6\t0\n" assert node.query("SELECT * FROM mydb.filtered_table2") == TSV([[0, 0, 0, 0], [0, 0, 6, 0]])
assert instance.query("SELECT * FROM mydb.filtered_table3") == "0\t1\n1\t0\n" assert node.query("SELECT * FROM mydb.filtered_table3") == TSV([[0, 1], [1, 0]])
def test_reload_users_xml_by_timer(): def test_reload_users_xml_by_timer():
copy_policy_xml('normal_filters.xml') copy_policy_xml('normal_filters.xml')
assert instance.query("SELECT * FROM mydb.filtered_table1") == "1\t0\n1\t1\n" assert node.query("SELECT * FROM mydb.filtered_table1") == TSV([[1, 0], [1, 1]])
assert instance.query("SELECT * FROM mydb.filtered_table2") == "0\t0\t0\t0\n0\t0\t6\t0\n" assert node.query("SELECT * FROM mydb.filtered_table2") == TSV([[0, 0, 0, 0], [0, 0, 6, 0]])
assert instance.query("SELECT * FROM mydb.filtered_table3") == "0\t1\n1\t0\n" assert node.query("SELECT * FROM mydb.filtered_table3") == TSV([[0, 1], [1, 0]])
time.sleep(1) # The modification time of the 'row_policy.xml' file should be different. time.sleep(1) # The modification time of the 'row_policy.xml' file should be different.
copy_policy_xml('all_rows.xml', False) copy_policy_xml('all_rows.xml', False)
assert_eq_with_retry(instance, "SELECT * FROM mydb.filtered_table1", "0\t0\n0\t1\n1\t0\n1\t1") assert_eq_with_retry(node, "SELECT * FROM mydb.filtered_table1", [[0, 0], [0, 1], [1, 0], [1, 1]])
assert_eq_with_retry(instance, "SELECT * FROM mydb.filtered_table2", "0\t0\t0\t0\n0\t0\t6\t0\n1\t2\t3\t4\n4\t3\t2\t1") assert_eq_with_retry(node, "SELECT * FROM mydb.filtered_table2", [[0, 0, 0, 0], [0, 0, 6, 0], [1, 2, 3, 4], [4, 3, 2, 1]])
assert_eq_with_retry(instance, "SELECT * FROM mydb.filtered_table3", "0\t0\n0\t1\n1\t0\n1\t1") assert_eq_with_retry(node, "SELECT * FROM mydb.filtered_table3", [[0, 0], [0, 1], [1, 0], [1, 1]])
time.sleep(1) # The modification time of the 'row_policy.xml' file should be different. time.sleep(1) # The modification time of the 'row_policy.xml' file should be different.
copy_policy_xml('normal_filters.xml', False) copy_policy_xml('normal_filters.xml', False)
assert_eq_with_retry(instance, "SELECT * FROM mydb.filtered_table1", "1\t0\n1\t1") assert_eq_with_retry(node, "SELECT * FROM mydb.filtered_table1", [[1, 0], [1, 1]])
assert_eq_with_retry(instance, "SELECT * FROM mydb.filtered_table2", "0\t0\t0\t0\n0\t0\t6\t0") assert_eq_with_retry(node, "SELECT * FROM mydb.filtered_table2", [[0, 0, 0, 0], [0, 0, 6, 0]])
assert_eq_with_retry(instance, "SELECT * FROM mydb.filtered_table3", "0\t1\n1\t0") assert_eq_with_retry(node, "SELECT * FROM mydb.filtered_table3", [[0, 1], [1, 0]])
def test_introspection(): def test_introspection():
assert instance.query("SELECT currentRowPolicies('mydb', 'filtered_table1')") == "['default']\n" policies = [
assert instance.query("SELECT currentRowPolicies('mydb', 'filtered_table2')") == "['default']\n" ["another ON mydb.filtered_table1", "another", "mydb", "filtered_table1", "6068883a-0e9d-f802-7e22-0144f8e66d3c", "users.xml", "1", 0, 0, "['another']", "[]"],
assert instance.query("SELECT currentRowPolicies('mydb', 'filtered_table3')") == "['default']\n" ["another ON mydb.filtered_table2", "another", "mydb", "filtered_table2", "c019e957-c60b-d54e-cc52-7c90dac5fb01", "users.xml", "1", 0, 0, "['another']", "[]"],
assert instance.query("SELECT arraySort(currentRowPolicies())") == "[('mydb','filtered_table1','default'),('mydb','filtered_table2','default'),('mydb','filtered_table3','default'),('mydb','local','default')]\n" ["another ON mydb.filtered_table3", "another", "mydb", "filtered_table3", "4cb080d0-44e8-dbef-6026-346655143628", "users.xml", "1", 0, 0, "['another']", "[]"],
["another ON mydb.local", "another", "mydb", "local", "5b23c389-7e18-06bf-a6bc-dd1afbbc0a97", "users.xml", "a = 1", 0, 0, "['another']", "[]"],
policy1 = "mydb\tfiltered_table1\tdefault\tdefault ON mydb.filtered_table1\t9e8a8f62-4965-2b5e-8599-57c7b99b3549\tusers.xml\t0\ta = 1\t\t\t\t\n" ["default ON mydb.filtered_table1", "default", "mydb", "filtered_table1", "9e8a8f62-4965-2b5e-8599-57c7b99b3549", "users.xml", "a = 1", 0, 0, "['default']", "[]"],
policy2 = "mydb\tfiltered_table2\tdefault\tdefault ON mydb.filtered_table2\tcffae79d-b9bf-a2ef-b798-019c18470b25\tusers.xml\t0\ta + b < 1 or c - d > 5\t\t\t\t\n" ["default ON mydb.filtered_table2", "default", "mydb", "filtered_table2", "cffae79d-b9bf-a2ef-b798-019c18470b25", "users.xml", "a + b < 1 or c - d > 5", 0, 0, "['default']", "[]"],
policy3 = "mydb\tfiltered_table3\tdefault\tdefault ON mydb.filtered_table3\t12fc5cef-e3da-3940-ec79-d8be3911f42b\tusers.xml\t0\tc = 1\t\t\t\t\n" ["default ON mydb.filtered_table3", "default", "mydb", "filtered_table3", "12fc5cef-e3da-3940-ec79-d8be3911f42b", "users.xml", "c = 1", 0, 0, "['default']", "[]"],
policy4 = "mydb\tlocal\tdefault\tdefault ON mydb.local\tcdacaeb5-1d97-f99d-2bb0-4574f290629c\tusers.xml\t0\t1\t\t\t\t\n" ["default ON mydb.local", "default", "mydb", "local", "cdacaeb5-1d97-f99d-2bb0-4574f290629c", "users.xml", "1", 0, 0, "['default']", "[]"]
assert instance.query("SELECT * from system.row_policies WHERE has(currentRowPolicyIDs('mydb', 'filtered_table1'), id) ORDER BY table, name") == policy1 ]
assert instance.query("SELECT * from system.row_policies WHERE has(currentRowPolicyIDs('mydb', 'filtered_table2'), id) ORDER BY table, name") == policy2 assert node.query("SELECT * from system.row_policies ORDER BY short_name, database, table") == TSV(policies)
assert instance.query("SELECT * from system.row_policies WHERE has(currentRowPolicyIDs('mydb', 'filtered_table3'), id) ORDER BY table, name") == policy3
assert instance.query("SELECT * from system.row_policies WHERE has(currentRowPolicyIDs('mydb', 'local'), id) ORDER BY table, name") == policy4
assert instance.query("SELECT * from system.row_policies WHERE has(currentRowPolicyIDs(), id) ORDER BY table, name") == policy1 + policy2 + policy3 + policy4
def test_dcl_introspection(): def test_dcl_introspection():
assert instance.query("SHOW POLICIES ON mydb.filtered_table1") == "another\ndefault\n" assert node.query("SHOW POLICIES") == TSV(["another ON mydb.filtered_table1", "another ON mydb.filtered_table2", "another ON mydb.filtered_table3", "another ON mydb.local", "default ON mydb.filtered_table1", "default ON mydb.filtered_table2", "default ON mydb.filtered_table3", "default ON mydb.local"])
assert instance.query("SHOW POLICIES CURRENT ON mydb.filtered_table2") == "default\n" assert node.query("SHOW POLICIES ON mydb.filtered_table1") == TSV(["another", "default"])
assert instance.query("SHOW POLICIES") == "another ON mydb.filtered_table1\nanother ON mydb.filtered_table2\nanother ON mydb.filtered_table3\nanother ON mydb.local\ndefault ON mydb.filtered_table1\ndefault ON mydb.filtered_table2\ndefault ON mydb.filtered_table3\ndefault ON mydb.local\n" assert node.query("SHOW POLICIES ON mydb.local") == TSV(["another", "default"])
assert instance.query("SHOW POLICIES CURRENT") == "default ON mydb.filtered_table1\ndefault ON mydb.filtered_table2\ndefault ON mydb.filtered_table3\ndefault ON mydb.local\n"
assert instance.query("SHOW CREATE POLICY default ON mydb.filtered_table1") == "CREATE ROW POLICY default ON mydb.filtered_table1 FOR SELECT USING a = 1 TO default\n" assert node.query("SHOW CREATE POLICY default ON mydb.filtered_table1") == "CREATE ROW POLICY default ON mydb.filtered_table1 FOR SELECT USING a = 1 TO default\n"
assert instance.query("SHOW CREATE POLICY default ON mydb.filtered_table2") == "CREATE ROW POLICY default ON mydb.filtered_table2 FOR SELECT USING ((a + b) < 1) OR ((c - d) > 5) TO default\n" assert node.query("SHOW CREATE POLICY default ON mydb.filtered_table2") == "CREATE ROW POLICY default ON mydb.filtered_table2 FOR SELECT USING ((a + b) < 1) OR ((c - d) > 5) TO default\n"
assert instance.query("SHOW CREATE POLICY default ON mydb.filtered_table3") == "CREATE ROW POLICY default ON mydb.filtered_table3 FOR SELECT USING c = 1 TO default\n" assert node.query("SHOW CREATE POLICY default ON mydb.filtered_table3") == "CREATE ROW POLICY default ON mydb.filtered_table3 FOR SELECT USING c = 1 TO default\n"
assert instance.query("SHOW CREATE POLICY default ON mydb.local") == "CREATE ROW POLICY default ON mydb.local FOR SELECT USING 1 TO default\n" assert node.query("SHOW CREATE POLICY default ON mydb.local") == "CREATE ROW POLICY default ON mydb.local FOR SELECT USING 1 TO default\n"
copy_policy_xml('all_rows.xml') copy_policy_xml('all_rows.xml')
assert instance.query("SHOW POLICIES CURRENT") == "default ON mydb.filtered_table1\ndefault ON mydb.filtered_table2\ndefault ON mydb.filtered_table3\n" assert node.query("SHOW POLICIES") == TSV(["another ON mydb.filtered_table1", "another ON mydb.filtered_table2", "another ON mydb.filtered_table3", "default ON mydb.filtered_table1", "default ON mydb.filtered_table2", "default ON mydb.filtered_table3"])
assert instance.query("SHOW CREATE POLICY default ON mydb.filtered_table1") == "CREATE ROW POLICY default ON mydb.filtered_table1 FOR SELECT USING 1 TO default\n" assert node.query("SHOW CREATE POLICY default ON mydb.filtered_table1") == "CREATE ROW POLICY default ON mydb.filtered_table1 FOR SELECT USING 1 TO default\n"
assert instance.query("SHOW CREATE POLICY default ON mydb.filtered_table2") == "CREATE ROW POLICY default ON mydb.filtered_table2 FOR SELECT USING 1 TO default\n" assert node.query("SHOW CREATE POLICY default ON mydb.filtered_table2") == "CREATE ROW POLICY default ON mydb.filtered_table2 FOR SELECT USING 1 TO default\n"
assert instance.query("SHOW CREATE POLICY default ON mydb.filtered_table3") == "CREATE ROW POLICY default ON mydb.filtered_table3 FOR SELECT USING 1 TO default\n" assert node.query("SHOW CREATE POLICY default ON mydb.filtered_table3") == "CREATE ROW POLICY default ON mydb.filtered_table3 FOR SELECT USING 1 TO default\n"
copy_policy_xml('no_rows.xml') copy_policy_xml('no_rows.xml')
assert instance.query("SHOW POLICIES CURRENT") == "default ON mydb.filtered_table1\ndefault ON mydb.filtered_table2\ndefault ON mydb.filtered_table3\n" assert node.query("SHOW POLICIES") == TSV(["another ON mydb.filtered_table1", "another ON mydb.filtered_table2", "another ON mydb.filtered_table3", "default ON mydb.filtered_table1", "default ON mydb.filtered_table2", "default ON mydb.filtered_table3"])
assert instance.query("SHOW CREATE POLICY default ON mydb.filtered_table1") == "CREATE ROW POLICY default ON mydb.filtered_table1 FOR SELECT USING NULL TO default\n" assert node.query("SHOW CREATE POLICY default ON mydb.filtered_table1") == "CREATE ROW POLICY default ON mydb.filtered_table1 FOR SELECT USING NULL TO default\n"
assert instance.query("SHOW CREATE POLICY default ON mydb.filtered_table2") == "CREATE ROW POLICY default ON mydb.filtered_table2 FOR SELECT USING NULL TO default\n" assert node.query("SHOW CREATE POLICY default ON mydb.filtered_table2") == "CREATE ROW POLICY default ON mydb.filtered_table2 FOR SELECT USING NULL TO default\n"
assert instance.query("SHOW CREATE POLICY default ON mydb.filtered_table3") == "CREATE ROW POLICY default ON mydb.filtered_table3 FOR SELECT USING NULL TO default\n" assert node.query("SHOW CREATE POLICY default ON mydb.filtered_table3") == "CREATE ROW POLICY default ON mydb.filtered_table3 FOR SELECT USING NULL TO default\n"
copy_policy_xml('no_filters.xml') copy_policy_xml('no_filters.xml')
assert instance.query("SHOW POLICIES") == "" assert node.query("SHOW POLICIES") == ""
def test_dcl_management(): def test_dcl_management():
copy_policy_xml('no_filters.xml') copy_policy_xml('no_filters.xml')
assert instance.query("SHOW POLICIES") == "" assert node.query("SHOW POLICIES") == ""
instance.query("CREATE POLICY pA ON mydb.filtered_table1 FOR SELECT USING a<b") node.query("CREATE POLICY pA ON mydb.filtered_table1 FOR SELECT USING a<b")
assert instance.query("SELECT * FROM mydb.filtered_table1") == "" assert node.query("SELECT * FROM mydb.filtered_table1") == ""
assert instance.query("SHOW POLICIES CURRENT ON mydb.filtered_table1") == "" assert node.query("SHOW POLICIES ON mydb.filtered_table1") == "pA\n"
assert instance.query("SHOW POLICIES ON mydb.filtered_table1") == "pA\n"
instance.query("ALTER POLICY pA ON mydb.filtered_table1 TO default") node.query("ALTER POLICY pA ON mydb.filtered_table1 TO default")
assert instance.query("SELECT * FROM mydb.filtered_table1") == "0\t1\n" assert node.query("SELECT * FROM mydb.filtered_table1") == TSV([[0, 1]])
assert instance.query("SHOW POLICIES CURRENT ON mydb.filtered_table1") == "pA\n" assert node.query("SHOW POLICIES ON mydb.filtered_table1") == "pA\n"
instance.query("ALTER POLICY pA ON mydb.filtered_table1 FOR SELECT USING a>b") node.query("ALTER POLICY pA ON mydb.filtered_table1 FOR SELECT USING a>b")
assert instance.query("SELECT * FROM mydb.filtered_table1") == "1\t0\n" assert node.query("SELECT * FROM mydb.filtered_table1") == TSV([[1, 0]])
instance.query("ALTER POLICY pA ON mydb.filtered_table1 RENAME TO pB") node.query("ALTER POLICY pA ON mydb.filtered_table1 RENAME TO pB")
assert instance.query("SELECT * FROM mydb.filtered_table1") == "1\t0\n" assert node.query("SELECT * FROM mydb.filtered_table1") == TSV([[1, 0]])
assert instance.query("SHOW POLICIES CURRENT ON mydb.filtered_table1") == "pB\n" assert node.query("SHOW POLICIES ON mydb.filtered_table1") == "pB\n"
assert instance.query("SHOW CREATE POLICY pB ON mydb.filtered_table1") == "CREATE ROW POLICY pB ON mydb.filtered_table1 FOR SELECT USING a > b TO default\n" assert node.query("SHOW CREATE POLICY pB ON mydb.filtered_table1") == "CREATE ROW POLICY pB ON mydb.filtered_table1 FOR SELECT USING a > b TO default\n"
instance.query("DROP POLICY pB ON mydb.filtered_table1") node.query("DROP POLICY pB ON mydb.filtered_table1")
assert instance.query("SELECT * FROM mydb.filtered_table1") == "0\t0\n0\t1\n1\t0\n1\t1\n" assert node.query("SELECT * FROM mydb.filtered_table1") == TSV([[0, 0], [0, 1], [1, 0], [1, 1]])
assert instance.query("SHOW POLICIES") == "" assert node.query("SHOW POLICIES") == ""
def test_users_xml_is_readonly(): def test_users_xml_is_readonly():
assert re.search("storage is readonly", instance.query_and_get_error("DROP POLICY default ON mydb.filtered_table1")) assert re.search("storage is readonly", node.query_and_get_error("DROP POLICY default ON mydb.filtered_table1"))
def test_miscellaneous_engines(): def test_miscellaneous_engines():
copy_policy_xml('normal_filters.xml') copy_policy_xml('normal_filters.xml')
# ReplicatedMergeTree # ReplicatedMergeTree
instance.query("DROP TABLE mydb.filtered_table1") node.query("DROP TABLE mydb.filtered_table1")
instance.query("CREATE TABLE mydb.filtered_table1 (a UInt8, b UInt8) ENGINE ReplicatedMergeTree('/clickhouse/tables/00-00/filtered_table1', 'replica1') ORDER BY a") node.query("CREATE TABLE mydb.filtered_table1 (a UInt8, b UInt8) ENGINE ReplicatedMergeTree('/clickhouse/tables/00-00/filtered_table1', 'replica1') ORDER BY a")
instance.query("INSERT INTO mydb.filtered_table1 values (0, 0), (0, 1), (1, 0), (1, 1)") node.query("INSERT INTO mydb.filtered_table1 values (0, 0), (0, 1), (1, 0), (1, 1)")
assert instance.query("SELECT * FROM mydb.filtered_table1") == "1\t0\n1\t1\n" assert node.query("SELECT * FROM mydb.filtered_table1") == TSV([[1, 0], [1, 1]])
# CollapsingMergeTree # CollapsingMergeTree
instance.query("DROP TABLE mydb.filtered_table1") node.query("DROP TABLE mydb.filtered_table1")
instance.query("CREATE TABLE mydb.filtered_table1 (a UInt8, b Int8) ENGINE CollapsingMergeTree(b) ORDER BY a") node.query("CREATE TABLE mydb.filtered_table1 (a UInt8, b Int8) ENGINE CollapsingMergeTree(b) ORDER BY a")
instance.query("INSERT INTO mydb.filtered_table1 values (0, 1), (0, 1), (1, 1), (1, 1)") node.query("INSERT INTO mydb.filtered_table1 values (0, 1), (0, 1), (1, 1), (1, 1)")
assert instance.query("SELECT * FROM mydb.filtered_table1") == "1\t1\n1\t1\n" assert node.query("SELECT * FROM mydb.filtered_table1") == TSV([[1, 1], [1, 1]])
# ReplicatedCollapsingMergeTree # ReplicatedCollapsingMergeTree
instance.query("DROP TABLE mydb.filtered_table1") node.query("DROP TABLE mydb.filtered_table1")
instance.query("CREATE TABLE mydb.filtered_table1 (a UInt8, b Int8) ENGINE ReplicatedCollapsingMergeTree('/clickhouse/tables/00-00/filtered_table1', 'replica1', b) ORDER BY a") node.query("CREATE TABLE mydb.filtered_table1 (a UInt8, b Int8) ENGINE ReplicatedCollapsingMergeTree('/clickhouse/tables/00-00/filtered_table1', 'replica1', b) ORDER BY a")
instance.query("INSERT INTO mydb.filtered_table1 values (0, 1), (0, 1), (1, 1), (1, 1)") node.query("INSERT INTO mydb.filtered_table1 values (0, 1), (0, 1), (1, 1), (1, 1)")
assert instance.query("SELECT * FROM mydb.filtered_table1") == "1\t1\n1\t1\n" assert node.query("SELECT * FROM mydb.filtered_table1") == TSV([[1, 1], [1, 1]])
# DistributedMergeTree # DistributedMergeTree
instance.query("DROP TABLE IF EXISTS mydb.not_filtered_table") node.query("DROP TABLE IF EXISTS mydb.not_filtered_table")
instance.query("CREATE TABLE mydb.not_filtered_table (a UInt8, b UInt8) ENGINE Distributed('test_local_cluster', mydb, local)") node.query("CREATE TABLE mydb.not_filtered_table (a UInt8, b UInt8) ENGINE Distributed('test_local_cluster', mydb, local)")
instance.query("CREATE TABLE mydb.local (a UInt8, b UInt8) ENGINE MergeTree ORDER BY a") node.query("CREATE TABLE mydb.local (a UInt8, b UInt8) ENGINE MergeTree ORDER BY a")
instance2.query("CREATE TABLE mydb.local (a UInt8, b UInt8) ENGINE MergeTree ORDER BY a") node2.query("CREATE TABLE mydb.local (a UInt8, b UInt8) ENGINE MergeTree ORDER BY a")
instance.query("INSERT INTO mydb.local values (2, 0), (2, 1), (1, 0), (1, 1)") node.query("INSERT INTO mydb.local values (2, 0), (2, 1), (1, 0), (1, 1)")
instance2.query("INSERT INTO mydb.local values (3, 0), (3, 1), (1, 0), (1, 1)") node2.query("INSERT INTO mydb.local values (3, 0), (3, 1), (1, 0), (1, 1)")
assert instance.query("SELECT * FROM mydb.not_filtered_table", user="another") == "1\t0\n1\t1\n1\t0\n1\t1\n" assert node.query("SELECT * FROM mydb.not_filtered_table", user="another") == TSV([[1, 0], [1, 1], [1, 0], [1, 1]])
assert instance.query("SELECT sum(a), b FROM mydb.not_filtered_table GROUP BY b ORDER BY b", user="another") == "2\t0\n2\t1\n" assert node.query("SELECT sum(a), b FROM mydb.not_filtered_table GROUP BY b ORDER BY b", user="another") == TSV([[2, 0], [2, 1]])