mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-12-11 00:42:29 +00:00
665 lines
21 KiB
Python
Executable File
665 lines
21 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
|
|
import logging
|
|
from enum import Enum
|
|
from functools import reduce
|
|
from hashlib import md5
|
|
from itertools import chain
|
|
|
|
# isort: off
|
|
# pylint:disable=import-error; for style check
|
|
import sqlglot
|
|
from sqlglot.expressions import ColumnDef, PrimaryKeyColumnConstraint
|
|
|
|
# pylint:enable=import-error; for style check
|
|
# isort: on
|
|
|
|
from exceptions import (
|
|
DataResultDiffer,
|
|
Error,
|
|
ErrorWithParent,
|
|
ProgramError,
|
|
QueryExecutionError,
|
|
)
|
|
|
|
logger = logging.getLogger("parser")
|
|
logger.setLevel(logging.DEBUG)
|
|
|
|
CONDITION_SKIP = "skipif"
|
|
CONDITION_ONLY = "onlyif"
|
|
|
|
|
|
# TODO replace assertions with raise exception
|
|
class TestFileFormatException(Error):
|
|
pass
|
|
|
|
|
|
class FileAndPos:
|
|
def __init__(self, file=None, pos=None):
|
|
self.file = file
|
|
self.pos = pos
|
|
|
|
def __str__(self):
|
|
return f"{self.file}:{self.pos}"
|
|
|
|
|
|
def check_conditions(conditions, dbms_name):
|
|
rules = {}
|
|
for rec in conditions:
|
|
key, val = rec
|
|
if key not in conditions:
|
|
rules[key] = []
|
|
rules[key].append(val)
|
|
if CONDITION_SKIP in rules:
|
|
if dbms_name in rules[CONDITION_SKIP]:
|
|
return False
|
|
if CONDITION_ONLY in rules:
|
|
if dbms_name not in rules[CONDITION_ONLY]:
|
|
return False
|
|
return True
|
|
|
|
|
|
class BlockType(Enum):
|
|
comments = 1
|
|
control = 2
|
|
statement = 3
|
|
query = 4
|
|
|
|
|
|
COMMENT_TOKENS = ["#"]
|
|
RESULT_SEPARATION_LINE = "----"
|
|
CONTROL_TOKENS = ["halt", "hash-threshold"]
|
|
|
|
CONDITIONS_TOKENS = [CONDITION_SKIP, CONDITION_ONLY]
|
|
STATEMENT_TOKEN = "statement"
|
|
QUERY_TOKEN = "query"
|
|
|
|
|
|
ACCEPTABLE_TYPES = {type(""): "T", type(1): "I", type(0.001): "R"}
|
|
|
|
|
|
def _is_comment_line(tokens):
|
|
return tokens and tokens[0][0] in COMMENT_TOKENS
|
|
|
|
|
|
def _is_separation_line(tokens):
|
|
return tokens and tokens[0] == RESULT_SEPARATION_LINE
|
|
|
|
|
|
def _is_control_line(tokens):
|
|
return tokens and tokens[0] in CONTROL_TOKENS
|
|
|
|
|
|
def _is_conditional_line(tokens):
|
|
return tokens and tokens[0] in CONDITIONS_TOKENS
|
|
|
|
|
|
def _is_statement_line(tokens):
|
|
return tokens and tokens[0] == STATEMENT_TOKEN
|
|
|
|
|
|
def _is_query_line(tokens):
|
|
return tokens and tokens[0] == QUERY_TOKEN
|
|
|
|
|
|
class FileBlockBase:
|
|
def __init__(self, parser, start, end):
|
|
self._parser = parser
|
|
self._start = start
|
|
self._end = end
|
|
|
|
def get_block_type(self):
|
|
pass
|
|
|
|
def get_pos(self):
|
|
return self._start + 1
|
|
|
|
@staticmethod
|
|
def __parse_request(test_file, start, end):
|
|
request_end = start
|
|
while request_end < end:
|
|
tokens = test_file.get_tokens(request_end)
|
|
if not tokens or _is_separation_line(tokens):
|
|
break
|
|
request_end += 1
|
|
request = test_file.get_tokens_from_lines(start, request_end)
|
|
logger.debug("slice request %s:%s end %s", start, request_end, end)
|
|
return " ".join(request), request_end
|
|
|
|
@staticmethod
|
|
def __parse_result(test_file, start, end):
|
|
result_end = start
|
|
while result_end < end:
|
|
tokens = test_file.get_tokens(result_end)
|
|
if not tokens:
|
|
break
|
|
result_end += 1
|
|
logger.debug("slice result %s:%s end %s", start, result_end, end)
|
|
result = test_file.get_tokens(start, result_end)
|
|
return result, result_end
|
|
|
|
@staticmethod
|
|
def convert_request(sql):
|
|
if sql.startswith("CREATE TABLE"):
|
|
result = sqlglot.transpile(sql, read="sqlite", write="clickhouse")[0]
|
|
pk_token = sqlglot.parse_one(result, read="clickhouse").find(
|
|
PrimaryKeyColumnConstraint
|
|
)
|
|
pk_string = "tuple()"
|
|
if pk_token is not None:
|
|
pk_string = str(pk_token.find_ancestor(ColumnDef).args["this"])
|
|
|
|
result += " ENGINE = MergeTree() ORDER BY " + pk_string
|
|
return result
|
|
elif "SELECT" in sql and "CAST" in sql and "NULL" in sql:
|
|
# convert `CAST (NULL as INTEGER)` to `CAST (NULL as Nullable(Int32))`
|
|
try:
|
|
ast = sqlglot.parse_one(sql, read="sqlite")
|
|
except sqlglot.errors.ParseError as err:
|
|
logger.info("cannot parse %s , error is %s", sql, err)
|
|
return sql
|
|
cast = ast.find(sqlglot.expressions.Cast)
|
|
# logger.info("found sql %s && %s && %s", sql, cast.sql(), cast.to.args)
|
|
if (
|
|
cast is not None
|
|
and cast.name == "NULL"
|
|
and ("nested" not in cast.to.args or not cast.to.args["nested"])
|
|
):
|
|
cast.args["to"] = sqlglot.expressions.DataType.build(
|
|
"NULLABLE", expressions=[cast.to]
|
|
)
|
|
new_sql = ast.sql("clickhouse")
|
|
# logger.info("convert from %s to %s", sql, new_sql)
|
|
return new_sql
|
|
return sql
|
|
|
|
@staticmethod
|
|
def parse_block(parser, start, end):
|
|
file_pos = FileAndPos(parser.get_test_name(), start + 1)
|
|
logger.debug("%s start %s end %s", file_pos, start, end)
|
|
|
|
block_type = BlockType.comments
|
|
conditions = []
|
|
controls = []
|
|
statement = None
|
|
query = None
|
|
request = []
|
|
result_line = None
|
|
result = []
|
|
|
|
line = start
|
|
while line < end:
|
|
tokens = parser.get_tokens(line)
|
|
|
|
if _is_comment_line(tokens):
|
|
pass
|
|
elif _is_conditional_line(tokens):
|
|
conditions.append(parser.get_tokens(line))
|
|
|
|
elif _is_control_line(tokens):
|
|
assert block_type in (BlockType.comments, BlockType.control)
|
|
block_type = BlockType.control
|
|
controls.append(parser.get_tokens(line))
|
|
|
|
elif _is_statement_line(tokens):
|
|
assert block_type in (BlockType.comments,)
|
|
block_type = BlockType.statement
|
|
statement = parser.get_tokens(line)
|
|
request, last_line = FileBlockBase.__parse_request(
|
|
parser, line + 1, end
|
|
)
|
|
if parser.dbms_name == "ClickHouse":
|
|
request = FileBlockBase.convert_request(request)
|
|
assert last_line == end
|
|
line = last_line
|
|
|
|
elif _is_query_line(tokens):
|
|
assert block_type in (BlockType.comments,)
|
|
block_type = BlockType.query
|
|
query = parser.get_tokens(line)
|
|
request, last_line = FileBlockBase.__parse_request(
|
|
parser, line + 1, end
|
|
)
|
|
if parser.dbms_name == "ClickHouse":
|
|
request = FileBlockBase.convert_request(request)
|
|
result_line = last_line
|
|
line = last_line
|
|
if line == end:
|
|
break
|
|
tokens = parser.get_tokens(line)
|
|
assert _is_separation_line(tokens), f"last_line {last_line}, end {end}"
|
|
result, last_line = FileBlockBase.__parse_result(parser, line + 1, end)
|
|
assert last_line == end
|
|
line = last_line
|
|
line += 1
|
|
|
|
if block_type == BlockType.comments:
|
|
return FileBlockComments(parser, start, end)
|
|
|
|
if block_type == BlockType.control:
|
|
return FileBlockControl(parser, start, end, conditions, controls)
|
|
|
|
if block_type == BlockType.statement:
|
|
return FileBlockStatement(
|
|
parser, start, end, conditions, statement, request
|
|
)
|
|
|
|
if block_type == BlockType.query:
|
|
block = FileBlockQuery(
|
|
parser, start, end, conditions, query, request, result_line
|
|
)
|
|
block.with_result(result)
|
|
return block
|
|
raise ValueError(f"Unknown block_type {block_type}")
|
|
|
|
def dump_to(self, output):
|
|
if output is None:
|
|
return
|
|
for line in range(self._start, self._end):
|
|
output.write(self._parser.get_line(line))
|
|
output.write("\n")
|
|
|
|
|
|
class FileBlockComments(FileBlockBase):
|
|
def get_block_type(self):
|
|
return BlockType.comments
|
|
|
|
|
|
class FileBlockControl(FileBlockBase):
|
|
def __init__(self, parser, start, end, conditions, control):
|
|
super().__init__(parser, start, end)
|
|
self.conditions = conditions
|
|
self.control = control
|
|
|
|
def get_block_type(self):
|
|
return BlockType.control
|
|
|
|
def get_conditions(self):
|
|
return self.conditions
|
|
|
|
|
|
class FileBlockStatement(FileBlockBase):
|
|
def __init__(self, parser, start, end, conditions, statement, request):
|
|
super().__init__(parser, start, end)
|
|
self.conditions = conditions
|
|
self.statement = statement
|
|
self.request = request
|
|
|
|
def get_block_type(self):
|
|
return BlockType.statement
|
|
|
|
def get_request(self):
|
|
return self.request
|
|
|
|
def get_conditions(self):
|
|
return self.conditions
|
|
|
|
def get_statement(self):
|
|
return self.statement
|
|
|
|
def expected_error(self):
|
|
return self.statement[1] == "error"
|
|
|
|
|
|
class FileBlockQuery(FileBlockBase):
|
|
def __init__(self, parser, start, end, conditions, query, request, result_line):
|
|
super().__init__(parser, start, end)
|
|
self.conditions = conditions
|
|
self.query = query
|
|
self.request = request
|
|
self.result = None
|
|
self.result_line = result_line
|
|
|
|
def get_block_type(self):
|
|
return BlockType.query
|
|
|
|
def get_request(self):
|
|
return self.request
|
|
|
|
def get_conditions(self):
|
|
return self.conditions
|
|
|
|
def get_query(self):
|
|
return self.query
|
|
|
|
def expected_error(self):
|
|
return " ".join(self.query[2:]).lower() if self.query[1] == "error" else None
|
|
|
|
def get_types(self):
|
|
if self.query[1] == "error":
|
|
raise TestFileFormatException(
|
|
"the query is expected to fail, there are no types"
|
|
)
|
|
return self.query[1]
|
|
|
|
def get_sort_mode(self):
|
|
return self.query[2]
|
|
|
|
def get_result(self):
|
|
return self.result
|
|
|
|
def with_result(self, result):
|
|
self.result = result
|
|
|
|
def dump_to(self, output):
|
|
if output is None:
|
|
return
|
|
|
|
for line in range(self._start, self.result_line):
|
|
output.write(self._parser.get_line(line))
|
|
|
|
if self.result is not None:
|
|
logger.debug("dump result %s", self.result)
|
|
output.write("----\n")
|
|
for row in self.result:
|
|
output.write(" ".join(row) + "\n")
|
|
|
|
output.write("\n")
|
|
|
|
|
|
class TestFileParser:
|
|
CONTROL_TOKENS = ["halt", "hash-threshold"]
|
|
CONDITIONS_TOKENS = [CONDITION_SKIP, CONDITION_ONLY]
|
|
STATEMENT_TOKEN = "statement"
|
|
QUERY_TOKEN = "query"
|
|
COMMENT_TOKEN = "#"
|
|
|
|
DEFAULT_HASH_THRESHOLD = 8
|
|
|
|
def __init__(self, stream, test_name, test_file, dbms_name):
|
|
self._stream = stream
|
|
self._test_name = test_name
|
|
self._test_file = test_file
|
|
self.dbms_name = dbms_name
|
|
|
|
self._lines = []
|
|
self._raw_tokens = []
|
|
self._tokens = []
|
|
self._empty_lines = []
|
|
|
|
def get_test_name(self):
|
|
return self._test_name
|
|
|
|
def get_test_file(self):
|
|
if self._test_file is not None:
|
|
return self._test_file
|
|
return self._test_name
|
|
|
|
def get_line(self, line):
|
|
return self._lines[line]
|
|
|
|
def get_tokens(self, start, end=None):
|
|
if end is None:
|
|
return self._tokens[start]
|
|
else:
|
|
return self._tokens[start:end]
|
|
|
|
def get_tokens_from_lines(self, start, end):
|
|
return list(chain(*self._tokens[start:end]))
|
|
|
|
def __load_file(self):
|
|
self._lines = self._stream.readlines()
|
|
|
|
self._raw_tokens = [line.split() for line in self._lines]
|
|
assert len(self._lines) == len(self._raw_tokens)
|
|
|
|
self._tokens = []
|
|
for line in self._raw_tokens:
|
|
if self.COMMENT_TOKEN in line:
|
|
comment_starts_at = line.index(self.COMMENT_TOKEN)
|
|
self._tokens.append(line[0:comment_starts_at])
|
|
else:
|
|
self._tokens.append(line)
|
|
|
|
self._empty_lines = [i for i, x in enumerate(self._raw_tokens) if len(x) == 0]
|
|
|
|
logger.debug(
|
|
"Test file %s loaded rows %s, empty rows %s",
|
|
self.get_test_file(),
|
|
len(self._lines),
|
|
len(self._empty_lines),
|
|
)
|
|
|
|
def __unload_file(self):
|
|
self._test_file = None
|
|
self._test_name = None
|
|
self._stream = None
|
|
self._lines = []
|
|
self._raw_tokens = []
|
|
self._tokens = []
|
|
self._empty_lines = []
|
|
|
|
def _iterate_blocks(self):
|
|
prev = 0
|
|
for i in self._empty_lines:
|
|
if prev != i:
|
|
yield FileBlockBase.parse_block(self, prev, i)
|
|
prev = i + 1
|
|
|
|
if prev != len(self._lines):
|
|
yield FileBlockBase.parse_block(self, prev, len(self._lines))
|
|
|
|
def test_blocks(self):
|
|
try:
|
|
self.__load_file()
|
|
yield from self._iterate_blocks()
|
|
finally:
|
|
self.__unload_file()
|
|
|
|
|
|
class QueryResult:
|
|
def __init__(
|
|
self,
|
|
rows=None,
|
|
values_count=None,
|
|
data_hash=None,
|
|
exception=None,
|
|
hash_threshold=0,
|
|
):
|
|
self.rows = rows
|
|
self.values_count = values_count
|
|
self.data_hash = data_hash
|
|
self.exception = exception
|
|
self.hash_threshold = hash_threshold
|
|
self.hash_it()
|
|
logger.debug("created QueryResult %s", str(self))
|
|
|
|
def __str__(self):
|
|
params = ", ".join(
|
|
(
|
|
str(x)
|
|
for x in [
|
|
f"rows: {self.rows}" if self.rows else "",
|
|
f"values_count: {self.values_count}" if self.values_count else "",
|
|
f"data_hash: {self.data_hash}" if self.data_hash else "",
|
|
f"exception: {self.exception}" if self.exception else "",
|
|
f"hash_threshold: {self.hash_threshold}"
|
|
if self.hash_threshold
|
|
else "",
|
|
]
|
|
if x
|
|
)
|
|
)
|
|
return f"QueryResult({params})"
|
|
|
|
def __iter__(self):
|
|
if self.rows is not None:
|
|
if self.hash_threshold == 0:
|
|
return iter(self.rows)
|
|
if self.values_count <= self.hash_threshold:
|
|
return iter(self.rows)
|
|
if self.data_hash is not None:
|
|
return iter([[f"{self.values_count} values hashing to {self.data_hash}"]])
|
|
if self.exception is not None:
|
|
return iter([[f"exception: {self.exception}"]])
|
|
raise ProgramError("Query result is empty", details=str(self))
|
|
|
|
@staticmethod
|
|
def __value_count(rows):
|
|
return reduce(lambda a, b: a + len(b), rows, 0)
|
|
|
|
@staticmethod
|
|
def parse_it(rows, hash_threshold):
|
|
logger.debug("parse result len: %s rows: %s", len(rows), rows)
|
|
if len(rows) == 1:
|
|
logger.debug("one row is %s", rows)
|
|
if len(rows[0]) > 0 and rows[0][0] == "exception:":
|
|
logging.debug("as exception")
|
|
message = " ".join(rows[0][1:])
|
|
return QueryResult(exception=message)
|
|
if len(rows[0]) == 5 and " ".join(rows[0][1:4]) == "values hashing to":
|
|
logging.debug("as hashed data")
|
|
values_count = int(rows[0][0])
|
|
data_hash = rows[0][4]
|
|
return QueryResult(data_hash=data_hash, values_count=values_count)
|
|
logger.debug("as data")
|
|
values_count = QueryResult.__value_count(rows)
|
|
return QueryResult(
|
|
rows=rows, values_count=values_count, hash_threshold=hash_threshold
|
|
)
|
|
|
|
@staticmethod
|
|
def __result_as_strings(rows, types):
|
|
res = []
|
|
for row in rows:
|
|
res_row = []
|
|
for c, t in zip(row, types):
|
|
logger.debug("Builging row. c:%s t:%s", c, t)
|
|
if c is None:
|
|
res_row.append("NULL")
|
|
continue
|
|
|
|
if t == "T":
|
|
if c == "":
|
|
res_row.append("(empty)")
|
|
else:
|
|
res_row.append(str(c))
|
|
elif t == "I":
|
|
try:
|
|
res_row.append(str(int(c)))
|
|
except ValueError:
|
|
# raise QueryExecutionError(
|
|
# f"Got non-integer result '{c}' for I type."
|
|
# )
|
|
res_row.append(str(int(0)))
|
|
except OverflowError as ex:
|
|
raise QueryExecutionError(
|
|
f"Got overflowed result '{c}' for I type."
|
|
) from ex
|
|
|
|
elif t == "R":
|
|
res_row.append(f"{c:.3f}")
|
|
|
|
res.append(res_row)
|
|
return res
|
|
|
|
@staticmethod
|
|
def __sort_result(rows, sort_mode):
|
|
if sort_mode == "nosort":
|
|
return rows
|
|
if sort_mode == "rowsort":
|
|
return sorted(rows)
|
|
if sort_mode == "valuesort":
|
|
values = list(chain(*rows))
|
|
values.sort()
|
|
return [values] if values else []
|
|
return []
|
|
|
|
@staticmethod
|
|
def __calculate_hash(rows):
|
|
md5_hash = md5()
|
|
for row in rows:
|
|
for value in row:
|
|
md5_hash.update(value.encode("ascii"))
|
|
return str(md5_hash.hexdigest())
|
|
|
|
@staticmethod
|
|
def make_it(rows, types, sort_mode, hash_threshold):
|
|
values_count = QueryResult.__value_count(rows)
|
|
as_string = QueryResult.__result_as_strings(rows, types)
|
|
as_sorted = QueryResult.__sort_result(as_string, sort_mode)
|
|
return QueryResult(
|
|
rows=as_sorted, values_count=values_count, hash_threshold=hash_threshold
|
|
)
|
|
|
|
def hash_it(self):
|
|
if self.rows is not None and self.data_hash is None:
|
|
self.data_hash = QueryResult.__calculate_hash(self.rows)
|
|
return self
|
|
|
|
@staticmethod
|
|
def as_exception(e):
|
|
# do not print details to the test file
|
|
# but print original exception
|
|
if isinstance(e, ErrorWithParent):
|
|
message = f"{e}, original is: {e.get_parent()}"
|
|
else:
|
|
message = str(e)
|
|
|
|
return QueryResult(exception=message)
|
|
|
|
@staticmethod
|
|
def assert_eq(canonic, actual):
|
|
if not isinstance(canonic, QueryResult):
|
|
raise ProgramError("NotImplemented")
|
|
|
|
if not isinstance(actual, QueryResult):
|
|
raise ProgramError("NotImplemented")
|
|
|
|
if canonic.exception is not None or actual.exception is not None:
|
|
if canonic.exception is not None and actual.exception is not None:
|
|
if canonic.exception != actual.exception:
|
|
raise DataResultDiffer(
|
|
"canonic and actual results have different exceptions",
|
|
details=f"canonic: {canonic.exception}, actual: {actual.exception}",
|
|
)
|
|
# exceptions are the same
|
|
return
|
|
elif canonic.exception is not None:
|
|
raise DataResultDiffer(
|
|
"canonic result has exception and actual result doesn't",
|
|
details=f"canonic: {canonic.exception}",
|
|
)
|
|
else:
|
|
raise DataResultDiffer(
|
|
"actual result has exception and canonic result doesn't",
|
|
details=f"actual: {actual.exception}",
|
|
)
|
|
|
|
canonic.hash_it()
|
|
actual.hash_it()
|
|
|
|
if canonic.data_hash is not None:
|
|
if actual.data_hash is None:
|
|
raise ProgramError("actual result has to have hash for data")
|
|
if canonic.values_count != actual.values_count:
|
|
raise DataResultDiffer(
|
|
"canonic and actual results have different value count",
|
|
details=f"canonic values count {canonic.values_count}, "
|
|
f"actual {actual.values_count}",
|
|
)
|
|
if canonic.data_hash != actual.data_hash:
|
|
raise DataResultDiffer(
|
|
"canonic and actual results have different hashes"
|
|
)
|
|
return
|
|
|
|
if canonic.rows is not None and actual.rows is not None:
|
|
if canonic.values_count != actual.values_count:
|
|
raise DataResultDiffer(
|
|
"canonic and actual results have different value count",
|
|
details=f"canonic values count {canonic.values_count}, "
|
|
f"actual {actual.values_count}",
|
|
)
|
|
if canonic.rows != actual.rows:
|
|
raise DataResultDiffer(
|
|
"canonic and actual results have different values"
|
|
)
|
|
return
|
|
|
|
raise ProgramError(
|
|
"Unable to compare results",
|
|
details=f"actual {actual}, canonic {canonic}",
|
|
)
|