Add caching for GitHub PR objects

This commit is contained in:
Mikhail f. Shiryaev 2022-05-24 20:18:27 +02:00
parent 4b28ea92ca
commit 0e494c9ee7
No known key found for this signature in database
GPG Key ID: 4B02ED204C7D93F4
2 changed files with 44 additions and 12 deletions

View File

@ -1,3 +1,4 @@
*.md
*.txt
*.json
gh_cache

View File

@ -3,8 +3,10 @@
import argparse
import logging
import os.path as p
import os
import re
from datetime import date, timedelta
from datetime import date, datetime, timedelta
from queue import Empty, Queue
from subprocess import CalledProcessError, DEVNULL
from threading import Thread
@ -13,6 +15,7 @@ from typing import Dict, List, Optional, TextIO
from fuzzywuzzy.fuzz import ratio # type: ignore
from github import Github
from github.NamedUser import NamedUser
from github.Issue import Issue
from github.PullRequest import PullRequest
from github.Repository import Repository
from git_helper import is_shallow, git_runner as runner
@ -34,6 +37,8 @@ categories_preferred_order = (
FROM_REF = ""
TO_REF = ""
SHA_IN_CHANGELOG = [] # type: List[str]
GitHub = Github()
CACHE_PATH = p.join(p.dirname(p.realpath(__file__)), "gh_cache")
class Description:
@ -87,10 +92,10 @@ class Worker(Thread):
def run(self):
while not self.queue.empty():
try:
number = self.queue.get()
issue = self.queue.get() # type: Issue
except Empty:
break # possible race condition, just continue
api_pr = self.repo.get_pull(number)
api_pr = get_pull_cached(self.repo, issue.number, issue.updated_at)
in_changelog = False
merge_commit = api_pr.merge_commit_sha
try:
@ -109,13 +114,31 @@ class Worker(Thread):
self.queue.task_done()
def get_pull_cached(
repo: Repository, number: int, updated_at: Optional[datetime] = None
) -> PullRequest:
pr_cache_file = p.join(CACHE_PATH, f"{number}.pickle")
if updated_at is None:
updated_at = datetime.now() - timedelta(hours=-1)
if p.isfile(pr_cache_file):
cache_updated = datetime.fromtimestamp(p.getmtime(pr_cache_file))
if cache_updated > updated_at:
with open(pr_cache_file, "rb") as prfd:
return GitHub.load(prfd) # type: ignore
pr = repo.get_pull(number)
with open(pr_cache_file, "wb") as prfd:
GitHub.dump(pr, prfd) # type: ignore
return pr
def get_descriptions(
repo: Repository, numbers: List[int], jobs: int
repo: Repository, issues: List[Issue], jobs: int
) -> Dict[str, List[Description]]:
workers = [] # type: List[Worker]
queue = Queue() # type: Queue # (!?!?!?!??!)
for number in numbers:
queue.put(number)
queue = Queue() # type: Queue[Issue]
for issue in issues:
queue.put(issue)
for _ in range(jobs):
worker = Worker(queue, repo)
worker.start()
@ -200,7 +223,10 @@ def generate_description(item: PullRequest, repo: Repository) -> Optional[Descri
if item.head.ref.startswith("backport/"):
branch_parts = item.head.ref.split("/")
if len(branch_parts) == 3:
item = repo.get_pull(int(branch_parts[-1]))
try:
item = get_pull_cached(repo, int(branch_parts[-1]))
except Exception as e:
logging.warning("unable to get backpoted PR, exception: %s", e)
else:
logging.warning(
"The branch %s doesn't match backport template, using PR %s as is",
@ -337,6 +363,10 @@ def main():
format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d]:\n%(message)s",
level=log_levels[min(args.verbose, 3)],
)
# Create a cache directory
if not p.isdir(CACHE_PATH):
os.mkdir(CACHE_PATH, 0o700)
# Get the full repo
if is_shallow():
logging.info("Unshallow repository")
@ -359,15 +389,16 @@ def main():
to_date = (date.fromisoformat(to_date) + timedelta(1)).isoformat()
# Get all PRs for the given time frame
gh = Github(
global GitHub
GitHub = Github(
args.gh_user_or_token, args.gh_password, per_page=100, pool_size=args.jobs
)
query = f"type:pr repo:{args.repo} is:merged merged:{from_date}..{to_date}"
repo = gh.get_repo(args.repo)
api_prs = gh.search_issues(query=query, sort="created")
repo = GitHub.get_repo(args.repo)
api_prs = GitHub.search_issues(query=query, sort="created")
logging.info("Found %s PRs for the query: '%s'", api_prs.totalCount, query)
pr_numbers = [pr.number for pr in api_prs]
pr_numbers = list(api_prs)
descriptions = get_descriptions(repo, pr_numbers, args.jobs)