mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-12-04 13:32:13 +00:00
923 lines
22 KiB
C++
923 lines
22 KiB
C++
#pragma once
|
|
|
|
#include <functional>
|
|
#include <IO/ReadBuffer.h>
|
|
#include <IO/ReadHelpers.h>
|
|
#include <IO/WriteBuffer.h>
|
|
#include <IO/WriteHelpers.h>
|
|
#include <Interpreters/Session.h>
|
|
#include <Common/logger_useful.h>
|
|
#include <Poco/Format.h>
|
|
#include <Poco/RegularExpression.h>
|
|
#include <Poco/Net/StreamSocket.h>
|
|
#include "Types.h"
|
|
#include <unordered_map>
|
|
#include <utility>
|
|
|
|
namespace DB
|
|
{
|
|
|
|
namespace ErrorCodes
|
|
{
|
|
extern const int UNKNOWN_PACKET_FROM_CLIENT;
|
|
extern const int UNEXPECTED_PACKET_FROM_CLIENT;
|
|
extern const int NOT_IMPLEMENTED;
|
|
extern const int UNKNOWN_TYPE;
|
|
}
|
|
|
|
|
|
namespace PostgreSQLProtocol
|
|
{
|
|
|
|
namespace Messaging
|
|
{
|
|
|
|
enum class FrontMessageType : Int32
|
|
{
|
|
// first message types
|
|
CANCEL_REQUEST = 80877102,
|
|
SSL_REQUEST = 80877103,
|
|
GSSENC_REQUEST = 80877104,
|
|
|
|
// other front message types
|
|
PASSWORD_MESSAGE = 'p',
|
|
QUERY = 'Q',
|
|
TERMINATE = 'X',
|
|
PARSE = 'P',
|
|
BIND = 'B',
|
|
DESCRIBE = 'D',
|
|
SYNC = 'S',
|
|
FLUSH = 'H',
|
|
CLOSE = 'C',
|
|
};
|
|
|
|
enum class MessageType : Int32
|
|
{
|
|
// common
|
|
ERROR_RESPONSE = 0,
|
|
CANCEL_REQUEST = 1,
|
|
COMMAND_COMPLETE = 2,
|
|
NOTICE_RESPONSE = 3,
|
|
NOTIFICATION_RESPONSE = 4,
|
|
PARAMETER_STATUS = 5,
|
|
READY_FOR_QUERY = 6,
|
|
SYNC = 7,
|
|
TERMINATE = 8,
|
|
|
|
// start up and authentication
|
|
AUTHENTICATION_OK = 30,
|
|
AUTHENTICATION_KERBEROS_V5 = 31,
|
|
AUTHENTICATION_CLEARTEXT_PASSWORD = 32,
|
|
AUTHENTICATION_MD5_PASSWORD = 33,
|
|
AUTHENTICATION_SCM_CREDENTIAL = 34,
|
|
AUTHENTICATION_GSS = 35,
|
|
AUTHENTICATION_SSPI = 36,
|
|
AUTHENTICATION_GSS_CONTINUE = 37,
|
|
AUTHENTICATION_SASL = 38,
|
|
AUTHENTICATION_SASL_CONTINUE = 39,
|
|
AUTHENTICATION_SASL_FINAL = 40,
|
|
BACKEND_KEY_DATA = 41,
|
|
GSSENC_REQUEST = 42,
|
|
GSS_RESPONSE = 43,
|
|
NEGOTIATE_PROTOCOL_VERSION = 44,
|
|
PASSWORD_MESSAGE = 45,
|
|
SASL_INITIAL_RESPONSE = 46,
|
|
SASL_RESPONSE = 47,
|
|
SSL_REQUEST = 48,
|
|
STARTUP_MESSAGE = 49,
|
|
|
|
// simple query
|
|
DATA_ROW = 100,
|
|
EMPTY_QUERY_RESPONSE = 101,
|
|
ROW_DESCRIPTION = 102,
|
|
QUERY = 103,
|
|
|
|
// extended query
|
|
BIND = 120,
|
|
BIND_COMPLETE = 121,
|
|
CLOSE = 122,
|
|
CLOSE_COMPLETE = 123,
|
|
DESCRIBE = 124,
|
|
EXECUTE = 125,
|
|
FLUSH = 126,
|
|
NODATA = 127,
|
|
PARAMETER_DESCRIPTION = 128,
|
|
PARSE = 129,
|
|
PARSE_COMPLETE = 130,
|
|
PORTAL_SUSPENDED = 131,
|
|
|
|
// copy query
|
|
COPY_DATA = 171,
|
|
COPY_DONE = 172,
|
|
COPY_FAIL = 173,
|
|
COPY_IN_RESPONSE = 174,
|
|
COPY_OUT_RESPONSE = 175,
|
|
COPY_BOTH_RESPONSE = 176,
|
|
|
|
// function query (deprecated by the protocol)
|
|
FUNCTION_CALL = 190,
|
|
FUNCTION_CALL_RESPONSE = 191,
|
|
};
|
|
|
|
//// Column 'typelem' from 'pg_type' table. NB: not all types are compatible with PostgreSQL's ones
|
|
enum class ColumnType : Int32
|
|
{
|
|
CHAR = 18,
|
|
INT8 = 20,
|
|
INT2 = 21,
|
|
INT4 = 23,
|
|
FLOAT4 = 700,
|
|
FLOAT8 = 701,
|
|
VARCHAR = 1043,
|
|
DATE = 1082,
|
|
NUMERIC = 1700,
|
|
UUID = 2950,
|
|
};
|
|
|
|
class ColumnTypeSpec
|
|
{
|
|
public:
|
|
ColumnType type;
|
|
Int16 len;
|
|
|
|
ColumnTypeSpec(ColumnType type_, Int16 len_) : type(type_), len(len_) {}
|
|
};
|
|
|
|
ColumnTypeSpec convertTypeIndexToPostgresColumnTypeSpec(TypeIndex type_index);
|
|
|
|
class MessageTransport
|
|
{
|
|
private:
|
|
ReadBuffer * in;
|
|
WriteBuffer * out;
|
|
|
|
public:
|
|
explicit MessageTransport(WriteBuffer * out_) : in(nullptr), out(out_) {}
|
|
|
|
MessageTransport(ReadBuffer * in_, WriteBuffer * out_): in(in_), out(out_) {}
|
|
|
|
template<typename TMessage>
|
|
std::unique_ptr<TMessage> receiveWithPayloadSize(Int32 payload_size)
|
|
{
|
|
std::unique_ptr<TMessage> message = std::make_unique<TMessage>(payload_size);
|
|
message->deserialize(*in);
|
|
return message;
|
|
}
|
|
|
|
template<typename TMessage>
|
|
std::unique_ptr<TMessage> receive()
|
|
{
|
|
std::unique_ptr<TMessage> message = std::make_unique<TMessage>();
|
|
message->deserialize(*in);
|
|
return message;
|
|
}
|
|
|
|
FrontMessageType receiveMessageType()
|
|
{
|
|
char type = 0;
|
|
in->readStrict(type);
|
|
return static_cast<FrontMessageType>(type);
|
|
}
|
|
|
|
template<typename TMessage>
|
|
void send(TMessage & message, bool flush=false)
|
|
{
|
|
message.serialize(*out);
|
|
if (flush)
|
|
out->next();
|
|
}
|
|
|
|
template<typename TMessage>
|
|
void send(TMessage && message, bool flush=false)
|
|
{
|
|
send(message, flush);
|
|
}
|
|
|
|
void send(char message, bool flush=false)
|
|
{
|
|
out->write(message);
|
|
if (flush)
|
|
out->next();
|
|
}
|
|
|
|
void dropMessage()
|
|
{
|
|
Int32 size;
|
|
readBinaryBigEndian(size, *in);
|
|
in->ignore(size - 4);
|
|
}
|
|
|
|
void flush()
|
|
{
|
|
out->next();
|
|
}
|
|
};
|
|
|
|
/** Basic class for messages sent by client or server. */
|
|
class IMessage
|
|
{
|
|
public:
|
|
virtual MessageType getMessageType() const = 0;
|
|
|
|
virtual ~IMessage() = default;
|
|
};
|
|
|
|
class ISerializable
|
|
{
|
|
public:
|
|
/** Should be overridden for sending the message */
|
|
virtual void serialize(WriteBuffer & out) const = 0;
|
|
|
|
/** Size of the message in bytes including message length part (4 bytes) */
|
|
virtual Int32 size() const = 0;
|
|
|
|
ISerializable() = default;
|
|
|
|
ISerializable(const ISerializable &) = default;
|
|
|
|
virtual ~ISerializable() = default;
|
|
};
|
|
|
|
class FrontMessage : public IMessage
|
|
{
|
|
public:
|
|
/** Should be overridden for receiving the message
|
|
* NB: This method should not read the first byte, which means the type of the message
|
|
* (if type is provided for the message by the protocol).
|
|
*/
|
|
virtual void deserialize(ReadBuffer & in) = 0;
|
|
};
|
|
|
|
class BackendMessage : public IMessage, public ISerializable
|
|
{};
|
|
|
|
class FirstMessage : public FrontMessage
|
|
{
|
|
public:
|
|
Int32 payload_size;
|
|
|
|
FirstMessage() = delete;
|
|
explicit FirstMessage(int payload_size_) : payload_size(payload_size_) {}
|
|
};
|
|
|
|
class CancelRequest : public FirstMessage
|
|
{
|
|
public:
|
|
Int32 process_id = 0;
|
|
Int32 secret_key = 0;
|
|
|
|
explicit CancelRequest(int payload_size_) : FirstMessage(payload_size_) {}
|
|
|
|
void deserialize(ReadBuffer & in) override
|
|
{
|
|
readBinaryBigEndian(process_id, in);
|
|
readBinaryBigEndian(secret_key, in);
|
|
}
|
|
|
|
MessageType getMessageType() const override
|
|
{
|
|
return MessageType::CANCEL_REQUEST;
|
|
}
|
|
};
|
|
|
|
class ErrorOrNoticeResponse : BackendMessage
|
|
{
|
|
public:
|
|
enum Severity {ERROR = 0, FATAL = 1, PANIC = 2, WARNING = 3, NOTICE = 4, DEBUG = 5, INFO = 6, LOG = 7};
|
|
|
|
private:
|
|
Severity severity;
|
|
String sql_state;
|
|
String message;
|
|
|
|
String enum_to_string[8] = {"ERROR", "FATAL", "PANIC", "WARNING", "NOTICE", "DEBUG", "INFO", "LOG"};
|
|
|
|
char isErrorOrNotice() const
|
|
{
|
|
switch (severity)
|
|
{
|
|
case ERROR:
|
|
case FATAL:
|
|
case PANIC:
|
|
return 'E';
|
|
case WARNING:
|
|
case NOTICE:
|
|
case DEBUG:
|
|
case INFO:
|
|
case LOG:
|
|
return 'N';
|
|
}
|
|
throw Exception(ErrorCodes::UNKNOWN_TYPE, "Unknown severity type {}", std::to_string(severity));
|
|
}
|
|
|
|
public:
|
|
ErrorOrNoticeResponse(const Severity & severity_, const String & sql_state_, const String & message_)
|
|
: severity(severity_)
|
|
, sql_state(sql_state_)
|
|
, message(message_)
|
|
{}
|
|
|
|
void serialize(WriteBuffer & out) const override
|
|
{
|
|
out.write(isErrorOrNotice());
|
|
Int32 sz = size();
|
|
writeBinaryBigEndian(sz, out);
|
|
|
|
out.write('S');
|
|
writeNullTerminatedString(enum_to_string[severity], out);
|
|
out.write('C');
|
|
writeNullTerminatedString(sql_state, out);
|
|
out.write('M');
|
|
writeNullTerminatedString(message, out);
|
|
|
|
out.write(0);
|
|
}
|
|
|
|
Int32 size() const override
|
|
{
|
|
// message length part + (1 + sizes of other fields + 1) + null byte in the end of the message
|
|
return static_cast<Int32>(
|
|
4 +
|
|
(1 + enum_to_string[severity].size() + 1) +
|
|
(1 + sql_state.size() + 1) +
|
|
(1 + message.size() + 1) +
|
|
1);
|
|
}
|
|
|
|
MessageType getMessageType() const override
|
|
{
|
|
if (isErrorOrNotice() == 'E')
|
|
return MessageType::ERROR_RESPONSE;
|
|
return MessageType::NOTICE_RESPONSE;
|
|
}
|
|
};
|
|
|
|
class ReadyForQuery : BackendMessage
|
|
{
|
|
public:
|
|
void serialize(WriteBuffer &out) const override
|
|
{
|
|
out.write('Z');
|
|
writeBinaryBigEndian(size(), out);
|
|
// 'I' means that we are not in a transaction block. We use it here, because ClickHouse doesn't support transactions.
|
|
out.write('I');
|
|
}
|
|
|
|
Int32 size() const override
|
|
{
|
|
return 4 + 1;
|
|
}
|
|
|
|
MessageType getMessageType() const override
|
|
{
|
|
return MessageType::READY_FOR_QUERY;
|
|
}
|
|
};
|
|
|
|
class Terminate : FrontMessage
|
|
{
|
|
public:
|
|
void deserialize(ReadBuffer & in) override
|
|
{
|
|
in.ignore(4);
|
|
}
|
|
|
|
MessageType getMessageType() const override
|
|
{
|
|
return MessageType::TERMINATE;
|
|
}
|
|
};
|
|
|
|
class StartupMessage : FirstMessage
|
|
{
|
|
public:
|
|
String user;
|
|
String database;
|
|
// includes username, may also include database and other runtime parameters
|
|
std::unordered_map<String, String> parameters;
|
|
|
|
explicit StartupMessage(Int32 payload_size_) : FirstMessage(payload_size_) {}
|
|
|
|
void deserialize(ReadBuffer & in) override
|
|
{
|
|
Int32 ps = payload_size - 1;
|
|
while (ps > 0)
|
|
{
|
|
String parameter_name;
|
|
String parameter_value;
|
|
readNullTerminated(parameter_name, in);
|
|
readNullTerminated(parameter_value, in);
|
|
ps -= parameter_name.size() + 1;
|
|
ps -= parameter_value.size() + 1;
|
|
|
|
if (parameter_name == "user")
|
|
{
|
|
user = parameter_value;
|
|
}
|
|
else if (parameter_name == "database")
|
|
{
|
|
database = parameter_value;
|
|
}
|
|
|
|
parameters.insert({std::move(parameter_name), std::move(parameter_value)});
|
|
|
|
if (payload_size < 0)
|
|
{
|
|
throw Exception(ErrorCodes::UNKNOWN_PACKET_FROM_CLIENT,
|
|
"Size of payload is larger than one declared in the message of type {}.",
|
|
static_cast<UInt64>(getMessageType()));
|
|
}
|
|
}
|
|
in.ignore();
|
|
}
|
|
|
|
MessageType getMessageType() const override
|
|
{
|
|
return MessageType::STARTUP_MESSAGE;
|
|
}
|
|
};
|
|
|
|
class AuthenticationCleartextPassword : public Messaging::BackendMessage
|
|
{
|
|
public:
|
|
void serialize(WriteBuffer & out) const override
|
|
{
|
|
out.write('R');
|
|
writeBinaryBigEndian(size(), out);
|
|
writeBinaryBigEndian(static_cast<Int32>(3), out); // specifies that a clear-text password is required (by protocol)
|
|
}
|
|
|
|
Int32 size() const override
|
|
{
|
|
// length of message + special int32
|
|
return 4 + 4;
|
|
}
|
|
|
|
MessageType getMessageType() const override
|
|
{
|
|
return MessageType::AUTHENTICATION_CLEARTEXT_PASSWORD;
|
|
}
|
|
};
|
|
|
|
class AuthenticationOk : BackendMessage
|
|
{
|
|
public:
|
|
void serialize(WriteBuffer & out) const override
|
|
{
|
|
out.write('R');
|
|
writeBinaryBigEndian(size(), out);
|
|
writeBinaryBigEndian(0, out); // specifies that the authentication was successful (by protocol)
|
|
}
|
|
|
|
Int32 size() const override
|
|
{
|
|
// length of message + special int32
|
|
return 4 + 4;
|
|
}
|
|
|
|
MessageType getMessageType() const override
|
|
{
|
|
return MessageType::AUTHENTICATION_OK;
|
|
}
|
|
};
|
|
|
|
class PasswordMessage : FrontMessage
|
|
{
|
|
public:
|
|
String password;
|
|
|
|
void deserialize(ReadBuffer & in) override
|
|
{
|
|
Int32 sz;
|
|
readBinaryBigEndian(sz, in);
|
|
readNullTerminated(password, in);
|
|
}
|
|
|
|
MessageType getMessageType() const override
|
|
{
|
|
return MessageType::PASSWORD_MESSAGE;
|
|
}
|
|
};
|
|
|
|
class ParameterStatus : BackendMessage
|
|
{
|
|
private:
|
|
String name;
|
|
String value;
|
|
|
|
public:
|
|
ParameterStatus(String name_, String value_)
|
|
: name(name_)
|
|
, value(value_)
|
|
{}
|
|
|
|
void serialize(WriteBuffer & out) const override
|
|
{
|
|
out.write('S');
|
|
writeBinaryBigEndian(size(), out);
|
|
writeNullTerminatedString(name, out);
|
|
writeNullTerminatedString(value, out);
|
|
}
|
|
|
|
Int32 size() const override
|
|
{
|
|
return static_cast<Int32>(4 + name.size() + 1 + value.size() + 1);
|
|
}
|
|
|
|
MessageType getMessageType() const override
|
|
{
|
|
return MessageType::PARAMETER_STATUS;
|
|
}
|
|
};
|
|
|
|
class BackendKeyData : BackendMessage
|
|
{
|
|
private:
|
|
Int32 process_id;
|
|
Int32 secret_key;
|
|
|
|
public:
|
|
BackendKeyData(Int32 process_id_, Int32 secret_key_)
|
|
: process_id(process_id_)
|
|
, secret_key(secret_key_)
|
|
{}
|
|
|
|
void serialize(WriteBuffer & out) const override
|
|
{
|
|
out.write('K');
|
|
writeBinaryBigEndian(size(), out);
|
|
writeBinaryBigEndian(process_id, out);
|
|
writeBinaryBigEndian(secret_key, out);
|
|
}
|
|
|
|
Int32 size() const override
|
|
{
|
|
return 4 + 4 + 4;
|
|
}
|
|
|
|
MessageType getMessageType() const override
|
|
{
|
|
return MessageType::BACKEND_KEY_DATA;
|
|
}
|
|
};
|
|
|
|
class Query : FrontMessage
|
|
{
|
|
public:
|
|
String query;
|
|
|
|
void deserialize(ReadBuffer & in) override
|
|
{
|
|
Int32 sz;
|
|
readBinaryBigEndian(sz, in);
|
|
readNullTerminated(query, in);
|
|
}
|
|
|
|
MessageType getMessageType() const override
|
|
{
|
|
return MessageType::QUERY;
|
|
}
|
|
};
|
|
|
|
class EmptyQueryResponse : public BackendMessage
|
|
{
|
|
public:
|
|
void serialize(WriteBuffer & out) const override
|
|
{
|
|
out.write('I');
|
|
writeBinaryBigEndian(size(), out);
|
|
}
|
|
|
|
Int32 size() const override
|
|
{
|
|
return 4;
|
|
}
|
|
|
|
MessageType getMessageType() const override
|
|
{
|
|
return MessageType::EMPTY_QUERY_RESPONSE;
|
|
}
|
|
};
|
|
|
|
enum class FormatCode : Int16
|
|
{
|
|
TEXT = 0,
|
|
BINARY = 1,
|
|
};
|
|
|
|
class FieldDescription : ISerializable
|
|
{
|
|
private:
|
|
const String & name;
|
|
ColumnTypeSpec type_spec;
|
|
FormatCode format_code;
|
|
|
|
public:
|
|
FieldDescription(const String & name_, TypeIndex type_index, FormatCode format_code_ = FormatCode::TEXT)
|
|
: name(name_)
|
|
, type_spec(convertTypeIndexToPostgresColumnTypeSpec(type_index))
|
|
, format_code(format_code_)
|
|
{}
|
|
|
|
void serialize(WriteBuffer & out) const override
|
|
{
|
|
writeNullTerminatedString(name, out);
|
|
writeBinaryBigEndian(static_cast<Int32>(0), out);
|
|
writeBinaryBigEndian(static_cast<Int16>(0), out);
|
|
writeBinaryBigEndian(static_cast<Int32>(type_spec.type), out);
|
|
writeBinaryBigEndian(type_spec.len, out);
|
|
writeBinaryBigEndian(static_cast<Int32>(-1), out);
|
|
writeBinaryBigEndian(static_cast<Int16>(format_code), out);
|
|
}
|
|
|
|
Int32 size() const override
|
|
{
|
|
// size of name (C string)
|
|
// + object ID of the table (Int32 and always zero) + attribute number of the column (Int16 and always zero)
|
|
// + type object id (Int32) + data type size (Int16)
|
|
// + type modifier (Int32 and always -1) + format code (Int16)
|
|
return static_cast<Int32>((name.size() + 1) + 4 + 2 + 4 + 2 + 4 + 2);
|
|
}
|
|
};
|
|
|
|
class RowDescription : BackendMessage
|
|
{
|
|
private:
|
|
const std::vector<FieldDescription> & fields_descr;
|
|
|
|
public:
|
|
explicit RowDescription(const std::vector<FieldDescription> & fields_descr_) : fields_descr(fields_descr_) {}
|
|
|
|
void serialize(WriteBuffer & out) const override
|
|
{
|
|
out.write('T');
|
|
writeBinaryBigEndian(size(), out);
|
|
writeBinaryBigEndian(static_cast<Int16>(fields_descr.size()), out);
|
|
for (const FieldDescription & field : fields_descr)
|
|
field.serialize(out);
|
|
}
|
|
|
|
Int32 size() const override
|
|
{
|
|
Int32 sz = 4 + 2; // size of message + number of fields
|
|
for (const FieldDescription & field : fields_descr)
|
|
sz += field.size();
|
|
return sz;
|
|
}
|
|
|
|
MessageType getMessageType() const override
|
|
{
|
|
return MessageType::ROW_DESCRIPTION;
|
|
}
|
|
};
|
|
|
|
class StringField : public ISerializable
|
|
{
|
|
private:
|
|
String str;
|
|
public:
|
|
explicit StringField(String str_) : str(str_) {}
|
|
|
|
void serialize(WriteBuffer & out) const override
|
|
{
|
|
writeString(str, out);
|
|
}
|
|
|
|
Int32 size() const override
|
|
{
|
|
return static_cast<Int32>(str.size());
|
|
}
|
|
};
|
|
|
|
class NullField : public ISerializable
|
|
{
|
|
public:
|
|
void serialize(WriteBuffer & /* out */) const override {}
|
|
|
|
Int32 size() const override
|
|
{
|
|
return -1;
|
|
}
|
|
};
|
|
|
|
class DataRow : BackendMessage
|
|
{
|
|
private:
|
|
const std::vector<std::shared_ptr<ISerializable>> & row;
|
|
|
|
public:
|
|
explicit DataRow(const std::vector<std::shared_ptr<ISerializable>> & row_) : row(row_) {}
|
|
|
|
void serialize(WriteBuffer & out) const override
|
|
{
|
|
out.write('D');
|
|
writeBinaryBigEndian(size(), out);
|
|
writeBinaryBigEndian(static_cast<Int16>(row.size()), out);
|
|
for (const std::shared_ptr<ISerializable> & field : row)
|
|
{
|
|
Int32 sz = field->size();
|
|
writeBinaryBigEndian(sz, out);
|
|
if (sz > 0)
|
|
field->serialize(out);
|
|
}
|
|
}
|
|
|
|
Int32 size() const override
|
|
{
|
|
Int32 sz = 4 + 2; // size of message + number of fields
|
|
/// If values is NULL, field size is -1 and data not added.
|
|
for (const std::shared_ptr<ISerializable> & field : row)
|
|
sz += 4 + (field->size() > 0 ? field->size() : 0);
|
|
return sz;
|
|
}
|
|
|
|
MessageType getMessageType() const override
|
|
{
|
|
return MessageType::DATA_ROW;
|
|
}
|
|
};
|
|
|
|
class CommandComplete : BackendMessage
|
|
{
|
|
public:
|
|
enum Command {BEGIN = 0, COMMIT = 1, INSERT = 2, DELETE = 3, UPDATE = 4, SELECT = 5, MOVE = 6, FETCH = 7, COPY = 8};
|
|
private:
|
|
String enum_to_string[9] = {"BEGIN", "COMMIT", "INSERT", "DELETE", "UPDATE", "SELECT", "MOVE", "FETCH", "COPY"};
|
|
|
|
String value;
|
|
|
|
public:
|
|
CommandComplete(Command cmd_, Int32 rows_count_)
|
|
{
|
|
value = enum_to_string[cmd_];
|
|
String add = " ";
|
|
if (cmd_ == Command::INSERT)
|
|
add = " 0 ";
|
|
value += add + std::to_string(rows_count_);
|
|
}
|
|
|
|
void serialize(WriteBuffer & out) const override
|
|
{
|
|
out.write('C');
|
|
writeBinaryBigEndian(size(), out);
|
|
writeNullTerminatedString(value, out);
|
|
}
|
|
|
|
Int32 size() const override
|
|
{
|
|
return static_cast<Int32>(4 + value.size() + 1);
|
|
}
|
|
|
|
MessageType getMessageType() const override
|
|
{
|
|
return MessageType::COMMAND_COMPLETE;
|
|
}
|
|
|
|
static Command classifyQuery(const String & query)
|
|
{
|
|
std::vector<String> query_types({"BEGIN", "COMMIT", "INSERT", "DELETE", "UPDATE", "SELECT", "MOVE", "FETCH", "COPY"});
|
|
for (size_t i = 0; i != query_types.size(); ++i)
|
|
{
|
|
String::const_iterator iter = std::search(
|
|
query.begin(),
|
|
query.end(),
|
|
query_types[i].begin(),
|
|
query_types[i].end(),
|
|
[](char a, char b){return std::toupper(a) == b;});
|
|
|
|
if (iter != query.end())
|
|
return static_cast<Command>(i);
|
|
}
|
|
|
|
return Command::SELECT;
|
|
}
|
|
};
|
|
|
|
}
|
|
|
|
namespace PGAuthentication
|
|
{
|
|
|
|
class AuthenticationMethod
|
|
{
|
|
protected:
|
|
static void setPassword(
|
|
const String & user_name,
|
|
const String & password,
|
|
Session & session,
|
|
const Poco::Net::SocketAddress & address)
|
|
{
|
|
session.authenticate(user_name, password, address);
|
|
}
|
|
|
|
public:
|
|
virtual void authenticate(
|
|
const String & user_name,
|
|
Session & session,
|
|
Messaging::MessageTransport & mt,
|
|
const Poco::Net::SocketAddress & address) = 0;
|
|
|
|
virtual AuthenticationType getType() const = 0;
|
|
|
|
virtual ~AuthenticationMethod() = default;
|
|
};
|
|
|
|
class NoPasswordAuth : public AuthenticationMethod
|
|
{
|
|
public:
|
|
void authenticate(
|
|
const String & user_name,
|
|
Session & session,
|
|
[[maybe_unused]] Messaging::MessageTransport & mt,
|
|
const Poco::Net::SocketAddress & address) override
|
|
{
|
|
return setPassword(user_name, "", session, address);
|
|
}
|
|
|
|
AuthenticationType getType() const override
|
|
{
|
|
return AuthenticationType::NO_PASSWORD;
|
|
}
|
|
};
|
|
|
|
class CleartextPasswordAuth : public AuthenticationMethod
|
|
{
|
|
public:
|
|
void authenticate(
|
|
const String & user_name,
|
|
Session & session,
|
|
Messaging::MessageTransport & mt,
|
|
const Poco::Net::SocketAddress & address) override
|
|
{
|
|
mt.send(Messaging::AuthenticationCleartextPassword(), true);
|
|
|
|
Messaging::FrontMessageType type = mt.receiveMessageType();
|
|
if (type == Messaging::FrontMessageType::PASSWORD_MESSAGE)
|
|
{
|
|
std::unique_ptr<Messaging::PasswordMessage> password = mt.receive<Messaging::PasswordMessage>();
|
|
return setPassword(user_name, password->password, session, address);
|
|
}
|
|
else
|
|
throw Exception(ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT,
|
|
"Client sent wrong message or closed the connection. Message byte was {}.",
|
|
static_cast<Int32>(type));
|
|
}
|
|
|
|
AuthenticationType getType() const override
|
|
{
|
|
return AuthenticationType::PLAINTEXT_PASSWORD;
|
|
}
|
|
};
|
|
|
|
class AuthenticationManager
|
|
{
|
|
private:
|
|
Poco::Logger * log = &Poco::Logger::get("AuthenticationManager");
|
|
std::unordered_map<AuthenticationType, std::shared_ptr<AuthenticationMethod>> type_to_method = {};
|
|
|
|
public:
|
|
explicit AuthenticationManager(const std::vector<std::shared_ptr<AuthenticationMethod>> & auth_methods)
|
|
{
|
|
for (const std::shared_ptr<AuthenticationMethod> & method : auth_methods)
|
|
{
|
|
type_to_method[method->getType()] = method;
|
|
}
|
|
}
|
|
|
|
void authenticate(
|
|
const String & user_name,
|
|
Session & session,
|
|
Messaging::MessageTransport & mt,
|
|
const Poco::Net::SocketAddress & address)
|
|
{
|
|
AuthenticationType user_auth_type;
|
|
try
|
|
{
|
|
user_auth_type = session.getAuthenticationTypeOrLogInFailure(user_name);
|
|
if (type_to_method.find(user_auth_type) != type_to_method.end())
|
|
{
|
|
type_to_method[user_auth_type]->authenticate(user_name, session, mt, address);
|
|
mt.send(Messaging::AuthenticationOk(), true);
|
|
LOG_DEBUG(log, "Authentication for user {} was successful.", user_name);
|
|
return;
|
|
}
|
|
}
|
|
catch (const Exception&)
|
|
{
|
|
mt.send(Messaging::ErrorOrNoticeResponse(Messaging::ErrorOrNoticeResponse::ERROR, "28P01", "Invalid user or password"),
|
|
true);
|
|
|
|
throw;
|
|
}
|
|
|
|
mt.send(Messaging::ErrorOrNoticeResponse(Messaging::ErrorOrNoticeResponse::ERROR, "0A000", "Authentication method is not supported"),
|
|
true);
|
|
|
|
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Authentication method is not supported: {}", user_auth_type);
|
|
}
|
|
};
|
|
}
|
|
|
|
}
|
|
}
|