diff --git a/src/Common/ZooKeeper/IKeeper.cpp b/src/Common/ZooKeeper/IKeeper.cpp index 94fd291bd12..4f0c5efe680 100644 --- a/src/Common/ZooKeeper/IKeeper.cpp +++ b/src/Common/ZooKeeper/IKeeper.cpp @@ -142,6 +142,8 @@ void GetRequest::addRootPath(const String & root_path) { Coordination::addRootPa void SetRequest::addRootPath(const String & root_path) { Coordination::addRootPath(path, root_path); } void ListRequest::addRootPath(const String & root_path) { Coordination::addRootPath(path, root_path); } void CheckRequest::addRootPath(const String & root_path) { Coordination::addRootPath(path, root_path); } +void SetACLRequest::addRootPath(const String & root_path) { Coordination::addRootPath(path, root_path); } +void GetACLRequest::addRootPath(const String & root_path) { Coordination::addRootPath(path, root_path); } void MultiRequest::addRootPath(const String & root_path) { diff --git a/src/Common/ZooKeeper/IKeeper.h b/src/Common/ZooKeeper/IKeeper.h index 2d947bb402c..5e11687eab5 100644 --- a/src/Common/ZooKeeper/IKeeper.h +++ b/src/Common/ZooKeeper/IKeeper.h @@ -148,6 +148,40 @@ struct WatchResponse : virtual Response using WatchCallback = std::function; +struct SetACLRequest : virtual Request +{ + String path; + ACLs acls; + int32_t version = -1; + + void addRootPath(const String & root_path) override; + String getPath() const override { return path; } + size_t bytesSize() const override { return path.size() + sizeof(version) + acls.size() * sizeof(ACL); } +}; + +struct SetACLResponse : virtual Response +{ + Stat stat; + + size_t bytesSize() const override { return sizeof(Stat); } +}; + +struct GetACLRequest : virtual Request +{ + String path; + + void addRootPath(const String & root_path) override; + String getPath() const override { return path; } + size_t bytesSize() const override { return path.size(); } +}; + +struct GetACLResponse : virtual Response +{ + ACLs acl; + Stat stat; + size_t bytesSize() const override { return sizeof(Stat) + acl.size() * sizeof(ACL); } +}; + struct CreateRequest : virtual Request { String path; diff --git a/src/Coordination/KeeperStorage.cpp b/src/Coordination/KeeperStorage.cpp index c77e989b580..2a01d60f227 100644 --- a/src/Coordination/KeeperStorage.cpp +++ b/src/Coordination/KeeperStorage.cpp @@ -8,6 +8,8 @@ #include #include #include +#include +#include namespace DB { @@ -32,6 +34,17 @@ static std::string getBaseName(const String & path) return std::string{&path[basename_start + 1], path.length() - basename_start - 1}; } +static String base64Encode(const String & decoded) +{ + std::ostringstream ostr; // STYLE_CHECK_ALLOW_STD_STRING_STREAM + ostr.exceptions(std::ios::failbit); + Poco::Base64Encoder encoder(ostr); + encoder.rdbuf()->setLineLength(0); + encoder << decoded; + encoder.close(); + return ostr.str(); +} + static String getSHA1(const String & userdata) { Poco::SHA1Engine engine; @@ -40,18 +53,32 @@ static String getSHA1(const String & userdata) return String{digest_id.begin(), digest_id.end()}; } -static bool checkACL(int32_t permission, const Coordination::ACLs & node_acls, const std::vector & session_auths) +static String generateDigest(const String & userdata) +{ + std::vector user_password; + boost::split(user_password, userdata, [](char c) { return c == ':'; }); + return user_password[0] + base64Encode(getSHA1(user_password[1])); +} + +static bool checkACL(int32_t permission, const Coordination::ACLs & node_acls, const std::vector & session_auths) { if (node_acls.empty()) return true; - for (size_t i = 0; i < node_acls.size(); ++i) - { - if (!(node_acls[i].permissions & permission)) + for (const auto & session_auth : session_auths) + if (session_auth.scheme == "super") return true; - if (node_acls[i].id == session_auths[i]) - return true; + for (size_t i = 0; i < node_acls.size(); ++i) + { + if (node_acls[i].permissions & permission) + { + if (node_acls[i].scheme == "world" && node_acls[i].id == "anyone") + return true; + + if (node_acls[i].scheme == session_auths[i].scheme && node_acls[i].id == session_auths[i].id) + return true; + } } return false; @@ -93,6 +120,34 @@ static KeeperStorage::ResponsesForSessions processWatchesImpl(const String & pat return result; } +static bool fixupACL( + const std::vector & request_acls, + const std::vector & current_ids, + std::vector & result_acls) +{ + if (request_acls.empty()) + return false; + + for (const auto & request_acl : request_acls) + { + if (request_acl.scheme == "world" && request_acl.id == "anyone") + { + result_acls.push_back(request_acl); + } + else if (request_acl.scheme == "auth") + { + for (const auto & current_id : current_ids) + { + Coordination::ACL new_acl = request_acl; + new_acl.scheme = current_id.scheme; + new_acl.id = current_id.id; + result_acls.push_back(new_acl); + } + } + } + return !result_acls.empty(); +} + KeeperStorage::KeeperStorage(int64_t tick_time_ms) : session_expiry_queue(tick_time_ms) { @@ -193,31 +248,15 @@ struct KeeperStorageCreateRequest final : public KeeperStorageRequest else { auto & session_auth_ids = storage.session_and_auth[session_id]; - for (size_t i = 0; i < request.acls.size(); ++i) - { - if (request.acls[i].id.empty()) - { - if (!session_auth_ids[i].empty()) - request.acls[i].id = session_auth_ids[i]; - } - else - { - auto request_sha = getSHA1(request.acls[i].id); - if (session_auth_ids[i] != request_sha) /// User specified strange user:password in request - { - /// User specified strange user:password in request - response.error = Coordination::Error::ZAUTHFAILED; - return { response_ptr, {} }; - } - else - { - request.acls[i].id = request_sha; - break; - } - } - } KeeperStorage::Node created_node; + + if (!fixupACL(request.acls, session_auth_ids, created_node.acls)) + { + response.error = Coordination::Error::ZINVALIDACL; + return {response_ptr, {}}; + } + created_node.stat.czxid = zxid; created_node.stat.mzxid = zxid; created_node.stat.ctime = std::chrono::system_clock::now().time_since_epoch() / std::chrono::milliseconds(1); @@ -227,7 +266,6 @@ struct KeeperStorageCreateRequest final : public KeeperStorageRequest created_node.stat.ephemeralOwner = request.is_ephemeral ? session_id : 0; created_node.data = request.data; created_node.is_sequental = request.is_sequential; - created_node.acls = request.acls; std::string path_created = request.path; @@ -722,18 +760,35 @@ struct KeeperStorageAuthRequest final : public KeeperStorageRequest Coordination::ZooKeeperAuthRequest & auth_request = dynamic_cast(*zk_request); Coordination::ZooKeeperResponsePtr response_ptr = zk_request->makeResponse(); Coordination::ZooKeeperAuthResponse & auth_response = dynamic_cast(*response_ptr); + auto & sessions_and_auth = storage.session_and_auth; - if (auth_request.scheme != "digest") + if (auth_request.scheme == "super") + { + if (generateDigest(auth_request.data) == storage.superdigest) + { + KeeperStorage::AuthID auth{"super", ""}; + sessions_and_auth[session_id].emplace_back(auth); + } + else + { + auth_response.error = Coordination::Error::ZAUTHFAILED; + } + } + else if (auth_request.scheme == "world" && auth_request.data == "anyone") + { + KeeperStorage::AuthID auth{"world", "anyone"}; + sessions_and_auth[session_id].emplace_back(auth); + } + else if (auth_request.scheme != "digest") { auth_response.error = Coordination::Error::ZAUTHFAILED; } else { - auto & sessions_and_auth = storage.session_and_auth; - std::string id = getSHA1(auth_request.data); + KeeperStorage::AuthID auth{auth_request.scheme, generateDigest(auth_request.data)}; auto & session_ids = sessions_and_auth[session_id]; - if (std::find(session_ids.begin(), session_ids.end(), id) == session_ids.end()) - sessions_and_auth[session_id].emplace_back(id); + if (std::find(session_ids.begin(), session_ids.end(), auth) == session_ids.end()) + sessions_and_auth[session_id].emplace_back(auth); } return { response_ptr, {} }; diff --git a/src/Coordination/KeeperStorage.h b/src/Coordination/KeeperStorage.h index acd52ee4a01..f247364f611 100644 --- a/src/Coordination/KeeperStorage.h +++ b/src/Coordination/KeeperStorage.h @@ -51,6 +51,17 @@ public: Coordination::ZooKeeperRequestPtr request; }; + struct AuthID + { + std::string scheme; + std::string id; + + bool operator==(const AuthID & other) const + { + return scheme == other.scheme && id == other.id; + } + }; + using RequestsForSessions = std::vector; using Container = SnapshotableHashTable; @@ -59,13 +70,12 @@ public: using SessionIDs = std::vector; /// Just vector of SHA1 from user:password - using AuthIDs = std::vector; + using AuthIDs = std::vector; using SessionAndAuth = std::unordered_map; SessionAndAuth session_and_auth; using Watches = std::map; - Container container; Ephemerals ephemerals; SessionAndWatcher sessions_and_watchers; @@ -85,6 +95,8 @@ public: return zxid; } + const String superdigest; + public: KeeperStorage(int64_t tick_time_ms);