implemented ProtocolText

This commit is contained in:
Yuriy 2019-12-01 14:21:43 +03:00
parent afd8bced48
commit 3677d1dcfa
6 changed files with 151 additions and 39 deletions

View File

@ -100,4 +100,71 @@ size_t getLengthEncodedStringSize(const String & s)
return getLengthEncodedNumberSize(s.size()) + s.size();
}
ColumnDefinition getColumnDefinition(const String & column_name, const TypeIndex type_index)
{
ColumnType column_type;
int flags = 0;
switch (type_index)
{
case TypeIndex::UInt8:
column_type = ColumnType::MYSQL_TYPE_TINY;
flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG;
break;
case TypeIndex::UInt16:
column_type = ColumnType::MYSQL_TYPE_SHORT;
flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG;
break;
case TypeIndex::UInt32:
column_type = ColumnType::MYSQL_TYPE_LONG;
flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG;
break;
case TypeIndex::UInt64:
column_type = ColumnType::MYSQL_TYPE_LONGLONG;
flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG;
break;
case TypeIndex::Int8:
column_type = ColumnType::MYSQL_TYPE_TINY;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::Int16:
column_type = ColumnType::MYSQL_TYPE_SHORT;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::Int32:
column_type = ColumnType::MYSQL_TYPE_LONG;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::Int64:
column_type = ColumnType::MYSQL_TYPE_LONGLONG;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::Float32:
column_type = ColumnType::MYSQL_TYPE_FLOAT;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::Float64:
column_type = ColumnType::MYSQL_TYPE_TINY;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::Date:
column_type = ColumnType::MYSQL_TYPE_DATE;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::DateTime:
column_type = ColumnType::MYSQL_TYPE_DATETIME;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::String:
column_type = ColumnType::MYSQL_TYPE_STRING;
break;
case TypeIndex::FixedString:
column_type = ColumnType::MYSQL_TYPE_STRING;
break;
default:
column_type = ColumnType::MYSQL_TYPE_STRING;
break;
}
return ColumnDefinition(column_name, CharacterSet::binary, 0, column_type, flags, 0);
}
}

View File

@ -130,6 +130,14 @@ enum ColumnType
};
// https://dev.mysql.com/doc/dev/mysql-server/latest/group__group__cs__column__definition__flags.html
enum ColumnDefinitionFlags
{
UNSIGNED_FLAG = 32,
BINARY_FLAG = 128
};
class ProtocolError : public DB::Exception
{
public:
@ -824,19 +832,40 @@ protected:
}
};
ColumnDefinition getColumnDefinition(const String & column_name, const TypeIndex index);
namespace ProtocolText
{
class ResultsetRow : public WritePacket
{
std::vector<String> columns;
const Columns & columns;
int row_num;
size_t payload_size = 0;
std::vector<String> serialized;
public:
ResultsetRow() = default;
void appendColumn(String && value)
ResultsetRow(const DataTypes & data_types, const Columns & columns_, int row_num_)
: columns(columns_)
, row_num(row_num_)
{
payload_size += getLengthEncodedStringSize(value);
columns.emplace_back(std::move(value));
for (size_t i = 0; i < columns.size(); i++)
{
if (columns[i]->isNullAt(row_num))
{
payload_size += 1;
serialized.emplace_back("\xfb");
}
else
{
WriteBufferFromOwnString ostr;
data_types[i]->serializeAsText(*columns[i], row_num, ostr, FormatSettings());
payload_size += getLengthEncodedStringSize(ostr.str());
serialized.push_back(std::move(ostr.str()));
}
}
}
protected:
size_t getPayloadSize() const override
{
@ -845,11 +874,18 @@ protected:
void writePayloadImpl(WriteBuffer & buffer) const override
{
for (const String & column : columns)
writeLengthEncodedString(column, buffer);
for (size_t i = 0; i < columns.size(); i++)
{
if (columns[i]->isNullAt(row_num))
buffer.write(serialized[i].data(), 1);
else
writeLengthEncodedString(serialized[i], buffer);
}
}
};
}
namespace Authentication
{

View File

@ -28,18 +28,15 @@ void MySQLOutputFormat::initialize()
initialized = true;
auto & header = getPort(PortKind::Main).getHeader();
data_types = header.getDataTypes();
if (header.columns())
{
packet_sender.sendPacket(LengthEncodedNumber(header.columns()));
for (const ColumnWithTypeAndName & column : header.getColumnsWithTypeAndName())
{
ColumnDefinition column_definition(column.name, CharacterSet::binary, 0, ColumnType::MYSQL_TYPE_STRING,
0, 0);
packet_sender.sendPacket(column_definition);
for (size_t i = 0; i < header.columns(); i++) {
const auto & column_name = header.getColumnsWithTypeAndName()[i].name;
packet_sender.sendPacket(getColumnDefinition(column_name, data_types[i]->getTypeId()));
}
if (!(context.mysql.client_capabilities & Capability::CLIENT_DEPRECATE_EOF))
@ -52,22 +49,9 @@ void MySQLOutputFormat::initialize()
void MySQLOutputFormat::consume(Chunk chunk)
{
initialize();
auto & header = getPort(PortKind::Main).getHeader();
size_t rows = chunk.getNumRows();
auto & columns = chunk.getColumns();
for (size_t i = 0; i < rows; i++)
for (size_t i = 0; i < chunk.getNumRows(); i++)
{
ResultsetRow row_packet;
for (size_t col = 0; col < columns.size(); ++col)
{
WriteBufferFromOwnString ostr;
header.getByPosition(col).type->serializeAsText(*columns[col], i, ostr, format_settings);
row_packet.appendColumn(std::move(ostr.str()));
}
ProtocolText::ResultsetRow row_packet(data_types, chunk.getColumns(), i);
packet_sender.sendPacket(row_packet);
}
}

View File

@ -37,6 +37,7 @@ private:
const Context & context;
MySQLProtocol::PacketSender packet_sender;
FormatSettings format_settings;
DataTypes data_types;
};
}

View File

@ -1,7 +1,7 @@
Columns:
a
Column types:
a BINARY
a BIGINT
Result:
0
1
@ -10,7 +10,7 @@ name
a
Column types:
name BINARY
a BINARY
a TINYINT
Result:
tables 1
Columns:
@ -18,6 +18,6 @@ a
b
Column types:
a BINARY
b BINARY
b TINYINT
Result:
тест 1

View File

@ -110,6 +110,17 @@ def test_mysql_client(mysql_client, server_address):
def test_python_client(server_address):
client = pymysql.connections.Connection(host=server_address, user='user_with_double_sha1', password='abacaba', database='default', port=server_port)
with pytest.raises(pymysql.InternalError) as exc_info:
client.query('select name from tables')
assert exc_info.value.args == (60, "Table default.tables doesn't exist.")
cursor = client.cursor(pymysql.cursors.DictCursor)
cursor.execute("select 1 as a, 'тест' as b")
assert cursor.fetchall() == [{'a': 1, 'b': 'тест'}]
with pytest.raises(pymysql.InternalError) as exc_info:
pymysql.connections.Connection(host=server_address, user='default', password='abacab', database='default', port=server_port)
@ -124,7 +135,7 @@ def test_python_client(server_address):
cursor = client.cursor(pymysql.cursors.DictCursor)
cursor.execute("select 1 as a, 'тест' as b")
assert cursor.fetchall() == [{'a': '1', 'b': 'тест'}]
assert cursor.fetchall() == [{'a': 1, 'b': 'тест'}]
client.select_db('system')
@ -140,11 +151,14 @@ def test_python_client(server_address):
cursor.execute("INSERT INTO table1 VALUES (1), (3)")
cursor.execute("INSERT INTO table1 VALUES (1), (4)")
cursor.execute("SELECT * FROM table1 ORDER BY a")
assert cursor.fetchall() == [{'a': '1'}, {'a': '1'}, {'a': '3'}, {'a': '4'}]
assert cursor.fetchall() == [{'a': 1}, {'a': 1}, {'a': 3}, {'a': 4}]
def test_golang_client(server_address, golang_container):
# type: (str, Container) -> None
with open(os.path.join(SCRIPT_DIR, 'clients', 'golang', '0.reference')) as fp:
reference = fp.read()
code, (stdout, stderr) = golang_container.exec_run('./main --host {host} --port {port} --user default --password 123 --database '
'abc'.format(host=server_address, port=server_port), demux=True)
@ -155,10 +169,12 @@ def test_golang_client(server_address, golang_container):
'default'.format(host=server_address, port=server_port), demux=True)
assert code == 0
assert stdout == reference
with open(os.path.join(SCRIPT_DIR, 'clients', 'golang', '0.reference')) as fp:
reference = fp.read()
assert stdout == reference
code, (stdout, stderr) = golang_container.exec_run('./main --host {host} --port {port} --user user_with_double_sha1 --password abacaba --database '
'default'.format(host=server_address, port=server_port), demux=True)
assert code == 0
assert stdout == reference
def test_php_client(server_address, php_container):
@ -171,6 +187,14 @@ def test_php_client(server_address, php_container):
assert code == 0
assert stdout == 'tables\n'
code, (stdout, stderr) = php_container.exec_run('php -f test.php {host} {port} user_with_double_sha1 abacaba'.format(host=server_address, port=server_port), demux=True)
assert code == 0
assert stdout == 'tables\n'
code, (stdout, stderr) = php_container.exec_run('php -f test_ssl.php {host} {port} user_with_double_sha1 abacaba'.format(host=server_address, port=server_port), demux=True)
assert code == 0
assert stdout == 'tables\n'
def test_mysqljs_client(server_address, nodejs_container):
code, (_, stderr) = nodejs_container.exec_run('node test.js {host} {port} default 123'.format(host=server_address, port=server_port), demux=True)