ClickHouse/tests/queries/0_stateless/helpers/uexpect.py
2024-09-27 10:19:49 +00:00

223 lines
6.3 KiB
Python

# Copyright (c) 2019 Vitaliy Zakaznikov
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pty
import re
import sys
import time
from queue import Empty, Queue
from subprocess import Popen
from threading import Event, Thread
class TimeoutError(Exception):
def __init__(self, timeout):
self.timeout = timeout
def __str__(self):
return "Timeout %.3fs" % float(self.timeout)
class ExpectTimeoutError(Exception):
def __init__(self, pattern, timeout, buffer):
self.pattern = pattern
self.timeout = timeout
self.buffer = buffer
def __str__(self):
s = "Timeout %.3fs " % float(self.timeout)
if self.pattern:
s += "for %s " % repr(self.pattern.pattern)
if self.buffer:
s += "buffer %s" % repr(self.buffer[:])
# s += ' or \'%s\'' % ','.join(['%x' % ord(c) for c in self.buffer[:]])
return s
class IO(object):
class EOF(object):
pass
class Timeout(object):
pass
EOF = EOF
TIMEOUT = Timeout
class Logger(object):
def __init__(self, logger, prefix=""):
self._logger = logger
self._prefix = prefix
def write(self, data):
self._logger.write(("\n" + data).replace("\n", "\n" + self._prefix))
def flush(self):
self._logger.flush()
def __init__(self, process, master, queue, reader):
self.process = process
self.master = master
self.queue = queue
self.buffer = None
self.before = None
self.after = None
self.match = None
self.pattern = None
self.reader = reader
self._timeout = None
self._logger = None
self._eol = ""
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
self.close()
def logger(self, logger=None, prefix=""):
if logger:
self._logger = self.Logger(logger, prefix=prefix)
return self._logger
def timeout(self, timeout=None):
if timeout:
self._timeout = timeout
return self._timeout
def eol(self, eol=None):
if eol:
self._eol = eol
return self._eol
def close(self, force=True):
self.reader["kill_event"].set()
os.system("pkill -TERM -P %d" % self.process.pid)
if force:
self.process.kill()
else:
self.process.terminate()
os.close(self.master)
if self._logger:
self._logger.write("\n")
self._logger.flush()
def send(self, data, eol=None):
if eol is None:
eol = self._eol
return self.write(data + eol)
def write(self, data):
return os.write(self.master, data.encode())
def expect(self, pattern, timeout=None, escape=False):
self.match = None
self.before = None
self.after = None
if escape:
pattern = re.escape(pattern)
pattern = re.compile(pattern)
if timeout is None:
timeout = self._timeout
timeleft = timeout
while True:
start_time = time.time()
if self.buffer is not None:
self.match = pattern.search(self.buffer, 0)
if self.match is not None:
self.after = self.buffer[self.match.start() : self.match.end()]
self.before = self.buffer[: self.match.start()]
self.buffer = self.buffer[self.match.end() :]
break
if timeleft < 0:
break
try:
data = self.read(timeout=timeleft, raise_exception=True)
except TimeoutError:
if self._logger:
self._logger.write((self.buffer or "") + "\n")
self._logger.flush()
exception = ExpectTimeoutError(pattern, timeout, self.buffer)
self.buffer = None
raise exception
timeleft -= time.time() - start_time
if data:
self.buffer = (self.buffer + data) if self.buffer else data
if self._logger:
self._logger.write((self.before or "") + (self.after or ""))
self._logger.flush()
if self.match is None:
exception = ExpectTimeoutError(pattern, timeout, self.buffer)
self.buffer = None
raise exception
return self.match
def read(self, timeout=0, raise_exception=False):
data = ""
timeleft = timeout
try:
while timeleft >= 0:
start_time = time.time()
data += self.queue.get(timeout=timeleft)
if data:
break
timeleft -= time.time() - start_time
except Empty:
if data:
return data
if raise_exception:
raise TimeoutError(timeout)
pass
if not data and raise_exception:
raise TimeoutError(timeout)
return data
def spawn(command):
master, slave = pty.openpty()
process = Popen(
command,
preexec_fn=os.setsid,
stdout=slave,
stdin=slave,
stderr=slave,
bufsize=1,
)
os.close(slave)
queue = Queue()
reader_kill_event = Event()
thread = Thread(target=reader, args=(process, master, queue, reader_kill_event))
thread.daemon = True
thread.start()
return IO(
process,
master,
queue,
reader={"thread": thread, "kill_event": reader_kill_event},
)
def reader(process, out, queue, kill_event):
while True:
try:
data = os.read(out, 65536).decode(errors="replace")
queue.put(data)
except:
if kill_event.is_set():
break
raise