ClickHouse/tests/queries/0_stateless/helpers/uexpect.py
2020-10-02 19:54:07 +03:00

207 lines
6.2 KiB
Python

# Copyright (c) 2019 Vitaliy Zakaznikov
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pty
import time
import sys
import re
from threading import Thread, Event
from subprocess import Popen
from queue import Queue, Empty
class TimeoutError(Exception):
def __init__(self, timeout):
self.timeout = timeout
def __str__(self):
return 'Timeout %.3fs' % float(self.timeout)
class ExpectTimeoutError(Exception):
def __init__(self, pattern, timeout, buffer):
self.pattern = pattern
self.timeout = timeout
self.buffer = buffer
def __str__(self):
s = 'Timeout %.3fs ' % float(self.timeout)
if self.pattern:
s += 'for %s ' % repr(self.pattern.pattern)
if self.buffer:
s += 'buffer %s' % repr(self.buffer[:])
#s += ' or \'%s\'' % ','.join(['%x' % ord(c) for c in self.buffer[:]])
return s
class IO(object):
class EOF(object):
pass
class Timeout(object):
pass
EOF = EOF
TIMEOUT = Timeout
class Logger(object):
def __init__(self, logger, prefix=''):
self._logger = logger
self._prefix = prefix
def write(self, data):
self._logger.write(('\n' + data).replace('\n','\n' + self._prefix))
def flush(self):
self._logger.flush()
def __init__(self, process, master, queue, reader):
self.process = process
self.master = master
self.queue = queue
self.buffer = None
self.before = None
self.after = None
self.match = None
self.pattern = None
self.reader = reader
self._timeout = None
self._logger = None
self._eol = ''
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
self.close()
def logger(self, logger=None, prefix=''):
if logger:
self._logger = self.Logger(logger, prefix=prefix)
return self._logger
def timeout(self, timeout=None):
if timeout:
self._timeout = timeout
return self._timeout
def eol(self, eol=None):
if eol:
self._eol = eol
return self._eol
def close(self, force=True):
self.reader['kill_event'].set()
os.system('pkill -TERM -P %d' % self.process.pid)
if force:
self.process.kill()
else:
self.process.terminate()
os.close(self.master)
if self._logger:
self._logger.write('\n')
self._logger.flush()
def send(self, data, eol=None):
if eol is None:
eol = self._eol
return self.write(data + eol)
def write(self, data):
return os.write(self.master, data.encode())
def expect(self, pattern, timeout=None, escape=False):
self.match = None
self.before = None
self.after = None
if escape:
pattern = re.escape(pattern)
pattern = re.compile(pattern)
if timeout is None:
timeout = self._timeout
timeleft = timeout
while True:
start_time = time.time()
if self.buffer is not None:
self.match = pattern.search(self.buffer, 0)
if self.match is not None:
self.after = self.buffer[self.match.start():self.match.end()]
self.before = self.buffer[:self.match.start()]
self.buffer = self.buffer[self.match.end():]
break
if timeleft < 0:
break
try:
data = self.read(timeout=timeleft, raise_exception=True)
except TimeoutError:
if self._logger:
self._logger.write((self.buffer or '') + '\n')
self._logger.flush()
exception = ExpectTimeoutError(pattern, timeout, self.buffer)
self.buffer = None
raise exception
timeleft -= (time.time() - start_time)
if data:
self.buffer = (self.buffer + data) if self.buffer else data
if self._logger:
self._logger.write((self.before or '') + (self.after or ''))
self._logger.flush()
if self.match is None:
exception = ExpectTimeoutError(pattern, timeout, self.buffer)
self.buffer = None
raise exception
return self.match
def read(self, timeout=0, raise_exception=False):
data = ''
timeleft = timeout
try:
while timeleft >= 0 :
start_time = time.time()
data += self.queue.get(timeout=timeleft)
if data:
break
timeleft -= (time.time() - start_time)
except Empty:
if data:
return data
if raise_exception:
raise TimeoutError(timeout)
pass
if not data and raise_exception:
raise TimeoutError(timeout)
return data
def spawn(command):
master, slave = pty.openpty()
process = Popen(command, preexec_fn=os.setsid, stdout=slave, stdin=slave, stderr=slave, bufsize=1)
os.close(slave)
queue = Queue()
reader_kill_event = Event()
thread = Thread(target=reader, args=(process, master, queue, reader_kill_event))
thread.daemon = True
thread.start()
return IO(process, master, queue, reader={'thread':thread, 'kill_event':reader_kill_event})
def reader(process, out, queue, kill_event):
while True:
try:
data = os.read(out, 65536).decode(errors='replace')
queue.put(data)
except:
if kill_event.is_set():
break
raise