Fix linter issues in sqllogic module

This commit is contained in:
Mikhail f. Shiryaev 2024-02-27 17:39:40 +01:00
parent 711da9505e
commit 770d710474
No known key found for this signature in database
GPG Key ID: 4B02ED204C7D93F4
5 changed files with 124 additions and 150 deletions

View File

@ -1,18 +1,15 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import datetime
import logging
import pyodbc
import sqlite3
import traceback
import enum
import logging
import random
import sqlite3
import string
from contextlib import contextmanager
import pyodbc # pylint:disable=import-error; for style check
from exceptions import ProgramError
logger = logging.getLogger("connection")
logger.setLevel(logging.DEBUG)
@ -22,9 +19,7 @@ class OdbcConnectingArgs:
self._kwargs = kwargs
def __str__(self):
conn_str = ";".join(
["{}={}".format(x, y) for x, y in self._kwargs.items() if y]
)
conn_str = ";".join([f"{x}={y}" for x, y in self._kwargs.items() if y])
return conn_str
def update_database(self, database):
@ -49,6 +44,7 @@ class OdbcConnectingArgs:
for kv in conn_str.split(";"):
if kv:
k, v = kv.split("=", 1)
# pylint:disable-next=protected-access
args._kwargs[k] = v
return args
@ -82,7 +78,7 @@ class KnownDBMS(str, enum.Enum):
clickhouse = "ClickHouse"
class ConnectionWrap(object):
class ConnectionWrap:
def __init__(self, connection=None, factory=None, factory_kwargs=None):
self._factory = factory
self._factory_kwargs = factory_kwargs
@ -126,7 +122,7 @@ class ConnectionWrap(object):
f"SELECT name FROM system.tables WHERE database='{self.DATABASE_NAME}'"
)
elif self.DBMS_NAME == KnownDBMS.sqlite.value:
list_query = f"SELECT name FROM sqlite_master WHERE type='table'"
list_query = "SELECT name FROM sqlite_master WHERE type='table'"
else:
logger.warning(
"unable to drop all tables for unknown database: %s", self.DBMS_NAME
@ -154,7 +150,7 @@ class ConnectionWrap(object):
self._use_database(database)
logger.info(
"currentDatabase : %s",
execute_request(f"SELECT currentDatabase()", self).get_result(),
execute_request("SELECT currentDatabase()", self).get_result(),
)
@contextmanager
@ -174,7 +170,7 @@ class ConnectionWrap(object):
def __exit__(self, *args):
if hasattr(self._connection, "close"):
return self._connection.close()
self._connection.close()
def setup_connection(engine, conn_str=None, make_debug_request=True):
@ -263,7 +259,7 @@ class ExecResult:
def assert_no_exception(self):
if self.has_exception():
raise ProgramError(
f"request doesn't have a result set, it has the exception",
"request doesn't have a result set, it has the exception",
parent=self._exception,
)

View File

@ -1,8 +1,6 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from enum import Enum
class Error(Exception):
def __init__(
@ -45,16 +43,8 @@ class Error(Exception):
@property
def reason(self):
return ", ".join(
(
str(x)
for x in [
super().__str__(),
"details: {}".format(self._details) if self._details else "",
]
if x
)
)
details = f"details: {self._details}" if self._details else ""
return ", ".join((str(x) for x in [super().__str__(), details] if x))
def set_details(self, file=None, name=None, pos=None, request=None, details=None):
if file is not None:
@ -88,16 +78,8 @@ class ErrorWithParent(Error):
@property
def reason(self):
return ", ".join(
(
str(x)
for x in [
super().reason,
"exception: {}".format(str(self._parent)) if self._parent else "",
]
if x
)
)
exception = f"exception: {self._parent}" if self._parent else ""
return ", ".join((str(x) for x in [super().reason, exception] if x))
class ProgramError(ErrorWithParent):

View File

@ -2,20 +2,25 @@
# -*- coding: utf-8 -*-
import argparse
import enum
import os
import logging
import csv
import enum
import json
import logging
import multiprocessing
import os
from functools import reduce
from deepdiff import DeepDiff
from connection import setup_connection, Engines, default_clickhouse_odbc_conn_str
from test_runner import TestRunner, Status, RequestType
# isort: off
from deepdiff import DeepDiff # pylint:disable=import-error; for style check
# isort: on
LEVEL_NAMES = [x.lower() for x in logging._nameToLevel.keys() if x != logging.NOTSET]
from connection import Engines, default_clickhouse_odbc_conn_str, setup_connection
from test_runner import RequestType, Status, TestRunner
LEVEL_NAMES = [ # pylint:disable-next=protected-access
l.lower() for l, n in logging._nameToLevel.items() if n != logging.NOTSET
]
def setup_logger(args):
@ -41,7 +46,7 @@ def __write_check_status(status_row, out_dir):
if len(status_row) > 140:
status_row = status_row[0:135] + "..."
check_status_path = os.path.join(out_dir, "check_status.tsv")
with open(check_status_path, "a") as stream:
with open(check_status_path, "a", encoding="utf-8") as stream:
writer = csv.writer(stream, delimiter="\t", lineterminator="\n")
writer.writerow(status_row)
@ -60,7 +65,7 @@ def __write_test_result(
):
all_stages = reports.keys()
test_results_path = os.path.join(out_dir, "test_results.tsv")
with open(test_results_path, "a") as stream:
with open(test_results_path, "a", encoding="utf-8") as stream:
writer = csv.writer(stream, delimiter="\t", lineterminator="\n")
for stage in all_stages:
report = reports[stage]
@ -182,7 +187,7 @@ def mode_check_statements(parser):
input_dir, f"check statements:: not a dir {input_dir}"
)
reports = dict()
reports = {}
out_stages_dir = os.path.join(out_dir, f"{args.mode}-stages")
@ -242,7 +247,7 @@ def mode_check_complete(parser):
input_dir, f"check statements:: not a dir {input_dir}"
)
reports = dict()
reports = {}
out_stages_dir = os.path.join(out_dir, f"{args.mode}-stages")
@ -286,9 +291,9 @@ def make_actual_report(reports):
return {stage: report.get_map() for stage, report in reports.items()}
def write_actual_report(actial, out_dir):
with open(os.path.join(out_dir, "actual_report.json"), "w") as f:
f.write(json.dumps(actial))
def write_actual_report(actual, out_dir):
with open(os.path.join(out_dir, "actual_report.json"), "w", encoding="utf-8") as f:
f.write(json.dumps(actual))
def read_canonic_report(input_dir):
@ -296,13 +301,15 @@ def read_canonic_report(input_dir):
if not os.path.exists(file):
return {}
with open(os.path.join(input_dir, "canonic_report.json"), "r") as f:
with open(
os.path.join(input_dir, "canonic_report.json"), "r", encoding="utf-8"
) as f:
data = f.read()
return json.loads(data)
def write_canonic_report(canonic, out_dir):
with open(os.path.join(out_dir, "canonic_report.json"), "w") as f:
with open(os.path.join(out_dir, "canonic_report.json"), "w", encoding="utf-8") as f:
f.write(json.dumps(canonic))
@ -370,7 +377,7 @@ def mode_self_test(parser):
if not os.path.isdir(out_dir):
raise NotADirectoryError(out_dir, f"self test: not a dir {out_dir}")
reports = dict()
reports = {}
out_stages_dir = os.path.join(out_dir, f"{args.mode}-stages")

View File

@ -2,24 +2,27 @@
# -*- coding: utf-8 -*-
import logging
import os
from itertools import chain
from enum import Enum
from hashlib import md5
from functools import reduce
from hashlib import md5
from itertools import chain
# isort: off
# pylint:disable=import-error; for style check
import sqlglot
from sqlglot.expressions import PrimaryKeyColumnConstraint, ColumnDef
from sqlglot.expressions import ColumnDef, PrimaryKeyColumnConstraint
# pylint:enable=import-error; for style check
# isort: on
from exceptions import (
Error,
ProgramError,
ErrorWithParent,
DataResultDiffer,
Error,
ErrorWithParent,
ProgramError,
QueryExecutionError,
)
logger = logging.getLogger("parser")
logger.setLevel(logging.DEBUG)
@ -248,6 +251,7 @@ class FileBlockBase:
)
block.with_result(result)
return block
raise ValueError(f"Unknown block_type {block_type}")
def dump_to(self, output):
if output is None:
@ -258,9 +262,6 @@ class FileBlockBase:
class FileBlockComments(FileBlockBase):
def __init__(self, parser, start, end):
super().__init__(parser, start, end)
def get_block_type(self):
return BlockType.comments
@ -469,20 +470,18 @@ class QueryResult:
(
str(x)
for x in [
"rows: {}".format(self.rows) if self.rows else "",
"values_count: {}".format(self.values_count)
if self.values_count
else "",
"data_hash: {}".format(self.data_hash) if self.data_hash else "",
"exception: {}".format(self.exception) if self.exception else "",
"hash_threshold: {}".format(self.hash_threshold)
f"rows: {self.rows}" if self.rows else "",
f"values_count: {self.values_count}" if self.values_count else "",
f"data_hash: {self.data_hash}" if self.data_hash else "",
f"exception: {self.exception}" if self.exception else "",
f"hash_threshold: {self.hash_threshold}"
if self.hash_threshold
else "",
]
if x
)
)
return "QueryResult({})".format(params)
return f"QueryResult({params})"
def __iter__(self):
if self.rows is not None:
@ -491,12 +490,10 @@ class QueryResult:
if self.values_count <= self.hash_threshold:
return iter(self.rows)
if self.data_hash is not None:
return iter(
[["{} values hashing to {}".format(self.values_count, self.data_hash)]]
)
return iter([[f"{self.values_count} values hashing to {self.data_hash}"]])
if self.exception is not None:
return iter([["exception: {}".format(self.exception)]])
raise ProgramError("Query result is empty", details="{}".format(self.__str__()))
return iter([[f"exception: {self.exception}"]])
raise ProgramError("Query result is empty", details=str(self))
@staticmethod
def __value_count(rows):
@ -528,7 +525,7 @@ class QueryResult:
for row in rows:
res_row = []
for c, t in zip(row, types):
logger.debug(f"Builging row. c:{c} t:{t}")
logger.debug("Builging row. c:%s t:%s", c, t)
if c is None:
res_row.append("NULL")
continue
@ -541,7 +538,7 @@ class QueryResult:
elif t == "I":
try:
res_row.append(str(int(c)))
except ValueError as ex:
except ValueError:
# raise QueryExecutionError(
# f"Got non-integer result '{c}' for I type."
# )
@ -549,7 +546,7 @@ class QueryResult:
except OverflowError as ex:
raise QueryExecutionError(
f"Got overflowed result '{c}' for I type."
)
) from ex
elif t == "R":
res_row.append(f"{c:.3f}")
@ -567,6 +564,7 @@ class QueryResult:
values = list(chain(*rows))
values.sort()
return [values] if values else []
return []
@staticmethod
def __calculate_hash(rows):
@ -595,9 +593,9 @@ class QueryResult:
# do not print details to the test file
# but print original exception
if isinstance(e, ErrorWithParent):
message = "{}, original is: {}".format(e, e.get_parent())
message = f"{e}, original is: {e.get_parent()}"
else:
message = "{}".format(e)
message = str(e)
return QueryResult(exception=message)
@ -616,7 +614,6 @@ class QueryResult:
"canonic and actual results have different exceptions",
details=f"canonic: {canonic.exception}, actual: {actual.exception}",
)
else:
# exceptions are the same
return
elif canonic.exception is not None:
@ -639,9 +636,8 @@ class QueryResult:
if canonic.values_count != actual.values_count:
raise DataResultDiffer(
"canonic and actual results have different value count",
details="canonic values count {}, actual {}".format(
canonic.values_count, actual.values_count
),
details=f"canonic values count {canonic.values_count}, "
f"actual {actual.values_count}",
)
if canonic.data_hash != actual.data_hash:
raise DataResultDiffer(
@ -653,9 +649,8 @@ class QueryResult:
if canonic.values_count != actual.values_count:
raise DataResultDiffer(
"canonic and actual results have different value count",
details="canonic values count {}, actual {}".format(
canonic.values_count, actual.values_count
),
details=f"canonic values count {canonic.values_count}, "
f"actual {actual.values_count}",
)
if canonic.rows != actual.rows:
raise DataResultDiffer(
@ -665,5 +660,5 @@ class QueryResult:
raise ProgramError(
"Unable to compare results",
details="actual {}, canonic {}".format(actual, canonic),
details=f"actual {actual}, canonic {canonic}",
)

View File

@ -1,25 +1,23 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import enum
import logging
import os
import traceback
import io
import json
import logging
import os
import test_parser
from connection import execute_request
from exceptions import (
DataResultDiffer,
Error,
ProgramError,
DataResultDiffer,
StatementExecutionError,
StatementSuccess,
QueryExecutionError,
QuerySuccess,
SchemeResultDiffer,
StatementExecutionError,
StatementSuccess,
)
from connection import execute_request
logger = logging.getLogger("parser")
logger.setLevel(logging.DEBUG)
@ -55,6 +53,7 @@ class Status(str, enum.Enum):
class TestStatus:
def __init__(self):
self.name = None
self.status = None
self.file = None
self.position = None
@ -155,7 +154,7 @@ class SimpleStats:
self.success += 1
def get_map(self):
result = dict()
result = {}
result["success"] = self.success
result["fail"] = self.fail
return result
@ -187,7 +186,7 @@ class Stats:
choose.update(status)
def get_map(self):
result = dict()
result = {}
result["statements"] = self.statements.get_map()
result["queries"] = self.queries.get_map()
result["total"] = self.total.get_map()
@ -205,7 +204,7 @@ class OneReport:
self.test_name = test_name
self.test_file = test_file
self.stats = Stats()
self.requests = dict() # type: dict(int, TestStatus)
self.requests = {}
def update(self, status):
if not isinstance(status, TestStatus):
@ -218,11 +217,11 @@ class OneReport:
return str(self.get_map())
def get_map(self):
result = dict()
result = {}
result["test_name"] = self.test_name
result["test_file"] = self.test_file
result["stats"] = self.stats.get_map()
result["requests"] = dict()
result["requests"] = {}
requests = result["requests"]
for pos, status in self.requests.items():
requests[pos] = status.get_map()
@ -233,7 +232,7 @@ class Report:
def __init__(self, dbms_name, input_dir=None):
self.dbms_name = dbms_name
self.stats = Stats()
self.tests = dict() # type: dict(str, OneReport)
self.tests = {}
self.input_dir = input_dir
self.output_dir = None
@ -256,7 +255,7 @@ class Report:
self.output_dir = res_dir
def get_map(self):
result = dict()
result = {}
result["dbms_name"] = self.dbms_name
result["stats"] = self.stats.get_map()
result["input_dir"] = self.input_dir
@ -264,7 +263,7 @@ class Report:
result["input_dir"] = self.input_dir
if self.output_dir is not None:
result["output_dir"] = self.output_dir
result["tests"] = dict()
result["tests"] = {}
tests = result["tests"]
for test_name, one_report in self.tests.items():
tests.update({test_name: one_report.get_map()})
@ -297,8 +296,8 @@ class Report:
def write_report(self, report_dir):
report_path = os.path.join(report_dir, "report.json")
logger.info(f"create file {report_path}")
with open(report_path, "w") as stream:
logger.info("create file %s", report_path)
with open(report_path, "w", encoding="utf-8") as stream:
stream.write(json.dumps(self.get_map(), indent=4))
@ -434,16 +433,13 @@ class TestRunner:
details=f"expected error: {expected_error}",
parent=exec_res.get_exception(),
)
else:
clogger.debug("errors matched")
raise QuerySuccess()
else:
clogger.debug("missed error")
raise QueryExecutionError(
"query is expected to fail with error",
details="expected error: {}".format(expected_error),
details=f"expected error: {expected_error}",
)
else:
clogger.debug("success is expected")
if exec_res.has_exception():
clogger.debug("had error")
@ -460,7 +456,6 @@ class TestRunner:
test_parser.QueryResult.assert_eq(canonic, actual)
block.with_result(actual)
raise QuerySuccess()
else:
clogger.debug("completion mode")
raise QueryExecutionError(
"query execution failed with an exception",
@ -476,9 +471,8 @@ class TestRunner:
if canonic_columns_count != actual_columns_count:
raise SchemeResultDiffer(
"canonic and actual columns count differ",
details="expected columns {}, actual columns {}".format(
canonic_columns_count, actual_columns_count
),
details=f"expected columns {canonic_columns_count}, "
f"actual columns {actual_columns_count}",
)
actual = test_parser.QueryResult.make_it(
@ -528,7 +522,7 @@ class TestRunner:
self.report = Report(self.dbms_name, self._input_dir)
if self.results is None:
self.results = dict()
self.results = {}
if self.dbms_name == "ClickHouse" and test_name in [
"test/select5.test",
@ -536,7 +530,7 @@ class TestRunner:
"test/evidence/slt_lang_replace.test",
"test/evidence/slt_lang_droptrigger.test",
]:
logger.info(f"Let's skip test %s for ClickHouse", test_name)
logger.info("Let's skip test %s for ClickHouse", test_name)
return
with self.connection.with_one_test_scope():
@ -565,7 +559,7 @@ class TestRunner:
test_name = os.path.relpath(test_file, start=self._input_dir)
logger.debug("open file %s", test_name)
with open(test_file, "r") as stream:
with open(test_file, "r", encoding="utf-8") as stream:
self.run_one_test(stream, test_name, test_file)
def run_all_tests_from_dir(self, input_dir):
@ -582,10 +576,10 @@ class TestRunner:
for test_name, stream in self.results.items():
test_file = os.path.join(dir_path, test_name)
logger.info(f"create file {test_file}")
logger.info("create file %s", test_file)
result_dir = os.path.dirname(test_file)
os.makedirs(result_dir, exist_ok=True)
with open(test_file, "w") as output:
with open(test_file, "w", encoding="utf-8") as output:
output.write(stream.getvalue())
def write_report(self, report_dir):