From c434a3f87af75e131bc0bb6e21967f1640ffd159 Mon Sep 17 00:00:00 2001 From: kssenii Date: Wed, 4 Dec 2024 19:04:43 +0100 Subject: [PATCH] Support oauth_server_uri --- src/Databases/Iceberg/DatabaseIceberg.cpp | 4 +- .../Iceberg/DatabaseIcebergSettings.cpp | 1 + src/Databases/Iceberg/RestCatalog.cpp | 44 +++++++++++++++---- src/Databases/Iceberg/RestCatalog.h | 2 + 4 files changed, 41 insertions(+), 10 deletions(-) diff --git a/src/Databases/Iceberg/DatabaseIceberg.cpp b/src/Databases/Iceberg/DatabaseIceberg.cpp index dd559ba8b92..8544239d6b1 100644 --- a/src/Databases/Iceberg/DatabaseIceberg.cpp +++ b/src/Databases/Iceberg/DatabaseIceberg.cpp @@ -36,6 +36,7 @@ namespace DatabaseIcebergSetting extern const DatabaseIcebergSettingsString auth_header; extern const DatabaseIcebergSettingsString auth_scope; extern const DatabaseIcebergSettingsString storage_endpoint; + extern const DatabaseIcebergSettingsString oauth_server_uri; extern const DatabaseIcebergSettingsBool vended_credentials; } namespace Setting @@ -119,6 +120,7 @@ std::shared_ptr DatabaseIceberg::getCatalog(ContextPtr) const settings[DatabaseIcebergSetting::catalog_credential].value, settings[DatabaseIcebergSetting::auth_scope].value, settings[DatabaseIcebergSetting::auth_header], + settings[DatabaseIcebergSetting::oauth_server_uri].value, Context::getGlobalContextInstance()); } } @@ -364,7 +366,7 @@ void registerDatabaseIceberg(DatabaseFactory & factory) throw Exception(ErrorCodes::BAD_ARGUMENTS, "Engine `{}` must have arguments", database_engine_name); ASTs & engine_args = function_define->arguments->children; - if (engine_args.size() < 1) + if (engine_args.empty()) throw Exception(ErrorCodes::BAD_ARGUMENTS, "Engine `{}` must have arguments", database_engine_name); for (auto & engine_arg : engine_args) diff --git a/src/Databases/Iceberg/DatabaseIcebergSettings.cpp b/src/Databases/Iceberg/DatabaseIcebergSettings.cpp index de04fc0bd11..33374edbb6d 100644 --- a/src/Databases/Iceberg/DatabaseIcebergSettings.cpp +++ b/src/Databases/Iceberg/DatabaseIcebergSettings.cpp @@ -20,6 +20,7 @@ namespace ErrorCodes DECLARE(String, catalog_credential, "", "", 0) \ DECLARE(Bool, vended_credentials, true, "Use vended credentials (storage credentials) from catalog", 0) \ DECLARE(String, auth_scope, "PRINCIPAL_ROLE:ALL", "Authorization scope for client credentials or token exchange", 0) \ + DECLARE(String, oauth_server_uri, "", "OAuth server uri", 0) \ DECLARE(String, warehouse, "", "Warehouse name inside the catalog", 0) \ DECLARE(String, auth_header, "", "Authorization header of format 'Authorization: '", 0) \ DECLARE(String, storage_endpoint, "", "Object storage endpoint", 0) \ diff --git a/src/Databases/Iceberg/RestCatalog.cpp b/src/Databases/Iceberg/RestCatalog.cpp index 262ac53e99f..b138cc1c14b 100644 --- a/src/Databases/Iceberg/RestCatalog.cpp +++ b/src/Databases/Iceberg/RestCatalog.cpp @@ -101,6 +101,13 @@ StorageType parseStorageTypeFromLocation(const std::string & location) return *storage_type; } +std::string correctAPIURI(const std::string & uri) +{ + if (uri.ends_with("v1")) + return uri; + return std::filesystem::path(uri) / "v1"; +} + } std::string RestCatalog::Config::toString() const @@ -122,12 +129,14 @@ RestCatalog::RestCatalog( const std::string & catalog_credential_, const std::string & auth_scope_, const std::string & auth_header_, + const std::string & oauth_server_uri_, DB::ContextPtr context_) : ICatalog(warehouse_) , DB::WithContext(context_) - , base_url(base_url_) + , base_url(correctAPIURI(base_url_)) , log(getLogger("RestCatalog(" + warehouse_ + ")")) , auth_scope(auth_scope_) + , oauth_server_uri(oauth_server_uri_) { if (!catalog_credential_.empty()) { @@ -217,14 +226,30 @@ std::string RestCatalog::retrieveAccessToken() const headers.emplace_back("Content-Type", "application/x-www-form-urlencoded"); headers.emplace_back("Accepts", "application/json; charset=UTF-8"); - Poco::URI url(base_url / oauth_tokens_endpoint); - Poco::URI::QueryParameters params = { - {"grant_type", "client_credentials"}, - {"scope", auth_scope}, - {"client_id", client_id}, - {"client_secret", client_secret}, - }; - url.setQueryParameters(params); + Poco::URI url; + DB::ReadWriteBufferFromHTTP::OutStreamCallback out_stream_callback; + if (oauth_server_uri.empty()) + { + url = Poco::URI(base_url / oauth_tokens_endpoint); + + Poco::URI::QueryParameters params = { + {"grant_type", "client_credentials"}, + {"scope", auth_scope}, + {"client_id", client_id}, + {"client_secret", client_secret}, + }; + url.setQueryParameters(params); + } + else + { + url = Poco::URI(oauth_server_uri); + out_stream_callback = [&](std::ostream & os) + { + os << fmt::format( + "grant_type=client_credentials&scope={}&client_id={}&client_secret={}", + auth_scope, client_id, client_secret); + }; + } const auto & context = getContext(); auto wb = DB::BuilderRWBufferFromHTTP(url) @@ -233,6 +258,7 @@ std::string RestCatalog::retrieveAccessToken() const .withSettings(context->getReadSettings()) .withTimeouts(DB::ConnectionTimeouts::getHTTPTimeouts(context->getSettingsRef(), context->getServerSettings())) .withHostFilter(&context->getRemoteHostFilter()) + .withOutCallback(std::move(out_stream_callback)) .withSkipNotFound(false) .withHeaders(headers) .create(credentials); diff --git a/src/Databases/Iceberg/RestCatalog.h b/src/Databases/Iceberg/RestCatalog.h index 4505e020580..aab8be6ed8d 100644 --- a/src/Databases/Iceberg/RestCatalog.h +++ b/src/Databases/Iceberg/RestCatalog.h @@ -24,6 +24,7 @@ public: const std::string & catalog_credential_, const std::string & auth_scope_, const std::string & auth_header_, + const std::string & oauth_server_uri_, DB::ContextPtr context_); ~RestCatalog() override = default; @@ -73,6 +74,7 @@ private: std::string client_id; std::string client_secret; std::string auth_scope; + std::string oauth_server_uri; mutable std::optional access_token; Poco::Net::HTTPBasicCredentials credentials{};