diff --git a/src/Functions/getHttpHeader.cpp b/src/Functions/getHttpHeader.cpp new file mode 100644 index 00000000000..ee914dd40b9 --- /dev/null +++ b/src/Functions/getHttpHeader.cpp @@ -0,0 +1,95 @@ +#include +#include +#include +#include +#include +#include +#include "Common/CurrentThread.h" +#include +#include "Interpreters/ClientInfo.h" +#include "Interpreters/Context_fwd.h" +#include +#include + + +namespace DB +{ +namespace ErrorCodes +{ + extern const int ILLEGAL_TYPE_OF_ARGUMENT; + extern const int ILLEGAL_COLUMN; + extern const int FUNCTION_NOT_ALLOWED; +} + +namespace +{ + +/** Get the value of parameter in http headers. + * If there no such parameter or the method of request is not + * http, the function will return empty string. + */ +class FunctionGetHttpHeader : public IFunction +{ +private: + +public: + FunctionGetHttpHeader() = default; + + static constexpr auto name = "getHttpHeader"; + + static FunctionPtr create(ContextPtr /*context*/) + { + return std::make_shared(); + } + + + String getName() const override { return name; } + + bool isDeterministic() const override { return false; } + + bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return false; } + + + size_t getNumberOfArguments() const override + { + return 1; + } + + DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override + { + if (!isString(arguments[0])) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "The argument of function {} must have String type", getName()); + return std::make_shared(); + } + + ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override + { + const auto & query_context = DB::CurrentThread::getQueryContext(); + const auto & method = query_context->getClientInfo().http_method; + + const auto & headers = DB::CurrentThread::getQueryContext()->getClientInfo().headers; + + const IColumn * arg_column = arguments[0].column.get(); + const ColumnString * arg_string = checkAndGetColumnConstData(arg_column); + + if (!arg_string) + throw Exception(ErrorCodes::ILLEGAL_COLUMN, "The argument of function {} must be constant String", getName()); + + if (method != ClientInfo::HTTPMethod::GET && method != ClientInfo::HTTPMethod::POST) + return result_type->createColumnConst(input_rows_count, ""); + + if (!headers.has(arg_string->getDataAt(0).toString())) + return result_type->createColumnConst(input_rows_count, ""); + + return result_type->createColumnConst(input_rows_count, headers[arg_string->getDataAt(0).toString()]); + } +}; + +} + +REGISTER_FUNCTION(GetHttpHeader) +{ + factory.registerFunction(); +} + +} diff --git a/src/Interpreters/ClientInfo.h b/src/Interpreters/ClientInfo.h index 70524333047..5878f0b424e 100644 --- a/src/Interpreters/ClientInfo.h +++ b/src/Interpreters/ClientInfo.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include @@ -96,6 +97,8 @@ public: /// For mysql and postgresql UInt64 connection_id = 0; + Poco::Net::NameValueCollection headers; + void setHttpHeaders(const Poco::Net::NameValueCollection & _headers) { headers = _headers; } /// Comma separated list of forwarded IP addresses (from X-Forwarded-For for HTTP interface). /// It's expected that proxy appends the forwarded address to the end of the list. diff --git a/src/Interpreters/Context.cpp b/src/Interpreters/Context.cpp index 8695669a7de..113d862e5d6 100644 --- a/src/Interpreters/Context.cpp +++ b/src/Interpreters/Context.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -85,6 +86,7 @@ #include #include #include +#include "Disks/ObjectStorages/S3/diskSettings.h" #include #include #include @@ -4008,12 +4010,18 @@ void Context::setClientConnectionId(uint32_t connection_id_) client_info.connection_id = connection_id_; } -void Context::setHttpClientInfo(ClientInfo::HTTPMethod http_method, const String & http_user_agent, const String & http_referer) +void Context::setHttpClientInfo(ClientInfo::HTTPMethod http_method, const String & http_user_agent, const String & http_referer, const Poco::Net::NameValueCollection & http_headers) { client_info.http_method = http_method; client_info.http_user_agent = http_user_agent; client_info.http_referer = http_referer; need_recalculate_access = true; + + if (!http_headers.empty()) + { + for (const auto & http_header : http_headers) + client_info.headers.set(http_header.first, http_header.second); + } } void Context::setForwardedFor(const String & forwarded_for) diff --git a/src/Interpreters/Context.h b/src/Interpreters/Context.h index b4a5b3d8c85..b985f45a091 100644 --- a/src/Interpreters/Context.h +++ b/src/Interpreters/Context.h @@ -1,5 +1,7 @@ #pragma once +#include +#include "Core/Types.h" #ifndef CLICKHOUSE_KEEPER_STANDALONE_BUILD #include @@ -609,7 +611,7 @@ public: void setClientInterface(ClientInfo::Interface interface); void setClientVersion(UInt64 client_version_major, UInt64 client_version_minor, UInt64 client_version_patch, unsigned client_tcp_protocol_version); void setClientConnectionId(uint32_t connection_id); - void setHttpClientInfo(ClientInfo::HTTPMethod http_method, const String & http_user_agent, const String & http_referer); + void setHttpClientInfo(ClientInfo::HTTPMethod http_method, const String & http_user_agent, const String & http_referer, const Poco::Net::NameValueCollection & http_headers = {}); void setForwardedFor(const String & forwarded_for); void setQueryKind(ClientInfo::QueryKind query_kind); void setQueryKindInitial(); @@ -804,6 +806,10 @@ public: /// Storage of forbidden HTTP headers from config.xml void setHTTPHeaderFilter(const Poco::Util::AbstractConfiguration & config); const HTTPHeaderFilter & getHTTPHeaderFilter() const; + const Poco::Net::NameValueCollection & getHttpHeaders() const + { + return client_info.headers; + } /// The port that the server listens for executing SQL queries. UInt16 getTCPPort() const; diff --git a/src/Interpreters/Session.cpp b/src/Interpreters/Session.cpp index 439bf6056ba..7105e18ce18 100644 --- a/src/Interpreters/Session.cpp +++ b/src/Interpreters/Session.cpp @@ -15,6 +15,7 @@ #include #include +#include #include #include @@ -428,17 +429,18 @@ void Session::setClientConnectionId(uint32_t connection_id) prepared_client_info->connection_id = connection_id; } -void Session::setHttpClientInfo(ClientInfo::HTTPMethod http_method, const String & http_user_agent, const String & http_referer) +void Session::setHttpClientInfo(ClientInfo::HTTPMethod http_method, const String & http_user_agent, const String & http_referer, const Poco::Net::NameValueCollection & http_headers) { if (session_context) { - session_context->setHttpClientInfo(http_method, http_user_agent, http_referer); + session_context->setHttpClientInfo(http_method, http_user_agent, http_referer, http_headers); } else { prepared_client_info->http_method = http_method; prepared_client_info->http_user_agent = http_user_agent; prepared_client_info->http_referer = http_referer; + prepared_client_info->headers = http_headers; } } diff --git a/src/Interpreters/Session.h b/src/Interpreters/Session.h index 81ef987b428..43e54474bbd 100644 --- a/src/Interpreters/Session.h +++ b/src/Interpreters/Session.h @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -64,7 +65,7 @@ public: void setClientInterface(ClientInfo::Interface interface); void setClientVersion(UInt64 client_version_major, UInt64 client_version_minor, UInt64 client_version_patch, unsigned client_tcp_protocol_version); void setClientConnectionId(uint32_t connection_id); - void setHttpClientInfo(ClientInfo::HTTPMethod http_method, const String & http_user_agent, const String & http_referer); + void setHttpClientInfo(ClientInfo::HTTPMethod http_method, const String & http_user_agent, const String & http_referer, const Poco::Net::NameValueCollection & http_headers = {}); void setForwardedFor(const String & forwarded_for); void setQuotaClientKey(const String & quota_key); void setConnectionClientVersion(UInt64 client_version_major, UInt64 client_version_minor, UInt64 client_version_patch, unsigned client_tcp_protocol_version); diff --git a/src/Server/HTTPHandler.cpp b/src/Server/HTTPHandler.cpp index ebb7f0d3490..f787d2abe45 100644 --- a/src/Server/HTTPHandler.cpp +++ b/src/Server/HTTPHandler.cpp @@ -40,6 +40,7 @@ #include #include #include +#include #include #include #include @@ -500,7 +501,11 @@ bool HTTPHandler::authenticateUser( else if (request.getMethod() == HTTPServerRequest::HTTP_POST) http_method = ClientInfo::HTTPMethod::POST; - session->setHttpClientInfo(http_method, request.get("User-Agent", ""), request.get("Referer", "")); + NameValueCollection http_headers; + for (const auto & it : request) + http_headers.set(it.first, it.second); + + session->setHttpClientInfo(http_method, request.get("User-Agent", ""), request.get("Referer", ""), http_headers); session->setForwardedFor(request.get("X-Forwarded-For", "")); session->setQuotaClientKey(quota_key); @@ -580,6 +585,10 @@ void HTTPHandler::processQuery( session->makeSessionContext(); } + NameValueCollection headers; + for (auto it = request.begin(); it != request.end(); ++it) + headers.set(it->first, it->second); + auto context = session->makeQueryContext(); /// This parameter is used to tune the behavior of output formats (such as Native) for compatibility. diff --git a/tests/queries/0_stateless/02884_getHttpHeaderFunction.reference b/tests/queries/0_stateless/02884_getHttpHeaderFunction.reference new file mode 100644 index 00000000000..564a057086f --- /dev/null +++ b/tests/queries/0_stateless/02884_getHttpHeaderFunction.reference @@ -0,0 +1,2 @@ +Code: 81. DB::Exception: Database `02884_getHttpHeaderFunction` does not exist. (UNKNOWN_DATABASE) (version 23.9.1.1) +default diff --git a/tests/queries/0_stateless/02884_getHttpHeaderFunction.sh b/tests/queries/0_stateless/02884_getHttpHeaderFunction.sh new file mode 100755 index 00000000000..b03b05f7cdb --- /dev/null +++ b/tests/queries/0_stateless/02884_getHttpHeaderFunction.sh @@ -0,0 +1,21 @@ +#!/usr/bin/env bash + +CUR_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) +# shellcheck source=../shell_config.sh +. "$CUR_DIR"/../shell_config.sh +db="02884_getHttpHeaderFunction" + +$CLICKHOUSE_CLIENT -q "CREATE DATABASE IF NOT EXISTS ${db}" +$CLICKHOUSE_CLIENT -q "CREATE TABLE IF NOT EXISTS ${db}.get_http_header (id UInt32, header_value String DEFAULT getHttpHeader('X-Clickhouse-User')) Engine=Memory()" + +#Insert data via tcp client +$CLICKHOUSE_CLIENT -q "INSERT INTO ${db}.get_http_header (id) values (1), (2)" + +#Insert data via http request +echo "INSERT INTO ${db}.get_http_header (id) values (3), (4)" | curl -H 'X-ClickHouse-User: default' -H 'X-ClickHouse-Key: ' 'http://localhost:8123/' -d @- + +$CLICKHOUSE_CLIENT -q "SELECT * FROM ${db}.get_http_header ORDER BY id;" +$CLICKHOUSE_CLIENT -q "DROP DATABASE ${db}" + +echo "SELECT getHttpHeader('X-Clickhouse-User')" | curl -H 'X-ClickHouse-User: default' -H 'X-ClickHouse-Key: ' 'http://localhost:8123/' -d @- +