mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-12-15 02:41:59 +00:00
223 lines
6.3 KiB
Python
223 lines
6.3 KiB
Python
# Copyright (c) 2019 Vitaliy Zakaznikov
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import os
|
|
import pty
|
|
import re
|
|
import time
|
|
from queue import Empty, Queue
|
|
from subprocess import Popen
|
|
from threading import Event, Thread
|
|
|
|
|
|
class TimeoutError(Exception):
|
|
def __init__(self, timeout):
|
|
self.timeout = timeout
|
|
|
|
def __str__(self):
|
|
return "Timeout %.3fs" % float(self.timeout)
|
|
|
|
|
|
class ExpectTimeoutError(Exception):
|
|
def __init__(self, pattern, timeout, buffer):
|
|
self.pattern = pattern
|
|
self.timeout = timeout
|
|
self.buffer = buffer
|
|
|
|
def __str__(self):
|
|
s = "Timeout %.3fs " % float(self.timeout)
|
|
if self.pattern:
|
|
s += "for %s " % repr(self.pattern.pattern)
|
|
if self.buffer:
|
|
s += "buffer %s " % repr(self.buffer[:])
|
|
s += "or '%s'" % ",".join(["%x" % ord(c) for c in self.buffer[:]])
|
|
return s
|
|
|
|
|
|
class IO(object):
|
|
class EOF(object):
|
|
pass
|
|
|
|
class Timeout(object):
|
|
pass
|
|
|
|
EOF = EOF
|
|
TIMEOUT = Timeout
|
|
|
|
class Logger(object):
|
|
def __init__(self, logger, prefix=""):
|
|
self._logger = logger
|
|
self._prefix = prefix
|
|
|
|
def write(self, data):
|
|
self._logger.write(("\n" + data).replace("\n", "\n" + self._prefix))
|
|
|
|
def flush(self):
|
|
self._logger.flush()
|
|
|
|
def __init__(self, process, master, queue, reader):
|
|
self.process = process
|
|
self.master = master
|
|
self.queue = queue
|
|
self.buffer = None
|
|
self.before = None
|
|
self.after = None
|
|
self.match = None
|
|
self.pattern = None
|
|
self.reader = reader
|
|
self._timeout = None
|
|
self._logger = None
|
|
self._eol = ""
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, type, value, traceback):
|
|
self.close()
|
|
|
|
def logger(self, logger=None, prefix=""):
|
|
if logger:
|
|
self._logger = self.Logger(logger, prefix=prefix)
|
|
return self._logger
|
|
|
|
def timeout(self, timeout=None):
|
|
if timeout:
|
|
self._timeout = timeout
|
|
return self._timeout
|
|
|
|
def eol(self, eol=None):
|
|
if eol:
|
|
self._eol = eol
|
|
return self._eol
|
|
|
|
def close(self, force=True):
|
|
self.reader["kill_event"].set()
|
|
os.system("pkill -TERM -P %d" % self.process.pid)
|
|
if force:
|
|
self.process.kill()
|
|
else:
|
|
self.process.terminate()
|
|
os.close(self.master)
|
|
if self._logger:
|
|
self._logger.write("\n")
|
|
self._logger.flush()
|
|
|
|
def send(self, data, eol=None):
|
|
if eol is None:
|
|
eol = self._eol
|
|
return self.write(data + eol)
|
|
|
|
def write(self, data):
|
|
return os.write(self.master, data.encode())
|
|
|
|
def expect(self, pattern, timeout=None, escape=False):
|
|
self.match = None
|
|
self.before = None
|
|
self.after = None
|
|
if escape:
|
|
pattern = re.escape(pattern)
|
|
pattern = re.compile(pattern)
|
|
if timeout is None:
|
|
timeout = self._timeout
|
|
timeleft = timeout
|
|
while True:
|
|
start_time = time.time()
|
|
if self.buffer is not None:
|
|
self.match = pattern.search(self.buffer, 0)
|
|
if self.match is not None:
|
|
self.after = self.buffer[self.match.start() : self.match.end()]
|
|
self.before = self.buffer[: self.match.start()]
|
|
self.buffer = self.buffer[self.match.end() :]
|
|
break
|
|
if timeleft < 0:
|
|
break
|
|
try:
|
|
data = self.read(timeout=timeleft, raise_exception=True)
|
|
except TimeoutError:
|
|
if self._logger:
|
|
self._logger.write((self.buffer or "") + "\n")
|
|
self._logger.flush()
|
|
exception = ExpectTimeoutError(pattern, timeout, self.buffer)
|
|
self.buffer = None
|
|
raise exception
|
|
timeleft -= time.time() - start_time
|
|
if data:
|
|
self.buffer = (self.buffer + data) if self.buffer else data
|
|
if self._logger:
|
|
self._logger.write((self.before or "") + (self.after or ""))
|
|
self._logger.flush()
|
|
if self.match is None:
|
|
exception = ExpectTimeoutError(pattern, timeout, self.buffer)
|
|
self.buffer = None
|
|
raise exception
|
|
return self.match
|
|
|
|
def read(self, timeout=0, raise_exception=False):
|
|
data = ""
|
|
timeleft = timeout
|
|
try:
|
|
while timeleft >= 0:
|
|
start_time = time.time()
|
|
data += self.queue.get(timeout=timeleft)
|
|
if data:
|
|
break
|
|
timeleft -= time.time() - start_time
|
|
except Empty:
|
|
if data:
|
|
return data
|
|
if raise_exception:
|
|
raise TimeoutError(timeout)
|
|
pass
|
|
if not data and raise_exception:
|
|
raise TimeoutError(timeout)
|
|
|
|
return data
|
|
|
|
|
|
def spawn(command):
|
|
master, slave = pty.openpty()
|
|
process = Popen(
|
|
command,
|
|
preexec_fn=os.setsid,
|
|
stdout=slave,
|
|
stdin=slave,
|
|
stderr=slave,
|
|
bufsize=1,
|
|
)
|
|
os.close(slave)
|
|
|
|
queue = Queue()
|
|
reader_kill_event = Event()
|
|
thread = Thread(target=reader, args=(process, master, queue, reader_kill_event))
|
|
thread.daemon = True
|
|
thread.start()
|
|
|
|
return IO(
|
|
process,
|
|
master,
|
|
queue,
|
|
reader={"thread": thread, "kill_event": reader_kill_event},
|
|
)
|
|
|
|
|
|
def reader(process, out, queue, kill_event):
|
|
while True:
|
|
try:
|
|
# TODO: there are some issues with 1<<16 buffer size
|
|
data = os.read(out, 1 << 17).decode(errors="replace")
|
|
queue.put(data)
|
|
except:
|
|
if kill_event.is_set():
|
|
break
|
|
raise
|