Support oauth_server_uri

This commit is contained in:
kssenii 2024-12-04 19:04:43 +01:00
parent eb9d5705a9
commit c434a3f87a
4 changed files with 41 additions and 10 deletions

View File

@ -36,6 +36,7 @@ namespace DatabaseIcebergSetting
extern const DatabaseIcebergSettingsString auth_header; extern const DatabaseIcebergSettingsString auth_header;
extern const DatabaseIcebergSettingsString auth_scope; extern const DatabaseIcebergSettingsString auth_scope;
extern const DatabaseIcebergSettingsString storage_endpoint; extern const DatabaseIcebergSettingsString storage_endpoint;
extern const DatabaseIcebergSettingsString oauth_server_uri;
extern const DatabaseIcebergSettingsBool vended_credentials; extern const DatabaseIcebergSettingsBool vended_credentials;
} }
namespace Setting namespace Setting
@ -119,6 +120,7 @@ std::shared_ptr<Iceberg::ICatalog> DatabaseIceberg::getCatalog(ContextPtr) const
settings[DatabaseIcebergSetting::catalog_credential].value, settings[DatabaseIcebergSetting::catalog_credential].value,
settings[DatabaseIcebergSetting::auth_scope].value, settings[DatabaseIcebergSetting::auth_scope].value,
settings[DatabaseIcebergSetting::auth_header], settings[DatabaseIcebergSetting::auth_header],
settings[DatabaseIcebergSetting::oauth_server_uri].value,
Context::getGlobalContextInstance()); Context::getGlobalContextInstance());
} }
} }
@ -364,7 +366,7 @@ void registerDatabaseIceberg(DatabaseFactory & factory)
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Engine `{}` must have arguments", database_engine_name); throw Exception(ErrorCodes::BAD_ARGUMENTS, "Engine `{}` must have arguments", database_engine_name);
ASTs & engine_args = function_define->arguments->children; ASTs & engine_args = function_define->arguments->children;
if (engine_args.size() < 1) if (engine_args.empty())
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Engine `{}` must have arguments", database_engine_name); throw Exception(ErrorCodes::BAD_ARGUMENTS, "Engine `{}` must have arguments", database_engine_name);
for (auto & engine_arg : engine_args) for (auto & engine_arg : engine_args)

View File

@ -20,6 +20,7 @@ namespace ErrorCodes
DECLARE(String, catalog_credential, "", "", 0) \ DECLARE(String, catalog_credential, "", "", 0) \
DECLARE(Bool, vended_credentials, true, "Use vended credentials (storage credentials) from catalog", 0) \ DECLARE(Bool, vended_credentials, true, "Use vended credentials (storage credentials) from catalog", 0) \
DECLARE(String, auth_scope, "PRINCIPAL_ROLE:ALL", "Authorization scope for client credentials or token exchange", 0) \ DECLARE(String, auth_scope, "PRINCIPAL_ROLE:ALL", "Authorization scope for client credentials or token exchange", 0) \
DECLARE(String, oauth_server_uri, "", "OAuth server uri", 0) \
DECLARE(String, warehouse, "", "Warehouse name inside the catalog", 0) \ DECLARE(String, warehouse, "", "Warehouse name inside the catalog", 0) \
DECLARE(String, auth_header, "", "Authorization header of format 'Authorization: <scheme> <auth_info>'", 0) \ DECLARE(String, auth_header, "", "Authorization header of format 'Authorization: <scheme> <auth_info>'", 0) \
DECLARE(String, storage_endpoint, "", "Object storage endpoint", 0) \ DECLARE(String, storage_endpoint, "", "Object storage endpoint", 0) \

View File

@ -101,6 +101,13 @@ StorageType parseStorageTypeFromLocation(const std::string & location)
return *storage_type; return *storage_type;
} }
std::string correctAPIURI(const std::string & uri)
{
if (uri.ends_with("v1"))
return uri;
return std::filesystem::path(uri) / "v1";
}
} }
std::string RestCatalog::Config::toString() const std::string RestCatalog::Config::toString() const
@ -122,12 +129,14 @@ RestCatalog::RestCatalog(
const std::string & catalog_credential_, const std::string & catalog_credential_,
const std::string & auth_scope_, const std::string & auth_scope_,
const std::string & auth_header_, const std::string & auth_header_,
const std::string & oauth_server_uri_,
DB::ContextPtr context_) DB::ContextPtr context_)
: ICatalog(warehouse_) : ICatalog(warehouse_)
, DB::WithContext(context_) , DB::WithContext(context_)
, base_url(base_url_) , base_url(correctAPIURI(base_url_))
, log(getLogger("RestCatalog(" + warehouse_ + ")")) , log(getLogger("RestCatalog(" + warehouse_ + ")"))
, auth_scope(auth_scope_) , auth_scope(auth_scope_)
, oauth_server_uri(oauth_server_uri_)
{ {
if (!catalog_credential_.empty()) if (!catalog_credential_.empty())
{ {
@ -217,7 +226,12 @@ std::string RestCatalog::retrieveAccessToken() const
headers.emplace_back("Content-Type", "application/x-www-form-urlencoded"); headers.emplace_back("Content-Type", "application/x-www-form-urlencoded");
headers.emplace_back("Accepts", "application/json; charset=UTF-8"); headers.emplace_back("Accepts", "application/json; charset=UTF-8");
Poco::URI url(base_url / oauth_tokens_endpoint); Poco::URI url;
DB::ReadWriteBufferFromHTTP::OutStreamCallback out_stream_callback;
if (oauth_server_uri.empty())
{
url = Poco::URI(base_url / oauth_tokens_endpoint);
Poco::URI::QueryParameters params = { Poco::URI::QueryParameters params = {
{"grant_type", "client_credentials"}, {"grant_type", "client_credentials"},
{"scope", auth_scope}, {"scope", auth_scope},
@ -225,6 +239,17 @@ std::string RestCatalog::retrieveAccessToken() const
{"client_secret", client_secret}, {"client_secret", client_secret},
}; };
url.setQueryParameters(params); url.setQueryParameters(params);
}
else
{
url = Poco::URI(oauth_server_uri);
out_stream_callback = [&](std::ostream & os)
{
os << fmt::format(
"grant_type=client_credentials&scope={}&client_id={}&client_secret={}",
auth_scope, client_id, client_secret);
};
}
const auto & context = getContext(); const auto & context = getContext();
auto wb = DB::BuilderRWBufferFromHTTP(url) auto wb = DB::BuilderRWBufferFromHTTP(url)
@ -233,6 +258,7 @@ std::string RestCatalog::retrieveAccessToken() const
.withSettings(context->getReadSettings()) .withSettings(context->getReadSettings())
.withTimeouts(DB::ConnectionTimeouts::getHTTPTimeouts(context->getSettingsRef(), context->getServerSettings())) .withTimeouts(DB::ConnectionTimeouts::getHTTPTimeouts(context->getSettingsRef(), context->getServerSettings()))
.withHostFilter(&context->getRemoteHostFilter()) .withHostFilter(&context->getRemoteHostFilter())
.withOutCallback(std::move(out_stream_callback))
.withSkipNotFound(false) .withSkipNotFound(false)
.withHeaders(headers) .withHeaders(headers)
.create(credentials); .create(credentials);

View File

@ -24,6 +24,7 @@ public:
const std::string & catalog_credential_, const std::string & catalog_credential_,
const std::string & auth_scope_, const std::string & auth_scope_,
const std::string & auth_header_, const std::string & auth_header_,
const std::string & oauth_server_uri_,
DB::ContextPtr context_); DB::ContextPtr context_);
~RestCatalog() override = default; ~RestCatalog() override = default;
@ -73,6 +74,7 @@ private:
std::string client_id; std::string client_id;
std::string client_secret; std::string client_secret;
std::string auth_scope; std::string auth_scope;
std::string oauth_server_uri;
mutable std::optional<std::string> access_token; mutable std::optional<std::string> access_token;
Poco::Net::HTTPBasicCredentials credentials{}; Poco::Net::HTTPBasicCredentials credentials{};