#!/usr/bin/env python """Helper for GitHub API requests""" import logging from datetime import date, datetime, timedelta from pathlib import Path from os import path as p from time import sleep from typing import List, Optional, Tuple, Union import github # explicit reimport # pylint: disable=useless-import-alias from github.AuthenticatedUser import AuthenticatedUser from github.GithubException import ( RateLimitExceededException as RateLimitExceededException, ) from github.Issue import Issue as Issue from github.NamedUser import NamedUser as NamedUser from github.PullRequest import PullRequest as PullRequest from github.Repository import Repository as Repository # pylint: enable=useless-import-alias CACHE_PATH = p.join(p.dirname(p.realpath(__file__)), "gh_cache") logger = logging.getLogger(__name__) PullRequests = List[PullRequest] Issues = List[Issue] class GitHub(github.Github): def __init__(self, *args, create_cache_dir=True, **kwargs): # Define meta attribute and apply setter logic self._cache_path = Path(CACHE_PATH) if create_cache_dir: self.cache_path = self.cache_path if not kwargs.get("per_page"): kwargs["per_page"] = 100 # And set Path super().__init__(*args, **kwargs) self._retries = 0 # pylint: disable=signature-differs def search_issues(self, *args, **kwargs) -> Issues: # type: ignore """Wrapper around search method with throttling and splitting by date. We split only by the first""" splittable = False for arg, value in kwargs.items(): if arg in ["closed", "created", "merged", "updated"]: if hasattr(value, "__iter__") and not isinstance(value, str): assert [True for v in value if isinstance(v, (date, datetime))] assert len(value) == 2 kwargs[arg] = f"{value[0].isoformat()}..{value[1].isoformat()}" if not splittable: # We split only by the first met splittable argument preserved_arg = arg preserved_value = value middle_value = value[0] + (value[1] - value[0]) / 2 splittable = middle_value not in value continue assert isinstance(value, (date, datetime, str)) inter_result = [] # type: Issues for i in range(self.retries): try: logger.debug("Search issues, args=%s, kwargs=%s", args, kwargs) result = super().search_issues(*args, **kwargs) if result.totalCount == 1000 and splittable: # The hard limit is 1000. If it's splittable, then we make # two subrequests requests with less time frames logger.debug( "The search result contain exactly 1000 results, " "splitting %s=%s by middle point %s", preserved_arg, kwargs[preserved_arg], middle_value, ) kwargs[preserved_arg] = [preserved_value[0], middle_value] inter_result.extend(self.search_issues(*args, **kwargs)) if isinstance(middle_value, date): # When middle_value is a date, 2022-01-01..2022-01-03 # is split to 2022-01-01..2022-01-02 and # 2022-01-02..2022-01-03, so we have results for # 2022-01-02 twicely. We split it to # 2022-01-01..2022-01-02 and 2022-01-03..2022-01-03. # 2022-01-01..2022-01-02 aren't split, see splittable middle_value += timedelta(days=1) kwargs[preserved_arg] = [middle_value, preserved_value[1]] inter_result.extend(self.search_issues(*args, **kwargs)) return inter_result inter_result.extend(result) return inter_result except RateLimitExceededException as e: if i == self.retries - 1: exception = e self.sleep_on_rate_limit() raise exception # pylint: enable=signature-differs def get_pulls_from_search(self, *args, **kwargs) -> PullRequests: # type: ignore """The search api returns actually issues, so we need to fetch PullRequests""" issues = self.search_issues(*args, **kwargs) repos = {} prs = [] # type: PullRequests for issue in issues: # See https://github.com/PyGithub/PyGithub/issues/2202, # obj._rawData doesn't spend additional API requests # pylint: disable=protected-access repo_url = issue._rawData["repository_url"] # type: ignore if repo_url not in repos: repos[repo_url] = issue.repository prs.append( self.get_pull_cached(repos[repo_url], issue.number, issue.updated_at) ) return prs def sleep_on_rate_limit(self): for limit, data in self.get_rate_limit().raw_data.items(): if data["remaining"] == 0: sleep_time = data["reset"] - int(datetime.now().timestamp()) + 1 if sleep_time > 0: logger.warning( "Faced rate limit for '%s' requests type, sleeping %s", limit, sleep_time, ) sleep(sleep_time) return def get_pull_cached( self, repo: Repository, number: int, obj_updated_at: Optional[datetime] = None ) -> PullRequest: cache_file = self.cache_path / f"pr-{number}.pickle" if cache_file.is_file(): is_updated, cached_pr = self._is_cache_updated(cache_file, obj_updated_at) if is_updated: logger.debug("Getting PR #%s from cache", number) return cached_pr # type: ignore logger.debug("Getting PR #%s from API", number) for i in range(self.retries): try: pr = repo.get_pull(number) break except RateLimitExceededException: if i == self.retries - 1: raise self.sleep_on_rate_limit() logger.debug("Caching PR #%s from API in %s", number, cache_file) with open(cache_file, "wb") as prfd: self.dump(pr, prfd) # type: ignore return pr def get_user_cached( self, login: str, obj_updated_at: Optional[datetime] = None ) -> Union[AuthenticatedUser, NamedUser]: cache_file = self.cache_path / f"user-{login}.pickle" if cache_file.is_file(): is_updated, cached_user = self._is_cache_updated(cache_file, obj_updated_at) if is_updated: logger.debug("Getting user %s from cache", login) return cached_user # type: ignore logger.debug("Getting PR #%s from API", login) for i in range(self.retries): try: user = self.get_user(login) break except RateLimitExceededException: if i == self.retries - 1: raise self.sleep_on_rate_limit() logger.debug("Caching user %s from API in %s", login, cache_file) with open(cache_file, "wb") as prfd: self.dump(user, prfd) # type: ignore return user def _get_cached(self, path: Path): # type: ignore with open(path, "rb") as ob_fd: return self.load(ob_fd) # type: ignore def _is_cache_updated( self, cache_file: Path, obj_updated_at: Optional[datetime] ) -> Tuple[bool, object]: cached_obj = self._get_cached(cache_file) # We don't want the cache_updated being always old, # for example in cases when the user is not updated for ages cache_updated = max( datetime.fromtimestamp(cache_file.stat().st_mtime), cached_obj.updated_at ) if obj_updated_at is None: # When we don't know about the object is updated or not, # we update it once per hour obj_updated_at = datetime.now() - timedelta(hours=1) if obj_updated_at <= cache_updated: return True, cached_obj return False, cached_obj @property def cache_path(self) -> Path: return self._cache_path @cache_path.setter def cache_path(self, value: str) -> None: self._cache_path = Path(value) if self._cache_path.exists(): assert self._cache_path.is_dir() else: self._cache_path.mkdir(parents=True) @property def retries(self): if self._retries == 0: self._retries = 3 return self._retries @retries.setter def retries(self, value: int) -> None: assert isinstance(value, int) self._retries = value