mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-10 01:25:21 +00:00
286 lines
8.3 KiB
Python
286 lines
8.3 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
import datetime
|
|
import logging
|
|
import pyodbc
|
|
import sqlite3
|
|
import traceback
|
|
import enum
|
|
import random
|
|
import string
|
|
from contextlib import contextmanager
|
|
|
|
from exceptions import ProgramError
|
|
|
|
|
|
logger = logging.getLogger("connection")
|
|
logger.setLevel(logging.DEBUG)
|
|
|
|
|
|
class OdbcConnectingArgs:
|
|
def __init__(self, **kwargs):
|
|
self._kwargs = kwargs
|
|
|
|
def __str__(self):
|
|
conn_str = ";".join(
|
|
["{}={}".format(x, y) for x, y in self._kwargs.items() if y]
|
|
)
|
|
return conn_str
|
|
|
|
def update_database(self, database):
|
|
self._kwargs["Database"] = database
|
|
|
|
@staticmethod
|
|
def create_from_kw(
|
|
dsn="", server="localhost", user="default", database="default", **kwargs
|
|
):
|
|
conn_args = {
|
|
"DSN": dsn,
|
|
"Server": server,
|
|
"User": user,
|
|
"Database": database,
|
|
}
|
|
conn_args.update(kwargs)
|
|
return OdbcConnectingArgs(**conn_args)
|
|
|
|
@staticmethod
|
|
def create_from_connection_string(conn_str):
|
|
args = OdbcConnectingArgs()
|
|
for kv in conn_str.split(";"):
|
|
if kv:
|
|
k, v = kv.split("=", 1)
|
|
args._kwargs[k] = v
|
|
return args
|
|
|
|
|
|
def _random_str(length=8):
|
|
alphabet = string.ascii_lowercase + string.digits
|
|
return "".join(random.SystemRandom().choice(alphabet) for _ in range(length))
|
|
|
|
|
|
def default_clickhouse_odbc_conn_str():
|
|
return str(
|
|
OdbcConnectingArgs.create_from_kw(
|
|
dsn="ClickHouse DSN (ANSI)",
|
|
)
|
|
)
|
|
|
|
|
|
class Engines(enum.Enum):
|
|
SQLITE = enum.auto()
|
|
ODBC = enum.auto()
|
|
|
|
@staticmethod
|
|
def list():
|
|
return list(map(lambda c: c.name.lower(), Engines))
|
|
|
|
|
|
class KnownDBMS(str, enum.Enum):
|
|
sqlite = "sqlite"
|
|
clickhouse = "ClickHouse"
|
|
|
|
|
|
class ConnectionWrap(object):
|
|
def __init__(self, connection=None, factory=None, factory_kwargs=None):
|
|
self._factory = factory
|
|
self._factory_kwargs = factory_kwargs
|
|
self._connection = connection
|
|
|
|
self.DBMS_NAME = None
|
|
self.DATABASE_NAME = None
|
|
self.USER_NAME = None
|
|
|
|
@staticmethod
|
|
def create(connection):
|
|
return ConnectionWrap(connection=connection)
|
|
|
|
@staticmethod
|
|
def create_form_factory(factory, factory_kwargs):
|
|
return ConnectionWrap(
|
|
factory=factory, factory_kwargs=factory_kwargs
|
|
).reconnect()
|
|
|
|
def can_reconnect(self):
|
|
return self._factory is not None
|
|
|
|
def reconnect(self):
|
|
if self._connection is not None:
|
|
self._connection.close()
|
|
self._connection = self._factory(self._factory_kwargs)
|
|
return self
|
|
|
|
def assert_can_reconnect(self):
|
|
assert self.can_reconnect(), f"no reconnect for: {self.DBMS_NAME}"
|
|
|
|
def __getattr__(self, item):
|
|
return getattr(self._connection, item)
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def drop_all_tables(self):
|
|
if self.DBMS_NAME == KnownDBMS.clickhouse.value:
|
|
list_query = (
|
|
f"SELECT name FROM system.tables WHERE database='{self.DATABASE_NAME}'"
|
|
)
|
|
elif self.DBMS_NAME == KnownDBMS.sqlite.value:
|
|
list_query = f"SELECT name FROM sqlite_master WHERE type='table'"
|
|
else:
|
|
logger.warning(
|
|
"unable to drop all tables for unknown database: %s", self.DBMS_NAME
|
|
)
|
|
return
|
|
|
|
list_result = execute_request(list_query, self)
|
|
logger.info("tables will be dropped: %s", list_result.get_result())
|
|
for table_name in list_result.get_result():
|
|
table_name = table_name[0]
|
|
execute_request(f"DROP TABLE {table_name}", self).assert_no_exception()
|
|
logger.debug("success drop table: %s", table_name)
|
|
|
|
def _use_database(self, database="default"):
|
|
if self.DBMS_NAME == KnownDBMS.clickhouse.value:
|
|
logger.info("use test database: %s", database)
|
|
self._factory_kwargs.update_database(database)
|
|
self.reconnect()
|
|
self.DATABASE_NAME = database
|
|
|
|
def use_random_database(self):
|
|
if self.DBMS_NAME == KnownDBMS.clickhouse.value:
|
|
database = f"test_{_random_str()}"
|
|
execute_request(f"CREATE DATABASE {database}", self).assert_no_exception()
|
|
self._use_database(database)
|
|
logger.info(
|
|
"currentDatabase : %s",
|
|
execute_request(f"SELECT currentDatabase()", self).get_result(),
|
|
)
|
|
|
|
@contextmanager
|
|
def with_one_test_scope(self):
|
|
try:
|
|
yield self
|
|
finally:
|
|
self.drop_all_tables()
|
|
|
|
@contextmanager
|
|
def with_test_database_scope(self):
|
|
self.use_random_database()
|
|
try:
|
|
yield self
|
|
finally:
|
|
self._use_database()
|
|
|
|
def __exit__(self, *args):
|
|
if hasattr(self._connection, "close"):
|
|
return self._connection.close()
|
|
|
|
|
|
def setup_connection(engine, conn_str=None, make_debug_request=True):
|
|
connection = None
|
|
|
|
if isinstance(engine, str):
|
|
engine = Engines[engine.upper()]
|
|
|
|
if engine == Engines.ODBC:
|
|
if conn_str is None:
|
|
raise ProgramError("conn_str has to be set up for ODBC connection")
|
|
|
|
logger.debug("Drivers: %s", pyodbc.drivers())
|
|
logger.debug("DataSources: %s", pyodbc.dataSources())
|
|
logger.debug("Connection string: %s", conn_str)
|
|
|
|
conn_args = OdbcConnectingArgs.create_from_connection_string(conn_str)
|
|
|
|
connection = ConnectionWrap.create_form_factory(
|
|
factory=lambda args: pyodbc.connect(str(args)),
|
|
factory_kwargs=conn_args,
|
|
)
|
|
connection.add_output_converter(pyodbc.SQL_UNKNOWN_TYPE, lambda x: None)
|
|
|
|
connection.DBMS_NAME = connection.getinfo(pyodbc.SQL_DBMS_NAME)
|
|
connection.DATABASE_NAME = connection.getinfo(pyodbc.SQL_DATABASE_NAME)
|
|
connection.USER_NAME = connection.getinfo(pyodbc.SQL_USER_NAME)
|
|
|
|
elif engine == Engines.SQLITE:
|
|
conn_str = conn_str if conn_str is not None else ":memory:"
|
|
connection = ConnectionWrap.create(sqlite3.connect(conn_str))
|
|
|
|
connection.DBMS_NAME = "sqlite"
|
|
connection.DATABASE_NAME = "main"
|
|
connection.USER_NAME = "default"
|
|
|
|
logger.info(
|
|
"Connection info: DBMS name %s, database %s, user %s",
|
|
connection.DBMS_NAME,
|
|
connection.DATABASE_NAME,
|
|
connection.USER_NAME,
|
|
)
|
|
|
|
if make_debug_request:
|
|
request = "SELECT 1"
|
|
logger.debug("Make debug request to the connection: %s", request)
|
|
result = execute_request(request, connection)
|
|
logger.debug("Debug request returned: %s", result.get_result())
|
|
|
|
logger.debug("Connection is ok")
|
|
return connection
|
|
|
|
|
|
class ExecResult:
|
|
def __init__(self):
|
|
self._exception = None
|
|
self._result = None
|
|
self._description = None
|
|
|
|
def as_exception(self, exc):
|
|
self._exception = exc
|
|
return self
|
|
|
|
def get_result(self):
|
|
self.assert_no_exception()
|
|
return self._result
|
|
|
|
def get_description(self):
|
|
self.assert_no_exception()
|
|
return self._description
|
|
|
|
def as_ok(self, rows=None, description=None):
|
|
if rows is None:
|
|
self._result = True
|
|
return self
|
|
self._result = rows
|
|
self._description = description
|
|
return self
|
|
|
|
def get_exception(self):
|
|
return self._exception
|
|
|
|
def has_exception(self):
|
|
return self._exception is not None
|
|
|
|
def assert_no_exception(self):
|
|
if self.has_exception():
|
|
raise ProgramError(
|
|
f"request doesn't have a result set, it has the exception",
|
|
parent=self._exception,
|
|
)
|
|
|
|
|
|
def execute_request(request, connection):
|
|
cursor = connection.cursor()
|
|
try:
|
|
cursor.execute(request)
|
|
if cursor.description:
|
|
logging.debug("request has a description %s", cursor.description)
|
|
rows = cursor.fetchall()
|
|
connection.commit()
|
|
return ExecResult().as_ok(rows=rows, description=cursor.description)
|
|
else:
|
|
logging.debug("request doesn't have a description")
|
|
connection.commit()
|
|
return ExecResult().as_ok()
|
|
except (pyodbc.Error, sqlite3.DatabaseError) as err:
|
|
return ExecResult().as_exception(err)
|
|
finally:
|
|
cursor.close()
|