ClickHouse/tests/ci/github_helper.py
2023-03-02 11:49:36 +01:00

226 lines
9.0 KiB
Python

#!/usr/bin/env python
"""Helper for GitHub API requests"""
import logging
from datetime import date, datetime, timedelta
from pathlib import Path
from os import path as p
from time import sleep
from typing import List, Optional, Tuple, Union
import github
# explicit reimport
# pylint: disable=useless-import-alias
from github.AuthenticatedUser import AuthenticatedUser
from github.GithubException import (
RateLimitExceededException as RateLimitExceededException,
)
from github.Issue import Issue as Issue
from github.NamedUser import NamedUser as NamedUser
from github.PullRequest import PullRequest as PullRequest
from github.Repository import Repository as Repository
# pylint: enable=useless-import-alias
CACHE_PATH = p.join(p.dirname(p.realpath(__file__)), "gh_cache")
logger = logging.getLogger(__name__)
PullRequests = List[PullRequest]
Issues = List[Issue]
class GitHub(github.Github):
def __init__(self, *args, create_cache_dir=True, **kwargs):
# Define meta attribute and apply setter logic
self._cache_path = Path(CACHE_PATH)
if create_cache_dir:
self.cache_path = self.cache_path
if not kwargs.get("per_page"):
kwargs["per_page"] = 100
# And set Path
super().__init__(*args, **kwargs)
self._retries = 0
# pylint: disable=signature-differs
def search_issues(self, *args, **kwargs) -> Issues: # type: ignore
"""Wrapper around search method with throttling and splitting by date.
We split only by the first"""
splittable = False
for arg, value in kwargs.items():
if arg in ["closed", "created", "merged", "updated"]:
if hasattr(value, "__iter__") and not isinstance(value, str):
assert [True for v in value if isinstance(v, (date, datetime))]
assert len(value) == 2
kwargs[arg] = f"{value[0].isoformat()}..{value[1].isoformat()}"
if not splittable:
# We split only by the first met splittable argument
preserved_arg = arg
preserved_value = value
middle_value = value[0] + (value[1] - value[0]) / 2
splittable = middle_value not in value
continue
assert isinstance(value, (date, datetime, str))
inter_result = [] # type: Issues
for i in range(self.retries):
try:
logger.debug("Search issues, args=%s, kwargs=%s", args, kwargs)
result = super().search_issues(*args, **kwargs)
if result.totalCount == 1000 and splittable:
# The hard limit is 1000. If it's splittable, then we make
# two subrequests requests with less time frames
logger.debug(
"The search result contain exactly 1000 results, "
"splitting %s=%s by middle point %s",
preserved_arg,
kwargs[preserved_arg],
middle_value,
)
kwargs[preserved_arg] = [preserved_value[0], middle_value]
inter_result.extend(self.search_issues(*args, **kwargs))
if isinstance(middle_value, date):
# When middle_value is a date, 2022-01-01..2022-01-03
# is split to 2022-01-01..2022-01-02 and
# 2022-01-02..2022-01-03, so we have results for
# 2022-01-02 twicely. We split it to
# 2022-01-01..2022-01-02 and 2022-01-03..2022-01-03.
# 2022-01-01..2022-01-02 aren't split, see splittable
middle_value += timedelta(days=1)
kwargs[preserved_arg] = [middle_value, preserved_value[1]]
inter_result.extend(self.search_issues(*args, **kwargs))
return inter_result
inter_result.extend(result)
return inter_result
except RateLimitExceededException as e:
if i == self.retries - 1:
exception = e
self.sleep_on_rate_limit()
raise exception
# pylint: enable=signature-differs
def get_pulls_from_search(self, *args, **kwargs) -> PullRequests: # type: ignore
"""The search api returns actually issues, so we need to fetch PullRequests"""
issues = self.search_issues(*args, **kwargs)
repos = {}
prs = [] # type: PullRequests
for issue in issues:
# See https://github.com/PyGithub/PyGithub/issues/2202,
# obj._rawData doesn't spend additional API requests
# pylint: disable=protected-access
repo_url = issue._rawData["repository_url"] # type: ignore
if repo_url not in repos:
repos[repo_url] = issue.repository
prs.append(
self.get_pull_cached(repos[repo_url], issue.number, issue.updated_at)
)
return prs
def sleep_on_rate_limit(self):
for limit, data in self.get_rate_limit().raw_data.items():
if data["remaining"] == 0:
sleep_time = data["reset"] - int(datetime.now().timestamp()) + 1
if sleep_time > 0:
logger.warning(
"Faced rate limit for '%s' requests type, sleeping %s",
limit,
sleep_time,
)
sleep(sleep_time)
return
def get_pull_cached(
self, repo: Repository, number: int, obj_updated_at: Optional[datetime] = None
) -> PullRequest:
cache_file = self.cache_path / f"pr-{number}.pickle"
if cache_file.is_file():
is_updated, cached_pr = self._is_cache_updated(cache_file, obj_updated_at)
if is_updated:
logger.debug("Getting PR #%s from cache", number)
return cached_pr # type: ignore
logger.debug("Getting PR #%s from API", number)
for i in range(self.retries):
try:
pr = repo.get_pull(number)
break
except RateLimitExceededException:
if i == self.retries - 1:
raise
self.sleep_on_rate_limit()
logger.debug("Caching PR #%s from API in %s", number, cache_file)
with open(cache_file, "wb") as prfd:
self.dump(pr, prfd) # type: ignore
return pr
def get_user_cached(
self, login: str, obj_updated_at: Optional[datetime] = None
) -> Union[AuthenticatedUser, NamedUser]:
cache_file = self.cache_path / f"user-{login}.pickle"
if cache_file.is_file():
is_updated, cached_user = self._is_cache_updated(cache_file, obj_updated_at)
if is_updated:
logger.debug("Getting user %s from cache", login)
return cached_user # type: ignore
logger.debug("Getting PR #%s from API", login)
for i in range(self.retries):
try:
user = self.get_user(login)
break
except RateLimitExceededException:
if i == self.retries - 1:
raise
self.sleep_on_rate_limit()
logger.debug("Caching user %s from API in %s", login, cache_file)
with open(cache_file, "wb") as prfd:
self.dump(user, prfd) # type: ignore
return user
def _get_cached(self, path: Path): # type: ignore
with open(path, "rb") as ob_fd:
return self.load(ob_fd) # type: ignore
def _is_cache_updated(
self, cache_file: Path, obj_updated_at: Optional[datetime]
) -> Tuple[bool, object]:
cached_obj = self._get_cached(cache_file)
# We don't want the cache_updated being always old,
# for example in cases when the user is not updated for ages
cache_updated = max(
datetime.fromtimestamp(cache_file.stat().st_mtime), cached_obj.updated_at
)
if obj_updated_at is None:
# When we don't know about the object is updated or not,
# we update it once per hour
obj_updated_at = datetime.now() - timedelta(hours=1)
if obj_updated_at <= cache_updated:
return True, cached_obj
return False, cached_obj
@property
def cache_path(self) -> Path:
return self._cache_path
@cache_path.setter
def cache_path(self, value: str) -> None:
self._cache_path = Path(value)
if self._cache_path.exists():
assert self._cache_path.is_dir()
else:
self._cache_path.mkdir(parents=True)
@property
def retries(self):
if self._retries == 0:
self._retries = 3
return self._retries
@retries.setter
def retries(self, value: int) -> None:
assert isinstance(value, int)
self._retries = value