add function getHttpHeader

This commit is contained in:
凌涛 2023-09-20 10:07:02 +08:00
parent 711876dfa8
commit 9e3c54ddb9
9 changed files with 153 additions and 6 deletions

View File

@ -0,0 +1,95 @@
#include <Functions/IFunction.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <DataTypes/DataTypeString.h>
#include <Columns/ColumnString.h>
#include <Interpreters/Context.h>
#include "Common/CurrentThread.h"
#include <Common/Macros.h>
#include "Interpreters/ClientInfo.h"
#include "Interpreters/Context_fwd.h"
#include <Core/Field.h>
#include <Poco/Net/NameValueCollection.h>
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int ILLEGAL_COLUMN;
extern const int FUNCTION_NOT_ALLOWED;
}
namespace
{
/** Get the value of parameter in http headers.
* If there no such parameter or the method of request is not
* http, the function will return empty string.
*/
class FunctionGetHttpHeader : public IFunction
{
private:
public:
FunctionGetHttpHeader() = default;
static constexpr auto name = "getHttpHeader";
static FunctionPtr create(ContextPtr /*context*/)
{
return std::make_shared<FunctionGetHttpHeader>();
}
String getName() const override { return name; }
bool isDeterministic() const override { return false; }
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return false; }
size_t getNumberOfArguments() const override
{
return 1;
}
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!isString(arguments[0]))
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "The argument of function {} must have String type", getName());
return std::make_shared<DataTypeString>();
}
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override
{
const auto & query_context = DB::CurrentThread::getQueryContext();
const auto & method = query_context->getClientInfo().http_method;
const auto & headers = DB::CurrentThread::getQueryContext()->getClientInfo().headers;
const IColumn * arg_column = arguments[0].column.get();
const ColumnString * arg_string = checkAndGetColumnConstData<ColumnString>(arg_column);
if (!arg_string)
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "The argument of function {} must be constant String", getName());
if (method != ClientInfo::HTTPMethod::GET && method != ClientInfo::HTTPMethod::POST)
return result_type->createColumnConst(input_rows_count, "");
if (!headers.has(arg_string->getDataAt(0).toString()))
return result_type->createColumnConst(input_rows_count, "");
return result_type->createColumnConst(input_rows_count, headers[arg_string->getDataAt(0).toString()]);
}
};
}
REGISTER_FUNCTION(GetHttpHeader)
{
factory.registerFunction<FunctionGetHttpHeader>();
}
}

View File

@ -1,6 +1,7 @@
#pragma once
#include <Core/UUID.h>
#include <Poco/Net/NameValueCollection.h>
#include <Poco/Net/SocketAddress.h>
#include <base/types.h>
#include <Common/OpenTelemetryTraceContext.h>
@ -96,6 +97,8 @@ public:
/// For mysql and postgresql
UInt64 connection_id = 0;
Poco::Net::NameValueCollection headers;
void setHttpHeaders(const Poco::Net::NameValueCollection & _headers) { headers = _headers; }
/// Comma separated list of forwarded IP addresses (from X-Forwarded-For for HTTP interface).
/// It's expected that proxy appends the forwarded address to the end of the list.

View File

@ -2,6 +2,7 @@
#include <set>
#include <optional>
#include <memory>
#include <Poco/Net/NameValueCollection.h>
#include <Poco/UUID.h>
#include <Poco/Util/Application.h>
#include <Common/Macros.h>
@ -85,6 +86,7 @@
#include <Common/logger_useful.h>
#include <Common/RemoteHostFilter.h>
#include <Common/HTTPHeaderFilter.h>
#include "Disks/ObjectStorages/S3/diskSettings.h"
#include <Interpreters/AsynchronousInsertQueue.h>
#include <Interpreters/DatabaseCatalog.h>
#include <Interpreters/JIT/CompiledExpressionCache.h>
@ -4008,12 +4010,18 @@ void Context::setClientConnectionId(uint32_t connection_id_)
client_info.connection_id = connection_id_;
}
void Context::setHttpClientInfo(ClientInfo::HTTPMethod http_method, const String & http_user_agent, const String & http_referer)
void Context::setHttpClientInfo(ClientInfo::HTTPMethod http_method, const String & http_user_agent, const String & http_referer, const Poco::Net::NameValueCollection & http_headers)
{
client_info.http_method = http_method;
client_info.http_user_agent = http_user_agent;
client_info.http_referer = http_referer;
need_recalculate_access = true;
if (!http_headers.empty())
{
for (const auto & http_header : http_headers)
client_info.headers.set(http_header.first, http_header.second);
}
}
void Context::setForwardedFor(const String & forwarded_for)

View File

@ -1,5 +1,7 @@
#pragma once
#include <Poco/Net/NameValueCollection.h>
#include "Core/Types.h"
#ifndef CLICKHOUSE_KEEPER_STANDALONE_BUILD
#include <base/types.h>
@ -609,7 +611,7 @@ public:
void setClientInterface(ClientInfo::Interface interface);
void setClientVersion(UInt64 client_version_major, UInt64 client_version_minor, UInt64 client_version_patch, unsigned client_tcp_protocol_version);
void setClientConnectionId(uint32_t connection_id);
void setHttpClientInfo(ClientInfo::HTTPMethod http_method, const String & http_user_agent, const String & http_referer);
void setHttpClientInfo(ClientInfo::HTTPMethod http_method, const String & http_user_agent, const String & http_referer, const Poco::Net::NameValueCollection & http_headers = {});
void setForwardedFor(const String & forwarded_for);
void setQueryKind(ClientInfo::QueryKind query_kind);
void setQueryKindInitial();
@ -804,6 +806,10 @@ public:
/// Storage of forbidden HTTP headers from config.xml
void setHTTPHeaderFilter(const Poco::Util::AbstractConfiguration & config);
const HTTPHeaderFilter & getHTTPHeaderFilter() const;
const Poco::Net::NameValueCollection & getHttpHeaders() const
{
return client_info.headers;
}
/// The port that the server listens for executing SQL queries.
UInt16 getTCPPort() const;

View File

@ -15,6 +15,7 @@
#include <Interpreters/Cluster.h>
#include <magic_enum.hpp>
#include <Poco/Net/NameValueCollection.h>
#include <atomic>
#include <condition_variable>
@ -428,17 +429,18 @@ void Session::setClientConnectionId(uint32_t connection_id)
prepared_client_info->connection_id = connection_id;
}
void Session::setHttpClientInfo(ClientInfo::HTTPMethod http_method, const String & http_user_agent, const String & http_referer)
void Session::setHttpClientInfo(ClientInfo::HTTPMethod http_method, const String & http_user_agent, const String & http_referer, const Poco::Net::NameValueCollection & http_headers)
{
if (session_context)
{
session_context->setHttpClientInfo(http_method, http_user_agent, http_referer);
session_context->setHttpClientInfo(http_method, http_user_agent, http_referer, http_headers);
}
else
{
prepared_client_info->http_method = http_method;
prepared_client_info->http_user_agent = http_user_agent;
prepared_client_info->http_referer = http_referer;
prepared_client_info->headers = http_headers;
}
}

View File

@ -5,6 +5,7 @@
#include <Interpreters/ClientInfo.h>
#include <Interpreters/Context_fwd.h>
#include <Interpreters/SessionTracker.h>
#include <Poco/Net/NameValueCollection.h>
#include <chrono>
#include <memory>
@ -64,7 +65,7 @@ public:
void setClientInterface(ClientInfo::Interface interface);
void setClientVersion(UInt64 client_version_major, UInt64 client_version_minor, UInt64 client_version_patch, unsigned client_tcp_protocol_version);
void setClientConnectionId(uint32_t connection_id);
void setHttpClientInfo(ClientInfo::HTTPMethod http_method, const String & http_user_agent, const String & http_referer);
void setHttpClientInfo(ClientInfo::HTTPMethod http_method, const String & http_user_agent, const String & http_referer, const Poco::Net::NameValueCollection & http_headers = {});
void setForwardedFor(const String & forwarded_for);
void setQuotaClientKey(const String & quota_key);
void setConnectionClientVersion(UInt64 client_version_major, UInt64 client_version_minor, UInt64 client_version_patch, unsigned client_tcp_protocol_version);

View File

@ -40,6 +40,7 @@
#include <Poco/Net/HTTPBasicCredentials.h>
#include <Poco/Net/HTTPStream.h>
#include <Poco/MemoryStream.h>
#include <Poco/Net/NameValueCollection.h>
#include <Poco/StreamCopier.h>
#include <Poco/String.h>
#include <Poco/Net/SocketAddress.h>
@ -500,7 +501,11 @@ bool HTTPHandler::authenticateUser(
else if (request.getMethod() == HTTPServerRequest::HTTP_POST)
http_method = ClientInfo::HTTPMethod::POST;
session->setHttpClientInfo(http_method, request.get("User-Agent", ""), request.get("Referer", ""));
NameValueCollection http_headers;
for (const auto & it : request)
http_headers.set(it.first, it.second);
session->setHttpClientInfo(http_method, request.get("User-Agent", ""), request.get("Referer", ""), http_headers);
session->setForwardedFor(request.get("X-Forwarded-For", ""));
session->setQuotaClientKey(quota_key);
@ -580,6 +585,10 @@ void HTTPHandler::processQuery(
session->makeSessionContext();
}
NameValueCollection headers;
for (auto it = request.begin(); it != request.end(); ++it)
headers.set(it->first, it->second);
auto context = session->makeQueryContext();
/// This parameter is used to tune the behavior of output formats (such as Native) for compatibility.

View File

@ -0,0 +1,2 @@
Code: 81. DB::Exception: Database `02884_getHttpHeaderFunction` does not exist. (UNKNOWN_DATABASE) (version 23.9.1.1)
default

View File

@ -0,0 +1,21 @@
#!/usr/bin/env bash
CUR_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
# shellcheck source=../shell_config.sh
. "$CUR_DIR"/../shell_config.sh
db="02884_getHttpHeaderFunction"
$CLICKHOUSE_CLIENT -q "CREATE DATABASE IF NOT EXISTS ${db}"
$CLICKHOUSE_CLIENT -q "CREATE TABLE IF NOT EXISTS ${db}.get_http_header (id UInt32, header_value String DEFAULT getHttpHeader('X-Clickhouse-User')) Engine=Memory()"
#Insert data via tcp client
$CLICKHOUSE_CLIENT -q "INSERT INTO ${db}.get_http_header (id) values (1), (2)"
#Insert data via http request
echo "INSERT INTO ${db}.get_http_header (id) values (3), (4)" | curl -H 'X-ClickHouse-User: default' -H 'X-ClickHouse-Key: ' 'http://localhost:8123/' -d @-
$CLICKHOUSE_CLIENT -q "SELECT * FROM ${db}.get_http_header ORDER BY id;"
$CLICKHOUSE_CLIENT -q "DROP DATABASE ${db}"
echo "SELECT getHttpHeader('X-Clickhouse-User')" | curl -H 'X-ClickHouse-User: default' -H 'X-ClickHouse-Key: ' 'http://localhost:8123/' -d @-