support of sessions and default_database in MySQL wire protocol

This commit is contained in:
Yuriy 2019-06-16 18:12:37 +03:00
parent 73f282049e
commit 2e29ea7b2e
2 changed files with 9 additions and 5 deletions

View File

@ -48,6 +48,7 @@ MySQLHandler::MySQLHandler(IServer & server_, const Poco::Net::StreamSocket & so
void MySQLHandler::run()
{
connection_context = server.context();
connection_context.setSessionContext(connection_context);
connection_context.setDefaultFormat("MySQLWire");
in = std::make_shared<ReadBufferFromPocoSocket>(socket());
@ -306,7 +307,7 @@ void MySQLHandler::authenticate(const HandshakeResponse & handshake_response, co
try
{
connection_context.setUser(handshake_response.username, password, socket().address(), "");
connection_context.setCurrentDatabase(handshake_response.database);
if (!handshake_response.database.empty()) connection_context.setCurrentDatabase(handshake_response.database);
connection_context.setCurrentQueryId("");
LOG_ERROR(log, "Authentication for user " << handshake_response.username << " succeeded.");
}

View File

@ -51,7 +51,7 @@ def test_mysql_client(mysql_client, server_address):
-e "SELECT 'тест' as b;"
'''.format(host=server_address, port=server_port), demux=True)
assert stdout == 'a\n1\nb\nтест\n'
assert stdout == '\n'.join(['a', '1', 'b', 'тест', ''])
code, (stdout, stderr) = mysql_client.exec_run('''
mysql --protocol tcp -h {host} -P {port} default -u default --password=abc -e "select 1 as a;"
@ -75,14 +75,17 @@ def test_mysql_client(mysql_client, server_address):
mysql --protocol tcp -h {host} -P {port} default -u default --password=123
-e "CREATE DATABASE x;"
-e "USE x;"
-e "CREATE TABLE table1 (a UInt32) ENGINE = Memory;"
-e "CREATE TABLE table1 (column UInt32) ENGINE = Memory;"
-e "INSERT INTO table1 VALUES (0), (1), (5);"
-e "INSERT INTO table1 VALUES (0), (1), (5);"
-e "SELECT * FROM table1 ORDER BY a;"
-e "SELECT * FROM table1 ORDER BY column;"
-e "DROP DATABASE x;"
-e "CREATE TEMPORARY TABLE tmp (tmp_column UInt32);"
-e "INSERT INTO tmp VALUES (0), (1);"
-e "SELECT * FROM tmp ORDER BY tmp_column;"
'''.format(host=server_address, port=server_port), demux=True)
assert stdout == 'a\n0\n0\n1\n1\n5\n5\n'
assert stdout == '\n'.join(['column', '0', '0', '1', '1', '5', '5', 'tmp_column', '0', '1', ''])
def test_python_client(server_address):