mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-12-13 18:02:24 +00:00
211 lines
6.3 KiB
Python
211 lines
6.3 KiB
Python
# Copyright (c) 2019 Vitaliy Zakaznikov
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import os
|
|
import pty
|
|
import re
|
|
import time
|
|
from queue import Queue, Empty
|
|
from subprocess import Popen
|
|
from threading import Thread, Event
|
|
|
|
|
|
class TimeoutError(Exception):
|
|
def __init__(self, timeout):
|
|
self.timeout = timeout
|
|
|
|
def __str__(self):
|
|
return 'Timeout %.3fs' % float(self.timeout)
|
|
|
|
|
|
class ExpectTimeoutError(Exception):
|
|
def __init__(self, pattern, timeout, buffer):
|
|
self.pattern = pattern
|
|
self.timeout = timeout
|
|
self.buffer = buffer
|
|
|
|
def __str__(self):
|
|
s = 'Timeout %.3fs ' % float(self.timeout)
|
|
if self.pattern:
|
|
s += 'for %s ' % repr(self.pattern.pattern)
|
|
if self.buffer:
|
|
s += 'buffer %s ' % repr(self.buffer[:])
|
|
s += 'or \'%s\'' % ','.join(['%x' % ord(c) for c in self.buffer[:]])
|
|
return s
|
|
|
|
|
|
class IO(object):
|
|
class EOF(object):
|
|
pass
|
|
|
|
class Timeout(object):
|
|
pass
|
|
|
|
EOF = EOF
|
|
TIMEOUT = Timeout
|
|
|
|
class Logger(object):
|
|
def __init__(self, logger, prefix=''):
|
|
self._logger = logger
|
|
self._prefix = prefix
|
|
|
|
def write(self, data):
|
|
self._logger.write(('\n' + data).replace('\n', '\n' + self._prefix))
|
|
|
|
def flush(self):
|
|
self._logger.flush()
|
|
|
|
def __init__(self, process, master, queue, reader):
|
|
self.process = process
|
|
self.master = master
|
|
self.queue = queue
|
|
self.buffer = None
|
|
self.before = None
|
|
self.after = None
|
|
self.match = None
|
|
self.pattern = None
|
|
self.reader = reader
|
|
self._timeout = None
|
|
self._logger = None
|
|
self._eol = ''
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, type, value, traceback):
|
|
self.close()
|
|
|
|
def logger(self, logger=None, prefix=''):
|
|
if logger:
|
|
self._logger = self.Logger(logger, prefix=prefix)
|
|
return self._logger
|
|
|
|
def timeout(self, timeout=None):
|
|
if timeout:
|
|
self._timeout = timeout
|
|
return self._timeout
|
|
|
|
def eol(self, eol=None):
|
|
if eol:
|
|
self._eol = eol
|
|
return self._eol
|
|
|
|
def close(self, force=True):
|
|
self.reader['kill_event'].set()
|
|
os.system('pkill -TERM -P %d' % self.process.pid)
|
|
if force:
|
|
self.process.kill()
|
|
else:
|
|
self.process.terminate()
|
|
os.close(self.master)
|
|
if self._logger:
|
|
self._logger.write('\n')
|
|
self._logger.flush()
|
|
|
|
def send(self, data, eol=None):
|
|
if eol is None:
|
|
eol = self._eol
|
|
return self.write(data + eol)
|
|
|
|
def write(self, data):
|
|
return os.write(self.master, data.encode())
|
|
|
|
def expect(self, pattern, timeout=None, escape=False):
|
|
self.match = None
|
|
self.before = None
|
|
self.after = None
|
|
if escape:
|
|
pattern = re.escape(pattern)
|
|
pattern = re.compile(pattern)
|
|
if timeout is None:
|
|
timeout = self._timeout
|
|
timeleft = timeout
|
|
while True:
|
|
start_time = time.time()
|
|
if self.buffer is not None:
|
|
self.match = pattern.search(self.buffer, 0)
|
|
if self.match is not None:
|
|
self.after = self.buffer[self.match.start():self.match.end()]
|
|
self.before = self.buffer[:self.match.start()]
|
|
self.buffer = self.buffer[self.match.end():]
|
|
break
|
|
if timeleft < 0:
|
|
break
|
|
try:
|
|
data = self.read(timeout=timeleft, raise_exception=True)
|
|
except TimeoutError:
|
|
if self._logger:
|
|
self._logger.write((self.buffer or '') + '\n')
|
|
self._logger.flush()
|
|
exception = ExpectTimeoutError(pattern, timeout, self.buffer)
|
|
self.buffer = None
|
|
raise exception
|
|
timeleft -= (time.time() - start_time)
|
|
if data:
|
|
self.buffer = (self.buffer + data) if self.buffer else data
|
|
if self._logger:
|
|
self._logger.write((self.before or '') + (self.after or ''))
|
|
self._logger.flush()
|
|
if self.match is None:
|
|
exception = ExpectTimeoutError(pattern, timeout, self.buffer)
|
|
self.buffer = None
|
|
raise exception
|
|
return self.match
|
|
|
|
def read(self, timeout=0, raise_exception=False):
|
|
data = ''
|
|
timeleft = timeout
|
|
try:
|
|
while timeleft >= 0:
|
|
start_time = time.time()
|
|
data += self.queue.get(timeout=timeleft)
|
|
if data:
|
|
break
|
|
timeleft -= (time.time() - start_time)
|
|
except Empty:
|
|
if data:
|
|
return data
|
|
if raise_exception:
|
|
raise TimeoutError(timeout)
|
|
pass
|
|
if not data and raise_exception:
|
|
raise TimeoutError(timeout)
|
|
|
|
return data
|
|
|
|
|
|
def spawn(command):
|
|
master, slave = pty.openpty()
|
|
process = Popen(command, preexec_fn=os.setsid, stdout=slave, stdin=slave, stderr=slave, bufsize=1)
|
|
os.close(slave)
|
|
|
|
queue = Queue()
|
|
reader_kill_event = Event()
|
|
thread = Thread(target=reader, args=(process, master, queue, reader_kill_event))
|
|
thread.daemon = True
|
|
thread.start()
|
|
|
|
return IO(process, master, queue, reader={'thread': thread, 'kill_event': reader_kill_event})
|
|
|
|
|
|
def reader(process, out, queue, kill_event):
|
|
while True:
|
|
try:
|
|
# TODO: there are some issues with 1<<16 buffer size
|
|
data = os.read(out, 1<<17).decode(errors='replace')
|
|
queue.put(data)
|
|
except:
|
|
if kill_event.is_set():
|
|
break
|
|
raise
|