implemented ProtocolText

This commit is contained in:
Yuriy 2019-12-01 14:21:43 +03:00
parent afd8bced48
commit 3677d1dcfa
6 changed files with 151 additions and 39 deletions

View File

@ -100,4 +100,71 @@ size_t getLengthEncodedStringSize(const String & s)
return getLengthEncodedNumberSize(s.size()) + s.size(); return getLengthEncodedNumberSize(s.size()) + s.size();
} }
ColumnDefinition getColumnDefinition(const String & column_name, const TypeIndex type_index)
{
ColumnType column_type;
int flags = 0;
switch (type_index)
{
case TypeIndex::UInt8:
column_type = ColumnType::MYSQL_TYPE_TINY;
flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG;
break;
case TypeIndex::UInt16:
column_type = ColumnType::MYSQL_TYPE_SHORT;
flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG;
break;
case TypeIndex::UInt32:
column_type = ColumnType::MYSQL_TYPE_LONG;
flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG;
break;
case TypeIndex::UInt64:
column_type = ColumnType::MYSQL_TYPE_LONGLONG;
flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG;
break;
case TypeIndex::Int8:
column_type = ColumnType::MYSQL_TYPE_TINY;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::Int16:
column_type = ColumnType::MYSQL_TYPE_SHORT;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::Int32:
column_type = ColumnType::MYSQL_TYPE_LONG;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::Int64:
column_type = ColumnType::MYSQL_TYPE_LONGLONG;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::Float32:
column_type = ColumnType::MYSQL_TYPE_FLOAT;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::Float64:
column_type = ColumnType::MYSQL_TYPE_TINY;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::Date:
column_type = ColumnType::MYSQL_TYPE_DATE;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::DateTime:
column_type = ColumnType::MYSQL_TYPE_DATETIME;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::String:
column_type = ColumnType::MYSQL_TYPE_STRING;
break;
case TypeIndex::FixedString:
column_type = ColumnType::MYSQL_TYPE_STRING;
break;
default:
column_type = ColumnType::MYSQL_TYPE_STRING;
break;
}
return ColumnDefinition(column_name, CharacterSet::binary, 0, column_type, flags, 0);
}
} }

View File

@ -130,6 +130,14 @@ enum ColumnType
}; };
// https://dev.mysql.com/doc/dev/mysql-server/latest/group__group__cs__column__definition__flags.html
enum ColumnDefinitionFlags
{
UNSIGNED_FLAG = 32,
BINARY_FLAG = 128
};
class ProtocolError : public DB::Exception class ProtocolError : public DB::Exception
{ {
public: public:
@ -824,19 +832,40 @@ protected:
} }
}; };
ColumnDefinition getColumnDefinition(const String & column_name, const TypeIndex index);
namespace ProtocolText
{
class ResultsetRow : public WritePacket class ResultsetRow : public WritePacket
{ {
std::vector<String> columns; const Columns & columns;
int row_num;
size_t payload_size = 0; size_t payload_size = 0;
std::vector<String> serialized;
public: public:
ResultsetRow() = default; ResultsetRow(const DataTypes & data_types, const Columns & columns_, int row_num_)
: columns(columns_)
void appendColumn(String && value) , row_num(row_num_)
{ {
payload_size += getLengthEncodedStringSize(value); for (size_t i = 0; i < columns.size(); i++)
columns.emplace_back(std::move(value)); {
if (columns[i]->isNullAt(row_num))
{
payload_size += 1;
serialized.emplace_back("\xfb");
}
else
{
WriteBufferFromOwnString ostr;
data_types[i]->serializeAsText(*columns[i], row_num, ostr, FormatSettings());
payload_size += getLengthEncodedStringSize(ostr.str());
serialized.push_back(std::move(ostr.str()));
}
}
} }
protected: protected:
size_t getPayloadSize() const override size_t getPayloadSize() const override
{ {
@ -845,11 +874,18 @@ protected:
void writePayloadImpl(WriteBuffer & buffer) const override void writePayloadImpl(WriteBuffer & buffer) const override
{ {
for (const String & column : columns) for (size_t i = 0; i < columns.size(); i++)
writeLengthEncodedString(column, buffer); {
if (columns[i]->isNullAt(row_num))
buffer.write(serialized[i].data(), 1);
else
writeLengthEncodedString(serialized[i], buffer);
}
} }
}; };
}
namespace Authentication namespace Authentication
{ {

View File

@ -28,18 +28,15 @@ void MySQLOutputFormat::initialize()
initialized = true; initialized = true;
auto & header = getPort(PortKind::Main).getHeader(); auto & header = getPort(PortKind::Main).getHeader();
data_types = header.getDataTypes();
if (header.columns()) if (header.columns())
{ {
packet_sender.sendPacket(LengthEncodedNumber(header.columns())); packet_sender.sendPacket(LengthEncodedNumber(header.columns()));
for (const ColumnWithTypeAndName & column : header.getColumnsWithTypeAndName()) for (size_t i = 0; i < header.columns(); i++) {
{ const auto & column_name = header.getColumnsWithTypeAndName()[i].name;
ColumnDefinition column_definition(column.name, CharacterSet::binary, 0, ColumnType::MYSQL_TYPE_STRING, packet_sender.sendPacket(getColumnDefinition(column_name, data_types[i]->getTypeId()));
0, 0);
packet_sender.sendPacket(column_definition);
} }
if (!(context.mysql.client_capabilities & Capability::CLIENT_DEPRECATE_EOF)) if (!(context.mysql.client_capabilities & Capability::CLIENT_DEPRECATE_EOF))
@ -52,22 +49,9 @@ void MySQLOutputFormat::initialize()
void MySQLOutputFormat::consume(Chunk chunk) void MySQLOutputFormat::consume(Chunk chunk)
{ {
initialize(); for (size_t i = 0; i < chunk.getNumRows(); i++)
auto & header = getPort(PortKind::Main).getHeader();
size_t rows = chunk.getNumRows();
auto & columns = chunk.getColumns();
for (size_t i = 0; i < rows; i++)
{ {
ResultsetRow row_packet; ProtocolText::ResultsetRow row_packet(data_types, chunk.getColumns(), i);
for (size_t col = 0; col < columns.size(); ++col)
{
WriteBufferFromOwnString ostr;
header.getByPosition(col).type->serializeAsText(*columns[col], i, ostr, format_settings);
row_packet.appendColumn(std::move(ostr.str()));
}
packet_sender.sendPacket(row_packet); packet_sender.sendPacket(row_packet);
} }
} }

View File

@ -37,6 +37,7 @@ private:
const Context & context; const Context & context;
MySQLProtocol::PacketSender packet_sender; MySQLProtocol::PacketSender packet_sender;
FormatSettings format_settings; FormatSettings format_settings;
DataTypes data_types;
}; };
} }

View File

@ -1,7 +1,7 @@
Columns: Columns:
a a
Column types: Column types:
a BINARY a BIGINT
Result: Result:
0 0
1 1
@ -10,7 +10,7 @@ name
a a
Column types: Column types:
name BINARY name BINARY
a BINARY a TINYINT
Result: Result:
tables 1 tables 1
Columns: Columns:
@ -18,6 +18,6 @@ a
b b
Column types: Column types:
a BINARY a BINARY
b BINARY b TINYINT
Result: Result:
тест 1 тест 1

View File

@ -110,6 +110,17 @@ def test_mysql_client(mysql_client, server_address):
def test_python_client(server_address): def test_python_client(server_address):
client = pymysql.connections.Connection(host=server_address, user='user_with_double_sha1', password='abacaba', database='default', port=server_port)
with pytest.raises(pymysql.InternalError) as exc_info:
client.query('select name from tables')
assert exc_info.value.args == (60, "Table default.tables doesn't exist.")
cursor = client.cursor(pymysql.cursors.DictCursor)
cursor.execute("select 1 as a, 'тест' as b")
assert cursor.fetchall() == [{'a': 1, 'b': 'тест'}]
with pytest.raises(pymysql.InternalError) as exc_info: with pytest.raises(pymysql.InternalError) as exc_info:
pymysql.connections.Connection(host=server_address, user='default', password='abacab', database='default', port=server_port) pymysql.connections.Connection(host=server_address, user='default', password='abacab', database='default', port=server_port)
@ -124,7 +135,7 @@ def test_python_client(server_address):
cursor = client.cursor(pymysql.cursors.DictCursor) cursor = client.cursor(pymysql.cursors.DictCursor)
cursor.execute("select 1 as a, 'тест' as b") cursor.execute("select 1 as a, 'тест' as b")
assert cursor.fetchall() == [{'a': '1', 'b': 'тест'}] assert cursor.fetchall() == [{'a': 1, 'b': 'тест'}]
client.select_db('system') client.select_db('system')
@ -140,11 +151,14 @@ def test_python_client(server_address):
cursor.execute("INSERT INTO table1 VALUES (1), (3)") cursor.execute("INSERT INTO table1 VALUES (1), (3)")
cursor.execute("INSERT INTO table1 VALUES (1), (4)") cursor.execute("INSERT INTO table1 VALUES (1), (4)")
cursor.execute("SELECT * FROM table1 ORDER BY a") cursor.execute("SELECT * FROM table1 ORDER BY a")
assert cursor.fetchall() == [{'a': '1'}, {'a': '1'}, {'a': '3'}, {'a': '4'}] assert cursor.fetchall() == [{'a': 1}, {'a': 1}, {'a': 3}, {'a': 4}]
def test_golang_client(server_address, golang_container): def test_golang_client(server_address, golang_container):
# type: (str, Container) -> None # type: (str, Container) -> None
with open(os.path.join(SCRIPT_DIR, 'clients', 'golang', '0.reference')) as fp:
reference = fp.read()
code, (stdout, stderr) = golang_container.exec_run('./main --host {host} --port {port} --user default --password 123 --database ' code, (stdout, stderr) = golang_container.exec_run('./main --host {host} --port {port} --user default --password 123 --database '
'abc'.format(host=server_address, port=server_port), demux=True) 'abc'.format(host=server_address, port=server_port), demux=True)
@ -155,9 +169,11 @@ def test_golang_client(server_address, golang_container):
'default'.format(host=server_address, port=server_port), demux=True) 'default'.format(host=server_address, port=server_port), demux=True)
assert code == 0 assert code == 0
assert stdout == reference
with open(os.path.join(SCRIPT_DIR, 'clients', 'golang', '0.reference')) as fp: code, (stdout, stderr) = golang_container.exec_run('./main --host {host} --port {port} --user user_with_double_sha1 --password abacaba --database '
reference = fp.read() 'default'.format(host=server_address, port=server_port), demux=True)
assert code == 0
assert stdout == reference assert stdout == reference
@ -171,6 +187,14 @@ def test_php_client(server_address, php_container):
assert code == 0 assert code == 0
assert stdout == 'tables\n' assert stdout == 'tables\n'
code, (stdout, stderr) = php_container.exec_run('php -f test.php {host} {port} user_with_double_sha1 abacaba'.format(host=server_address, port=server_port), demux=True)
assert code == 0
assert stdout == 'tables\n'
code, (stdout, stderr) = php_container.exec_run('php -f test_ssl.php {host} {port} user_with_double_sha1 abacaba'.format(host=server_address, port=server_port), demux=True)
assert code == 0
assert stdout == 'tables\n'
def test_mysqljs_client(server_address, nodejs_container): def test_mysqljs_client(server_address, nodejs_container):
code, (_, stderr) = nodejs_container.exec_run('node test.js {host} {port} default 123'.format(host=server_address, port=server_port), demux=True) code, (_, stderr) = nodejs_container.exec_run('node test.js {host} {port} default 123'.format(host=server_address, port=server_port), demux=True)