2022-08-02 18:44:49 +02:00

213 lines
8.4 KiB

#!/usr/bin/env python
"""Helper for GitHub API requests"""
import logging
from datetime import date, datetime, timedelta
from pathlib import Path
from os import path as p
from time import sleep
from typing import List, Optional, Tuple
import github
from github.GithubException import RateLimitExceededException
from github.Issue import Issue
from github.NamedUser import NamedUser
from github.PullRequest import PullRequest
from github.Repository import Repository
CACHE_PATH = p.join(p.dirname(p.realpath(__file__)), "gh_cache")
logger = logging.getLogger(__name__)
PullRequests = List[PullRequest]
Issues = List[Issue]
class GitHub(github.Github):
def __init__(self, *args, **kwargs):
# Define meta attribute
self._cache_path = Path(CACHE_PATH)
# And set Path
super().__init__(*args, **kwargs)
self._retries = 0
# pylint: disable=signature-differs
def search_issues(self, *args, **kwargs) -> Issues: # type: ignore
"""Wrapper around search method with throttling and splitting by date.
We split only by the first"""
splittable = False
for arg, value in kwargs.items():
if arg in ["closed", "created", "merged", "updated"]:
if hasattr(value, "__iter__") and not isinstance(value, str):
assert [True for v in value if isinstance(v, (date, datetime))]
assert len(value) == 2
kwargs[arg] = f"{value[0].isoformat()}..{value[1].isoformat()}"
if not splittable:
# We split only by the first met splittable argument
preserved_arg = arg
preserved_value = value
middle_value = value[0] + (value[1] - value[0]) / 2
splittable = middle_value not in value
assert isinstance(value, (date, datetime, str))
inter_result = [] # type: Issues
for i in range(self.retries):
logger.debug("Search issues, args=%s, kwargs=%s", args, kwargs)
result = super().search_issues(*args, **kwargs)
if result.totalCount == 1000 and splittable:
# The hard limit is 1000. If it's splittable, then we make
# two subrequests requests with less time frames
"The search result contain exactly 1000 results, "
"splitting %s=%s by middle point %s",
kwargs[preserved_arg] = [preserved_value[0], middle_value]
inter_result.extend(self.search_issues(*args, **kwargs))
if isinstance(middle_value, date):
# When middle_value is a date, 2022-01-01..2022-01-03
# is split to 2022-01-01..2022-01-02 and
# 2022-01-02..2022-01-03, so we have results for
# 2022-01-02 twicely. We split it to
# 2022-01-01..2022-01-02 and 2022-01-03..2022-01-03.
# 2022-01-01..2022-01-02 aren't split, see splittable
middle_value += timedelta(days=1)
kwargs[preserved_arg] = [middle_value, preserved_value[1]]
inter_result.extend(self.search_issues(*args, **kwargs))
return inter_result
return inter_result
except RateLimitExceededException as e:
if i == self.retries - 1:
exception = e
raise exception
# pylint: enable=signature-differs
def get_pulls_from_search(self, *args, **kwargs) -> PullRequests:
"""The search api returns actually issues, so we need to fetch PullRequests"""
issues = self.search_issues(*args, **kwargs)
repos = {}
prs = [] # type: PullRequests
for issue in issues:
# See,
# obj._rawData doesn't spend additional API requests
# pylint: disable=protected-access
repo_url = issue._rawData["repository_url"] # type: ignore
if repo_url not in repos:
repos[repo_url] = issue.repository
self.get_pull_cached(repos[repo_url], issue.number, issue.updated_at)
return prs
def sleep_on_rate_limit(self):
for limit, data in self.get_rate_limit().raw_data.items():
if data["remaining"] == 0:
sleep_time = data["reset"] - int( + 1
if sleep_time > 0:
"Faced rate limit for '%s' requests type, sleeping %s",
def get_pull_cached(
self, repo: Repository, number: int, obj_updated_at: Optional[datetime] = None
) -> PullRequest:
cache_file = self.cache_path / f"pr-{number}.pickle"
if cache_file.is_file():
is_updated, cached_pr = self._is_cache_updated(cache_file, obj_updated_at)
if is_updated:
logger.debug("Getting PR #%s from cache", number)
return cached_pr # type: ignore
logger.debug("Getting PR #%s from API", number)
for i in range(self.retries):
pr = repo.get_pull(number)
except RateLimitExceededException:
if i == self.retries - 1:
logger.debug("Caching PR #%s from API in %s", number, cache_file)
with open(cache_file, "wb") as prfd:
self.dump(pr, prfd) # type: ignore
return pr
def get_user_cached(
self, login: str, obj_updated_at: Optional[datetime] = None
) -> NamedUser:
cache_file = self.cache_path / f"user-{login}.pickle"
if cache_file.is_file():
is_updated, cached_user = self._is_cache_updated(cache_file, obj_updated_at)
if is_updated:
logger.debug("Getting user %s from cache", login)
return cached_user # type: ignore
logger.debug("Getting PR #%s from API", login)
for i in range(self.retries):
user = self.get_user(login)
except RateLimitExceededException:
if i == self.retries - 1:
logger.debug("Caching user %s from API in %s", login, cache_file)
with open(cache_file, "wb") as prfd:
self.dump(user, prfd) # type: ignore
return user
def _get_cached(self, path: Path):
with open(path, "rb") as ob_fd:
return self.load(ob_fd) # type: ignore
def _is_cache_updated(
self, cache_file: Path, obj_updated_at: Optional[datetime]
) -> Tuple[bool, object]:
cached_obj = self._get_cached(cache_file)
# We don't want the cache_updated being always old,
# for example in cases when the user is not updated for ages
cache_updated = max(
datetime.fromtimestamp(cache_file.stat().st_mtime), cached_obj.updated_at
if obj_updated_at is None:
# When we don't know about the object is updated or not,
# we update it once per hour
obj_updated_at = - timedelta(hours=1)
if obj_updated_at <= cache_updated:
return True, cached_obj
return False, cached_obj
def cache_path(self):
return self._cache_path
def cache_path(self, value: str):
self._cache_path = Path(value)
if self._cache_path.exists():
assert self._cache_path.is_dir()
def retries(self):
if self._retries == 0:
self._retries = 3
return self._retries
def retries(self, value: int):
self._retries = value