Rewrite cherry_pick.py to PyGithub API

This commit is contained in:
Mikhail f. Shiryaev 2022-07-14 20:57:03 +02:00
parent da97a22465
commit 909e871c48
No known key found for this signature in database
GPG Key ID: 4B02ED204C7D93F4
10 changed files with 632 additions and 1257 deletions

View File

@ -1,70 +1,487 @@
#!/usr/bin/env python3
"""
A plan:
- Receive GH objects cache from S3, ignore if fails
- Get all open release PRs
- Get all pull-requests between the date of the merge-base for the oldest PR with
labels pr-must-backport and version-specific v21.8-must-backport, but without
pr-backported
- Iterate over gotten PRs:
- for pr-must-backport:
- check if all backport-PRs are created. If yes,
set pr-backported label
- If not, create either cherrypick PRs or merge cherrypick (in the same
stage, if mergable?) and create backport-PRs
- If successfull, set pr-backported label on the PR
- for version-specific labels:
- the same, check, cherry-pick, backport
Cherry-pick stage:
- From time to time the cherry-pick fails, if it was done manually. In the
case we should check if it's even needed, and mark the release as done somehow.
"""
import argparse
import logging
import os
import subprocess
import sys
from contextlib import contextmanager
from datetime import date, timedelta
from subprocess import CalledProcessError
from typing import List, Optional
from env_helper import GITHUB_WORKSPACE, TEMP_PATH
from env_helper import TEMP_PATH
from get_robot_token import get_best_robot_token
from git_helper import git_runner, is_shallow
from github_helper import (
GitHub,
PullRequest,
PullRequests,
Repository,
)
from github.Label import Label
from ssh import SSHKey
from cherry_pick_utils.backport import Backport
from cherry_pick_utils.cherrypick import CherryPick
Labels = List[Label]
class labels:
LABEL_MUST_BACKPORT = "pr-must-backport"
LABEL_BACKPORT = "pr-backport"
LABEL_BACKPORTED = "pr-backported"
LABEL_CHERRYPICK = "pr-cherrypick"
LABEL_DO_NOT_TEST = "do not test"
class ReleaseBranch:
CHERRYPICK_DESCRIPTION = """This pull-request is a first step of an automated \
backporting.
It contains changes like after calling a local command `git cherry-pick`.
If you intend to continue backporting this changes, then resolve all conflicts if any.
Otherwise, if you do not want to backport them, then just close this pull-request.
The check results does not matter at this step - you can safely ignore them.
Also this pull-request will be merged automatically as it reaches the mergeable state, \
but you always can merge it manually.
"""
BACKPORT_DESCRIPTION = """This pull-request is a last step of an automated \
backporting.
Treat it as a standard pull-request: look at the checks and resolve conflicts.
Merge it only if you intend to backport changes to the target branch, otherwise just \
close it.
"""
REMOTE = ""
def __init__(self, name: str, pr: PullRequest):
self.name = name
self.pr = pr
self.cherrypick_branch = f"cherrypick/{name}/{pr.merge_commit_sha}"
self.backport_branch = f"backport/{name}/{pr.number}"
self.cherrypick_pr = None # type: Optional[PullRequest]
self.backport_pr = None # type: Optional[PullRequest]
self._backported = None # type: Optional[bool]
self.git_prefix = ( # All commits to cherrypick are done as robot-clickhouse
"git -c user.email=robot-clickhouse@clickhouse.com "
"-c user.name=robot-clickhouse -c commit.gpgsign=false"
)
def pop_prs(self, prs: PullRequests):
to_pop = [] # type: List[int]
for i, pr in enumerate(prs):
if self.name not in pr.head.ref:
continue
if pr.head.ref.startswith(f"cherrypick/{self.name}"):
self.cherrypick_pr = pr
to_pop.append(i)
elif pr.head.ref.startswith(f"backport/{self.name}"):
self.backport_pr = pr
to_pop.append(i)
else:
logging.error(
"PR #%s doesn't head ref starting with known suffix",
pr.number,
)
for i in reversed(to_pop):
# Going from the tail to keep the order and pop greater index first
prs.pop(i)
def process(self, dry_run: bool):
if self.backported:
return
if not self.cherrypick_pr:
if dry_run:
logging.info(
"DRY RUN: Would create cherrypick PR for #%s", self.pr.number
)
return
self.create_cherrypick()
if self.backported:
return
if self.cherrypick_pr is not None:
# Try to merge cherrypick instantly
if self.cherrypick_pr.mergeable and self.cherrypick_pr.state != "closed":
self.cherrypick_pr.merge()
# The PR needs update, since PR.merge doesn't update the object
self.cherrypick_pr.update()
if self.cherrypick_pr.merged:
if dry_run:
logging.info(
"DRY RUN: Would create backport PR for #%s", self.pr.number
)
return
self.create_backport()
return
elif self.cherrypick_pr.state == "closed":
logging.info(
"The cherrypick PR #%s for PR #%s is discarded",
self.cherrypick_pr.number,
self.pr.number,
)
self._backported = True
return
logging.info(
"Cherrypick PR #%s for PR #%s have conflicts and unable to be merged",
self.cherrypick_pr.number,
self.pr.number,
)
def create_cherrypick(self):
# First, create backport branch:
# Checkout release branch with discarding every change
git_runner(f"{self.git_prefix} checkout -f {self.name}")
# Create or reset backport branch
git_runner(f"{self.git_prefix} checkout -B {self.backport_branch}")
# Merge all changes from PR's the first parent commit w/o applying anything
# It will produce the commit like cherry-pick
first_parent = git_runner(f"git rev-parse {self.pr.merge_commit_sha}^1")
git_runner(f"{self.git_prefix} merge -s ours --no-edit {first_parent}")
# Second step, create cherrypick branch
git_runner(
f"{self.git_prefix} branch -f "
f"{self.cherrypick_branch} {self.pr.merge_commit_sha}"
)
# Check if there actually any changes between branches. If no, then no
# other actions are required.
try:
output = git_runner(
f"{self.git_prefix} merge --no-commit --no-ff {self.cherrypick_branch}"
)
# 'up-to-date', 'up to date', who knows what else
if output.startswith("Already up") and output.endswith("date."):
# The changes are already in the release branch, we are done here
logging.info(
"Release branch %s already contain changes from %s",
self.name,
self.pr.number,
)
self._backported = True
return
except CalledProcessError:
# There are most probably conflicts, they'll be resolved in PR
git_runner(f"{self.git_prefix} reset --merge")
else:
# There are changes able to apply, so continue
git_runner(f"{self.git_prefix} reset --merge")
for branch in [self.cherrypick_branch, self.backport_branch]:
git_runner(f"{self.git_prefix} push -f {self.REMOTE} {branch}:{branch}")
self.cherrypick_pr = self.pr.base.repo.create_pull(
title=f"Cherry pick #{self.pr.number} to {self.name}: {self.pr.title}",
body=f"Original pull-request #{self.pr.number}\n\n"
f"{self.CHERRYPICK_DESCRIPTION}",
base=self.backport_branch,
head=self.cherrypick_branch,
)
self.cherrypick_pr.add_to_labels(labels.LABEL_CHERRYPICK)
self.cherrypick_pr.add_to_labels(labels.LABEL_DO_NOT_TEST)
self.cherrypick_pr.add_to_assignees(self.pr.assignee)
self.cherrypick_pr.add_to_assignees(self.pr.user)
def create_backport(self):
git_runner(f"{self.git_prefix} checkout -f {self.backport_branch}")
git_runner(
f"{self.git_prefix} pull --ff-only {self.REMOTE} {self.backport_branch}"
)
merge_base = git_runner(
f"{self.git_prefix} merge-base "
f"{self.REMOTE}/{self.name} {self.backport_branch}"
)
git_runner(f"{self.git_prefix} reset --soft {merge_base}")
title = f"Backport #{self.pr.number} to {self.name}: {self.pr.title}"
git_runner(f"{self.git_prefix} commit -a --allow-empty -F -", input=title)
git_runner(
f"{self.git_prefix} push -f {self.REMOTE} "
f"{self.backport_branch}:{self.backport_branch}"
)
self.backport_pr = self.pr.base.repo.create_pull(
title=title,
body=f"Original pull-request #{self.pr.number}\n"
f"Cherry-pick pull-request #{self.cherrypick_pr.number}\n\n"
f"{self.BACKPORT_DESCRIPTION}",
base=self.name,
head=self.backport_branch,
)
self.backport_pr.add_to_labels(labels.LABEL_BACKPORT)
self.backport_pr.add_to_assignees(self.pr.assignee)
self.backport_pr.add_to_assignees(self.pr.user)
@property
def backported(self) -> bool:
if self._backported is not None:
return self._backported
return self.backport_pr is not None
def __repr__(self):
return self.name
class Backport:
def __init__(self, gh: GitHub, repo: str, dry_run: bool):
self.gh = gh
self._repo_name = repo
self.dry_run = dry_run
self._repo = None # type: Optional[Repository]
self._remote = ""
self._query = f"type:pr repo:{repo}"
self.release_prs = [] # type: PullRequests
self.release_branches = [] # type: List[str]
self.labels_to_backport = [] # type: List[str]
self.prs_for_backport = [] # type: PullRequests
self.error = False
@property
def remote(self) -> str:
if not self._remote:
# lines of "origin git@github.com:ClickHouse/ClickHouse.git (fetch)"
remotes = git_runner("git remote -v").split("\n")
# We need the first word from the first matching result
self._remote = tuple(
remote.split(maxsplit=1)[0]
for remote in remotes
if f"github.com/{self._repo_name}" in remote # ssh
or f"github.com:{self._repo_name}" in remote # https
)[0]
git_runner(f"git fetch {self._remote}")
ReleaseBranch.REMOTE = self._remote
return self._remote
def receive_release_prs(self):
logging.info("Getting release PRs")
self.release_prs = self.gh.get_pulls_from_search(
query=f"{self._query} is:open",
sort="created",
order="asc",
type="pr",
label="release",
)
self.release_branches = [pr.head.ref for pr in self.release_prs]
self.labels_to_backport = [
f"v{branch}-must-backport" for branch in self.release_branches
]
logging.info("Active releases: %s", ", ".join(self.release_branches))
def receive_prs_for_backport(self):
since_commit = git_runner(
f"git merge-base {self.remote}/{self.release_branches[0]} "
f"{self.remote}/{self.default_branch}"
)
since_date = date.fromisoformat(
git_runner.run(f"git log -1 --format=format:%cs {since_commit}")
)
tomorrow = date.today() + timedelta(days=1)
logging.info("Receive PRs suppose to be backported")
self.prs_for_backport = self.gh.get_pulls_from_search(
query=f"{self._query} -label:pr-backported",
label=",".join(self.labels_to_backport + [labels.LABEL_MUST_BACKPORT]),
merged=[since_date, tomorrow],
)
logging.info(
"PRs to be backported:\n %s",
"\n ".join([pr.html_url for pr in self.prs_for_backport]),
)
def process_backports(self):
for pr in self.prs_for_backport:
self.process_pr(pr)
def process_pr(self, pr: PullRequest):
pr_labels = [label.name for label in pr.labels]
if labels.LABEL_MUST_BACKPORT in pr_labels:
branches = [
ReleaseBranch(br, pr) for br in self.release_branches
] # type: List[ReleaseBranch]
else:
branches = [
ReleaseBranch(br, pr)
for br in [
label.split("-", 1)[0][1:] # v21.8-must-backport
for label in pr_labels
if label in self.labels_to_backport
]
]
if not branches:
# This is definitely some error. There must be at least one branch
# It also make the whole program exit code non-zero
logging.error(
"There are no branches to backport PR #%s, logical error", pr.number
)
self.error = True
return
logging.info(
" PR #%s is suppose to be backported to %s",
pr.number,
", ".join(map(str, branches)),
)
# All PRs for cherrypick and backport branches as heads
query_suffix = " ".join(
[
f"head:{branch.backport_branch} head:{branch.cherrypick_branch}"
for branch in branches
]
)
bp_cp_prs = self.gh.get_pulls_from_search(
query=f"{self._query} {query_suffix}",
)
for br in branches:
br.pop_prs(bp_cp_prs)
if bp_cp_prs:
# This is definitely some error. All prs must be consumed by
# branches with ReleaseBranch.pop_prs. It also make the whole
# program exit code non-zero
logging.error(
"The following PRs are not filtered by release branches:\n%s",
"\n".join(map(str, bp_cp_prs)),
)
self.error = True
return
if all(br.backported for br in branches):
# Let's check if the PR is already backported
self.mark_pr_backported(pr)
return
for br in branches:
try:
br.process(self.dry_run)
except Exception as e:
logging.error(
"During processing the PR #%s error occured: %s", pr.number, e
)
self.error = True
if all(br.backported for br in branches):
# And check it after the running
self.mark_pr_backported(pr)
def mark_pr_backported(self, pr: PullRequest):
if self.dry_run:
logging.info("DRY RUN: would mark PR #%s as done", pr.number)
return
pr.add_to_labels(labels.LABEL_BACKPORTED)
logging.info(
"PR #%s is successfully labeled with `%s`",
pr.number,
labels.LABEL_BACKPORTED,
)
@staticmethod
def pr_labels(pr: PullRequest) -> List[str]:
return [label.name for label in pr.labels]
@property
def repo(self) -> Repository:
if self._repo is None:
try:
self._repo = self.release_prs[0].base.repo
except IndexError as exc:
raise Exception(
"`repo` is available only after the `receive_release_prs`"
) from exc
return self._repo
@property
def default_branch(self) -> str:
return self.repo.default_branch
def parse_args():
parser = argparse.ArgumentParser("Create cherry-pick and backport PRs")
parser.add_argument("--token", help="github token, if not set, used from smm")
parser.add_argument(
"--repo", default="ClickHouse/ClickHouse", help="repo owner/name"
)
parser.add_argument("--dry-run", action="store_true", help="do not create anything")
parser.add_argument(
"--debug-helpers",
action="store_true",
help="add debug logging for git_helper and github_helper",
)
return parser.parse_args()
@contextmanager
def clear_repo():
orig_ref = git_runner("git branch --show-current") or git_runner(
"git rev-parse HEAD"
)
try:
yield
except (Exception, KeyboardInterrupt):
git_runner(f"git checkout -f {orig_ref}")
raise
else:
git_runner(f"git checkout -f {orig_ref}")
@contextmanager
def stash():
need_stash = bool(git_runner("git diff HEAD"))
if need_stash:
git_runner("git stash push --no-keep-index -m 'running cherry_pick.py'")
try:
with clear_repo():
yield
except (Exception, KeyboardInterrupt):
if need_stash:
git_runner("git stash pop")
raise
else:
if need_stash:
git_runner("git stash pop")
def main():
if not os.path.exists(TEMP_PATH):
os.makedirs(TEMP_PATH)
args = parse_args()
if args.debug_helpers:
logging.getLogger("github_helper").setLevel(logging.DEBUG)
logging.getLogger("git_helper").setLevel(logging.DEBUG)
token = args.token or get_best_robot_token()
bp = Backport(
token,
os.environ.get("REPO_OWNER"),
os.environ.get("REPO_NAME"),
os.environ.get("REPO_TEAM"),
)
cherry_pick = CherryPick(
token,
os.environ.get("REPO_OWNER"),
os.environ.get("REPO_NAME"),
os.environ.get("REPO_TEAM"),
1,
"master",
)
# Use the same _gh in both objects to have a proper cost
# pylint: disable=protected-access
for key in bp._gh.api_costs:
if key in cherry_pick._gh.api_costs:
bp._gh.api_costs[key] += cherry_pick._gh.api_costs[key]
for key in cherry_pick._gh.api_costs:
if key not in bp._gh.api_costs:
bp._gh.api_costs[key] = cherry_pick._gh.api_costs[key]
cherry_pick._gh = bp._gh
# pylint: enable=protected-access
def cherrypick_run(pr_data, branch):
cherry_pick.update_pr_branch(pr_data, branch)
return cherry_pick.execute(GITHUB_WORKSPACE, args.dry_run)
try:
bp.execute(GITHUB_WORKSPACE, "origin", None, cherrypick_run)
except subprocess.CalledProcessError as e:
logging.error(e.output)
gh = GitHub(token, per_page=100)
bp = Backport(gh, args.repo, args.dry_run)
bp.gh.cache_path = str(f"{TEMP_PATH}/gh_cache")
bp.receive_release_prs()
bp.receive_prs_for_backport()
bp.process_backports()
if bp.error:
logging.error("Finished successfully, but errors occured")
sys.exit(1)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
if not os.path.exists(TEMP_PATH):
os.makedirs(TEMP_PATH)
assert not is_shallow()
with stash():
if os.getenv("ROBOT_CLICKHOUSE_SSH_KEY", ""):
with SSHKey("ROBOT_CLICKHOUSE_SSH_KEY"):
main()

View File

@ -1,2 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

View File

@ -1,190 +0,0 @@
# -*- coding: utf-8 -*-
import argparse
import logging
import os
import re
import sys
sys.path.append(os.path.dirname(__file__))
from cherrypick import CherryPick
from query import Query as RemoteRepo
from local import Repository as LocalRepo
class Backport:
def __init__(self, token, owner, name, team):
self._gh = RemoteRepo(
token, owner=owner, name=name, team=team, max_page_size=60, min_page_size=7
)
self._token = token
self.default_branch_name = self._gh.default_branch
self.ssh_url = self._gh.ssh_url
def getPullRequests(self, from_commit):
return self._gh.get_pull_requests(from_commit)
def getBranchesWithRelease(self):
branches = set()
for pull_request in self._gh.find_pull_requests("release"):
branches.add(pull_request["headRefName"])
return branches
def execute(self, repo, upstream, until_commit, run_cherrypick):
repo = LocalRepo(repo, upstream, self.default_branch_name)
all_branches = repo.get_release_branches() # [(branch_name, base_commit)]
release_branches = self.getBranchesWithRelease()
branches = []
# iterate over all branches to preserve their precedence.
for branch in all_branches:
if branch[0] in release_branches:
branches.append(branch)
if not branches:
logging.info("No release branches found!")
return
logging.info(
"Found release branches: %s", ", ".join([br[0] for br in branches])
)
if not until_commit:
until_commit = branches[0][1]
pull_requests = self.getPullRequests(until_commit)
backport_map = {}
pr_map = {pr["number"]: pr for pr in pull_requests}
RE_MUST_BACKPORT = re.compile(r"^v(\d+\.\d+)-must-backport$")
RE_NO_BACKPORT = re.compile(r"^v(\d+\.\d+)-no-backport$")
RE_BACKPORTED = re.compile(r"^v(\d+\.\d+)-backported$")
# pull-requests are sorted by ancestry from the most recent.
for pr in pull_requests:
while repo.comparator(branches[-1][1]) >= repo.comparator(
pr["mergeCommit"]["oid"]
):
logging.info(
"PR #%s is already inside %s. Dropping this branch for further PRs",
pr["number"],
branches[-1][0],
)
branches.pop()
logging.info("Processing PR #%s", pr["number"])
assert len(branches) != 0
branch_set = {branch[0] for branch in branches}
# First pass. Find all must-backports
for label in pr["labels"]["nodes"]:
if label["name"] == "pr-must-backport":
backport_map[pr["number"]] = branch_set.copy()
continue
matched = RE_MUST_BACKPORT.match(label["name"])
if matched:
if pr["number"] not in backport_map:
backport_map[pr["number"]] = set()
backport_map[pr["number"]].add(matched.group(1))
# Second pass. Find all no-backports
for label in pr["labels"]["nodes"]:
if label["name"] == "pr-no-backport" and pr["number"] in backport_map:
del backport_map[pr["number"]]
break
matched_no_backport = RE_NO_BACKPORT.match(label["name"])
matched_backported = RE_BACKPORTED.match(label["name"])
if (
matched_no_backport
and pr["number"] in backport_map
and matched_no_backport.group(1) in backport_map[pr["number"]]
):
backport_map[pr["number"]].remove(matched_no_backport.group(1))
logging.info(
"\tskipping %s because of forced no-backport",
matched_no_backport.group(1),
)
elif (
matched_backported
and pr["number"] in backport_map
and matched_backported.group(1) in backport_map[pr["number"]]
):
backport_map[pr["number"]].remove(matched_backported.group(1))
logging.info(
"\tskipping %s because it's already backported manually",
matched_backported.group(1),
)
for pr, branches in list(backport_map.items()):
statuses = []
for branch in branches:
branch_status = run_cherrypick(pr_map[pr], branch)
statuses.append(f"{branch}, and the status is: {branch_status}")
logging.info(
"PR #%s needs to be backported to:\n\t%s", pr, "\n\t".join(statuses)
)
# print API costs
logging.info("\nGitHub API total costs for backporting per query:")
for name, value in list(self._gh.api_costs.items()):
logging.info("%s : %s", name, value)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--token", type=str, required=True, help="token for Github access"
)
parser.add_argument(
"--repo",
type=str,
required=True,
help="path to full repository",
metavar="PATH",
)
parser.add_argument(
"--til", type=str, help="check PRs from HEAD til this commit", metavar="COMMIT"
)
parser.add_argument(
"--dry-run",
action="store_true",
help="do not create or merge any PRs",
default=False,
)
parser.add_argument(
"--verbose",
"-v",
action="store_true",
help="more verbose output",
default=False,
)
parser.add_argument(
"--upstream",
"-u",
type=str,
help="remote name of upstream in repository",
default="origin",
)
args = parser.parse_args()
if args.verbose:
logging.basicConfig(
format="%(message)s", stream=sys.stdout, level=logging.DEBUG
)
else:
logging.basicConfig(format="%(message)s", stream=sys.stdout, level=logging.INFO)
cherry_pick = CherryPick(
args.token, "ClickHouse", "ClickHouse", "core", 1, "master"
)
def cherrypick_run(pr_data, branch):
cherry_pick.update_pr_branch(pr_data, branch)
return cherry_pick.execute(args.repo, args.dry_run)
bp = Backport(args.token, "ClickHouse", "ClickHouse", "core")
bp.execute(args.repo, args.upstream, args.til, cherrypick_run)

View File

@ -1,319 +0,0 @@
# -*- coding: utf-8 -*-
"""
Backports changes from PR to release branch.
Requires multiple separate runs as part of the implementation.
First run should do the following:
1. Merge release branch with a first parent of merge-commit of PR (using 'ours' strategy). (branch: backport/{branch}/{pr})
2. Create temporary branch over merge-commit to use it for PR creation. (branch: cherrypick/{merge_commit})
3. Create PR from temporary branch to backport branch (emulating cherry-pick).
Second run checks PR from previous run to be merged or at least being mergeable. If it's not merged then try to merge it.
Third run creates PR from backport branch (with merged previous PR) to release branch.
"""
import argparse
from enum import Enum
import logging
import os
import subprocess
import sys
sys.path.append(os.path.dirname(__file__))
from query import Query as RemoteRepo
class CherryPick:
class Status(Enum):
DISCARDED = "discarded"
NOT_INITIATED = "not started"
FIRST_MERGEABLE = "waiting for 1st stage"
FIRST_CONFLICTS = "conflicts on 1st stage"
SECOND_MERGEABLE = "waiting for 2nd stage"
SECOND_CONFLICTS = "conflicts on 2nd stage"
MERGED = "backported"
def _run(self, args):
out = subprocess.check_output(args).rstrip()
logging.debug(out)
return out
def __init__(self, token, owner, name, team, pr_number, target_branch):
self._gh = RemoteRepo(token, owner=owner, name=name, team=team)
self._pr = self._gh.get_pull_request(pr_number)
self.target_branch = target_branch
self.ssh_url = self._gh.ssh_url
# TODO: check if pull-request is merged.
self.update_pr_branch(self._pr, self.target_branch)
def update_pr_branch(self, pr_data, target_branch):
"""The method is here to avoid unnecessary creation of new objects"""
self._pr = pr_data
self.target_branch = target_branch
self.merge_commit_oid = self._pr["mergeCommit"]["oid"]
self.backport_branch = f"backport/{target_branch}/{pr_data['number']}"
self.cherrypick_branch = f"cherrypick/{target_branch}/{self.merge_commit_oid}"
def getCherryPickPullRequest(self):
return self._gh.find_pull_request(
base=self.backport_branch, head=self.cherrypick_branch
)
def createCherryPickPullRequest(self, repo_path):
DESCRIPTION = (
"This pull-request is a first step of an automated backporting.\n"
"It contains changes like after calling a local command `git cherry-pick`.\n"
"If you intend to continue backporting this changes, then resolve all conflicts if any.\n"
"Otherwise, if you do not want to backport them, then just close this pull-request.\n"
"\n"
"The check results does not matter at this step - you can safely ignore them.\n"
"Also this pull-request will be merged automatically as it reaches the mergeable state, but you always can merge it manually.\n"
)
# FIXME: replace with something better than os.system()
git_prefix = [
"git",
"-C",
repo_path,
"-c",
"user.email=robot-clickhouse@yandex-team.ru",
"-c",
"user.name=robot-clickhouse",
]
base_commit_oid = self._pr["mergeCommit"]["parents"]["nodes"][0]["oid"]
# Create separate branch for backporting, and make it look like real cherry-pick.
self._run(git_prefix + ["checkout", "-f", self.target_branch])
self._run(git_prefix + ["checkout", "-B", self.backport_branch])
self._run(git_prefix + ["merge", "-s", "ours", "--no-edit", base_commit_oid])
# Create secondary branch to allow pull request with cherry-picked commit.
self._run(
git_prefix + ["branch", "-f", self.cherrypick_branch, self.merge_commit_oid]
)
self._run(
git_prefix
+ [
"push",
"-f",
"origin",
"{branch}:{branch}".format(branch=self.backport_branch),
]
)
self._run(
git_prefix
+ [
"push",
"-f",
"origin",
"{branch}:{branch}".format(branch=self.cherrypick_branch),
]
)
# Create pull-request like a local cherry-pick
title = self._pr["title"].replace('"', r"\"")
pr = self._gh.create_pull_request(
source=self.cherrypick_branch,
target=self.backport_branch,
title=(
f'Cherry pick #{self._pr["number"]} '
f"to {self.target_branch}: "
f"{title}"
),
description=f'Original pull-request #{self._pr["number"]}\n\n{DESCRIPTION}',
)
# FIXME: use `team` to leave a single eligible assignee.
self._gh.add_assignee(pr, self._pr["author"])
self._gh.add_assignee(pr, self._pr["mergedBy"])
self._gh.set_label(pr, "do not test")
self._gh.set_label(pr, "pr-cherrypick")
return pr
def mergeCherryPickPullRequest(self, cherrypick_pr):
return self._gh.merge_pull_request(cherrypick_pr["id"])
def getBackportPullRequest(self):
return self._gh.find_pull_request(
base=self.target_branch, head=self.backport_branch
)
def createBackportPullRequest(self, cherrypick_pr, repo_path):
DESCRIPTION = (
"This pull-request is a last step of an automated backporting.\n"
"Treat it as a standard pull-request: look at the checks and resolve conflicts.\n"
"Merge it only if you intend to backport changes to the target branch, otherwise just close it.\n"
)
git_prefix = [
"git",
"-C",
repo_path,
"-c",
"user.email=robot-clickhouse@clickhouse.com",
"-c",
"user.name=robot-clickhouse",
]
title = self._pr["title"].replace('"', r"\"")
pr_title = f"Backport #{self._pr['number']} to {self.target_branch}: {title}"
self._run(git_prefix + ["checkout", "-f", self.backport_branch])
self._run(git_prefix + ["pull", "--ff-only", "origin", self.backport_branch])
self._run(
git_prefix
+ [
"reset",
"--soft",
self._run(
git_prefix
+ [
"merge-base",
"origin/" + self.target_branch,
self.backport_branch,
]
),
]
)
self._run(git_prefix + ["commit", "-a", "--allow-empty", "-m", pr_title])
self._run(
git_prefix
+ [
"push",
"-f",
"origin",
"{branch}:{branch}".format(branch=self.backport_branch),
]
)
pr = self._gh.create_pull_request(
source=self.backport_branch,
target=self.target_branch,
title=pr_title,
description=f"Original pull-request #{self._pr['number']}\n"
f"Cherry-pick pull-request #{cherrypick_pr['number']}\n\n{DESCRIPTION}",
)
# FIXME: use `team` to leave a single eligible assignee.
self._gh.add_assignee(pr, self._pr["author"])
self._gh.add_assignee(pr, self._pr["mergedBy"])
self._gh.set_label(pr, "pr-backport")
return pr
def execute(self, repo_path, dry_run=False):
pr1 = self.getCherryPickPullRequest()
if not pr1:
if not dry_run:
pr1 = self.createCherryPickPullRequest(repo_path)
logging.debug(
"Created PR with cherry-pick of %s to %s: %s",
self._pr["number"],
self.target_branch,
pr1["url"],
)
else:
return CherryPick.Status.NOT_INITIATED
else:
logging.debug(
"Found PR with cherry-pick of %s to %s: %s",
self._pr["number"],
self.target_branch,
pr1["url"],
)
if not pr1["merged"] and pr1["mergeable"] == "MERGEABLE" and not pr1["closed"]:
if not dry_run:
pr1 = self.mergeCherryPickPullRequest(pr1)
logging.debug(
"Merged PR with cherry-pick of %s to %s: %s",
self._pr["number"],
self.target_branch,
pr1["url"],
)
if not pr1["merged"]:
logging.debug(
"Waiting for PR with cherry-pick of %s to %s: %s",
self._pr["number"],
self.target_branch,
pr1["url"],
)
if pr1["closed"]:
return CherryPick.Status.DISCARDED
elif pr1["mergeable"] == "CONFLICTING":
return CherryPick.Status.FIRST_CONFLICTS
else:
return CherryPick.Status.FIRST_MERGEABLE
pr2 = self.getBackportPullRequest()
if not pr2:
if not dry_run:
pr2 = self.createBackportPullRequest(pr1, repo_path)
logging.debug(
"Created PR with backport of %s to %s: %s",
self._pr["number"],
self.target_branch,
pr2["url"],
)
else:
return CherryPick.Status.FIRST_MERGEABLE
else:
logging.debug(
"Found PR with backport of %s to %s: %s",
self._pr["number"],
self.target_branch,
pr2["url"],
)
if pr2["merged"]:
return CherryPick.Status.MERGED
elif pr2["closed"]:
return CherryPick.Status.DISCARDED
elif pr2["mergeable"] == "CONFLICTING":
return CherryPick.Status.SECOND_CONFLICTS
else:
return CherryPick.Status.SECOND_MERGEABLE
if __name__ == "__main__":
logging.basicConfig(format="%(message)s", stream=sys.stdout, level=logging.DEBUG)
parser = argparse.ArgumentParser()
parser.add_argument(
"--token", "-t", type=str, required=True, help="token for Github access"
)
parser.add_argument("--pr", type=str, required=True, help="PR# to cherry-pick")
parser.add_argument(
"--branch",
"-b",
type=str,
required=True,
help="target branch name for cherry-pick",
)
parser.add_argument(
"--repo",
"-r",
type=str,
required=True,
help="path to full repository",
metavar="PATH",
)
args = parser.parse_args()
cp = CherryPick(
args.token, "ClickHouse", "ClickHouse", "core", args.pr, args.branch
)
cp.execute(args.repo)

View File

@ -1,109 +0,0 @@
# -*- coding: utf-8 -*-
import functools
import logging
import os
import re
import git
class RepositoryBase:
def __init__(self, repo_path):
self._repo = git.Repo(repo_path, search_parent_directories=(not repo_path))
# comparator of commits
def cmp(x, y):
if str(x) == str(y):
return 0
if self._repo.is_ancestor(x, y):
return -1
else:
return 1
self.comparator = functools.cmp_to_key(cmp)
def iterate(self, begin, end):
rev_range = f"{begin}...{end}"
for commit in self._repo.iter_commits(rev_range, first_parent=True):
yield commit
class Repository(RepositoryBase):
def __init__(self, repo_path, remote_name, default_branch_name):
super().__init__(repo_path)
self._remote = self._repo.remotes[remote_name]
self._remote.fetch()
self._default = self._remote.refs[default_branch_name]
def get_head_commit(self):
return self._repo.commit(self._default)
def get_release_branches(self):
"""
Returns sorted list of tuples:
* remote branch (git.refs.remote.RemoteReference),
* base commit (git.Commit),
* head (git.Commit)).
List is sorted by commits in ascending order.
"""
release_branches = []
RE_RELEASE_BRANCH_REF = re.compile(r"^refs/remotes/.+/\d+\.\d+$")
for branch in [
r for r in self._remote.refs if RE_RELEASE_BRANCH_REF.match(r.path)
]:
base = self._repo.merge_base(self._default, self._repo.commit(branch))
if not base:
logging.info(
"Branch %s is not based on branch %s. Ignoring.",
branch.path,
self._default,
)
elif len(base) > 1:
logging.info(
"Branch %s has more than one base commit. Ignoring.", branch.path
)
else:
release_branches.append((os.path.basename(branch.name), base[0]))
return sorted(release_branches, key=lambda x: self.comparator(x[1]))
class BareRepository(RepositoryBase):
def __init__(self, repo_path, default_branch_name):
super().__init__(repo_path)
self._default = self._repo.branches[default_branch_name]
def get_release_branches(self):
"""
Returns sorted list of tuples:
* branch (git.refs.head?),
* base commit (git.Commit),
* head (git.Commit)).
List is sorted by commits in ascending order.
"""
release_branches = []
RE_RELEASE_BRANCH_REF = re.compile(r"^refs/heads/\d+\.\d+$")
for branch in [
r for r in self._repo.branches if RE_RELEASE_BRANCH_REF.match(r.path)
]:
base = self._repo.merge_base(self._default, self._repo.commit(branch))
if not base:
logging.info(
"Branch %s is not based on branch %s. Ignoring.",
branch.path,
self._default,
)
elif len(base) > 1:
logging.info(
"Branch %s has more than one base commit. Ignoring.", branch.path
)
else:
release_branches.append((os.path.basename(branch.name), base[0]))
return sorted(release_branches, key=lambda x: self.comparator(x[1]))

View File

@ -1,56 +0,0 @@
# -*- coding: utf-8 -*-
class Description:
"""Parsed description representation"""
MAP_CATEGORY_TO_LABEL = {
"New Feature": "pr-feature",
"Bug Fix": "pr-bugfix",
"Improvement": "pr-improvement",
"Performance Improvement": "pr-performance",
# 'Backward Incompatible Change': doesn't match anything
"Build/Testing/Packaging Improvement": "pr-build",
"Non-significant (changelog entry is not needed)": "pr-non-significant",
"Non-significant (changelog entry is not required)": "pr-non-significant",
"Non-significant": "pr-non-significant",
"Documentation (changelog entry is not required)": "pr-documentation",
# 'Other': doesn't match anything
}
def __init__(self, pull_request):
self.label_name = str()
self._parse(pull_request["bodyText"])
def _parse(self, text):
lines = text.splitlines()
next_category = False
category = str()
for line in lines:
stripped = line.strip()
if not stripped:
continue
if next_category:
category = stripped
next_category = False
category_headers = (
"Category (leave one):",
"Changelog category (leave one):",
"Changelog category:",
"Category:",
)
if stripped in category_headers:
next_category = True
if category in Description.MAP_CATEGORY_TO_LABEL:
self.label_name = Description.MAP_CATEGORY_TO_LABEL[category]
else:
if not category:
print("Cannot find category in pr description")
else:
print(("Unknown category: " + category))

View File

@ -1,532 +0,0 @@
# -*- coding: utf-8 -*-
import json
import inspect
import logging
import time
from urllib3.util.retry import Retry # type: ignore
import requests # type: ignore
from requests.adapters import HTTPAdapter # type: ignore
class Query:
"""
Implements queries to the Github API using GraphQL
"""
_PULL_REQUEST = """
author {{
... on User {{
id
login
}}
}}
baseRepository {{
nameWithOwner
}}
mergeCommit {{
oid
parents(first: {min_page_size}) {{
totalCount
nodes {{
oid
}}
}}
}}
mergedBy {{
... on User {{
id
login
}}
}}
baseRefName
closed
headRefName
id
mergeable
merged
number
title
url
"""
def __init__(self, token, owner, name, team, max_page_size=100, min_page_size=10):
self._PULL_REQUEST = Query._PULL_REQUEST.format(min_page_size=min_page_size)
self._token = token
self._owner = owner
self._name = name
self._team = team
self._session = None
self._max_page_size = max_page_size
self._min_page_size = min_page_size
self.api_costs = {}
repo = self.get_repository()
self._id = repo["id"]
self.ssh_url = repo["sshUrl"]
self.default_branch = repo["defaultBranchRef"]["name"]
self.members = set(self.get_members())
def get_repository(self):
_QUERY = """
repository(owner: "{owner}" name: "{name}") {{
defaultBranchRef {{
name
}}
id
sshUrl
}}
"""
query = _QUERY.format(owner=self._owner, name=self._name)
return self._run(query)["repository"]
def get_members(self):
"""Get all team members for organization
Returns:
members: a map of members' logins to ids
"""
_QUERY = """
organization(login: "{organization}") {{
team(slug: "{team}") {{
members(first: {max_page_size} {next}) {{
pageInfo {{
hasNextPage
endCursor
}}
nodes {{
id
login
}}
}}
}}
}}
"""
members = {}
not_end = True
query = _QUERY.format(
organization=self._owner,
team=self._team,
max_page_size=self._max_page_size,
next="",
)
while not_end:
result = self._run(query)["organization"]["team"]
if result is None:
break
result = result["members"]
not_end = result["pageInfo"]["hasNextPage"]
query = _QUERY.format(
organization=self._owner,
team=self._team,
max_page_size=self._max_page_size,
next=f'after: "{result["pageInfo"]["endCursor"]}"',
)
# Update members with new nodes compatible with py3.8-py3.10
members = {
**members,
**{node["login"]: node["id"] for node in result["nodes"]},
}
return members
def get_pull_request(self, number):
_QUERY = """
repository(owner: "{owner}" name: "{name}") {{
pullRequest(number: {number}) {{
{pull_request_data}
}}
}}
"""
query = _QUERY.format(
owner=self._owner,
name=self._name,
number=number,
pull_request_data=self._PULL_REQUEST,
min_page_size=self._min_page_size,
)
return self._run(query)["repository"]["pullRequest"]
def find_pull_request(self, base, head):
_QUERY = """
repository(owner: "{owner}" name: "{name}") {{
pullRequests(
first: {min_page_size} baseRefName: "{base}" headRefName: "{head}"
) {{
nodes {{
{pull_request_data}
}}
totalCount
}}
}}
"""
query = _QUERY.format(
owner=self._owner,
name=self._name,
base=base,
head=head,
pull_request_data=self._PULL_REQUEST,
min_page_size=self._min_page_size,
)
result = self._run(query)["repository"]["pullRequests"]
if result["totalCount"] > 0:
return result["nodes"][0]
else:
return {}
def find_pull_requests(self, label_name):
"""
Get all pull-requests filtered by label name
"""
_QUERY = """
repository(owner: "{owner}" name: "{name}") {{
pullRequests(first: {min_page_size} labels: "{label_name}" states: OPEN) {{
nodes {{
{pull_request_data}
}}
}}
}}
"""
query = _QUERY.format(
owner=self._owner,
name=self._name,
label_name=label_name,
pull_request_data=self._PULL_REQUEST,
min_page_size=self._min_page_size,
)
return self._run(query)["repository"]["pullRequests"]["nodes"]
def get_pull_requests(self, before_commit):
"""
Get all merged pull-requests from the HEAD of default branch to the last commit (excluding)
"""
_QUERY = """
repository(owner: "{owner}" name: "{name}") {{
defaultBranchRef {{
target {{
... on Commit {{
history(first: {max_page_size} {next}) {{
pageInfo {{
hasNextPage
endCursor
}}
nodes {{
oid
associatedPullRequests(first: {min_page_size}) {{
totalCount
nodes {{
... on PullRequest {{
{pull_request_data}
labels(first: {min_page_size}) {{
totalCount
pageInfo {{
hasNextPage
endCursor
}}
nodes {{
name
color
}}
}}
}}
}}
}}
}}
}}
}}
}}
}}
}}
"""
pull_requests = []
not_end = True
query = _QUERY.format(
owner=self._owner,
name=self._name,
max_page_size=self._max_page_size,
min_page_size=self._min_page_size,
pull_request_data=self._PULL_REQUEST,
next="",
)
while not_end:
result = self._run(query)["repository"]["defaultBranchRef"]["target"][
"history"
]
not_end = result["pageInfo"]["hasNextPage"]
query = _QUERY.format(
owner=self._owner,
name=self._name,
max_page_size=self._max_page_size,
min_page_size=self._min_page_size,
pull_request_data=self._PULL_REQUEST,
next=f'after: "{result["pageInfo"]["endCursor"]}"',
)
for commit in result["nodes"]:
# FIXME: maybe include `before_commit`?
if str(commit["oid"]) == str(before_commit):
not_end = False
break
# TODO: fetch all pull-requests that were merged in a single commit.
assert (
commit["associatedPullRequests"]["totalCount"]
<= self._min_page_size
)
for pull_request in commit["associatedPullRequests"]["nodes"]:
if (
pull_request["baseRepository"]["nameWithOwner"]
== f"{self._owner}/{self._name}"
and pull_request["baseRefName"] == self.default_branch
and pull_request["mergeCommit"]["oid"] == commit["oid"]
):
pull_requests.append(pull_request)
return pull_requests
def create_pull_request(
self, source, target, title, description="", draft=False, can_modify=True
):
_QUERY = """
createPullRequest(input: {{
baseRefName: "{target}",
headRefName: "{source}",
repositoryId: "{id}",
title: "{title}",
body: "{body}",
draft: {draft},
maintainerCanModify: {modify}
}}) {{
pullRequest {{
{pull_request_data}
}}
}}
"""
query = _QUERY.format(
target=target,
source=source,
id=self._id,
title=title,
body=description,
draft="true" if draft else "false",
modify="true" if can_modify else "false",
pull_request_data=self._PULL_REQUEST,
)
return self._run(query, is_mutation=True)["createPullRequest"]["pullRequest"]
def merge_pull_request(self, pr_id):
_QUERY = """
mergePullRequest(input: {{
pullRequestId: "{pr_id}"
}}) {{
pullRequest {{
{pull_request_data}
}}
}}
"""
query = _QUERY.format(pr_id=pr_id, pull_request_data=self._PULL_REQUEST)
return self._run(query, is_mutation=True)["mergePullRequest"]["pullRequest"]
# FIXME: figure out how to add more assignees at once
def add_assignee(self, pr, assignee):
_QUERY = """
addAssigneesToAssignable(input: {{
assignableId: "{id1}",
assigneeIds: "{id2}"
}}) {{
clientMutationId
}}
"""
query = _QUERY.format(id1=pr["id"], id2=assignee["id"])
self._run(query, is_mutation=True)
def set_label(self, pull_request, label_name):
"""
Set label by name to the pull request
Args:
pull_request: JSON object returned by `get_pull_requests()`
label_name (string): label name
"""
_GET_LABEL = """
repository(owner: "{owner}" name: "{name}") {{
labels(first: {max_page_size} {next} query: "{label_name}") {{
pageInfo {{
hasNextPage
endCursor
}}
nodes {{
id
name
color
}}
}}
}}
"""
_SET_LABEL = """
addLabelsToLabelable(input: {{
labelableId: "{pr_id}",
labelIds: "{label_id}"
}}) {{
clientMutationId
}}
"""
labels = []
not_end = True
query = _GET_LABEL.format(
owner=self._owner,
name=self._name,
label_name=label_name,
max_page_size=self._max_page_size,
next="",
)
while not_end:
result = self._run(query)["repository"]["labels"]
not_end = result["pageInfo"]["hasNextPage"]
query = _GET_LABEL.format(
owner=self._owner,
name=self._name,
label_name=label_name,
max_page_size=self._max_page_size,
next=f'after: "{result["pageInfo"]["endCursor"]}"',
)
labels += list(result["nodes"])
if not labels:
return
query = _SET_LABEL.format(pr_id=pull_request["id"], label_id=labels[0]["id"])
self._run(query, is_mutation=True)
@property
def session(self):
if self._session is not None:
return self._session
retries = 5
self._session = requests.Session()
retry = Retry(
total=retries,
read=retries,
connect=retries,
backoff_factor=1,
status_forcelist=(403, 500, 502, 504),
)
adapter = HTTPAdapter(max_retries=retry)
self._session.mount("http://", adapter)
self._session.mount("https://", adapter)
return self._session
def _run(self, query, is_mutation=False):
# Get caller and parameters from the stack to track the progress
frame = inspect.getouterframes(inspect.currentframe(), 2)[1]
caller = frame[3]
f_parameters = inspect.signature(getattr(self, caller)).parameters
parameters = ", ".join(str(frame[0].f_locals[p]) for p in f_parameters)
mutation = ""
if is_mutation:
mutation = ", is mutation"
print(f"---GraphQL request for {caller}({parameters}){mutation}---")
headers = {"Authorization": f"bearer {self._token}"}
if is_mutation:
query = f"""
mutation {{
{query}
}}
"""
else:
query = f"""
query {{
{query}
rateLimit {{
cost
remaining
}}
}}
"""
def request_with_retry(retry=0):
max_retries = 5
# From time to time we face some concrete errors, when it worth to
# retry instead of failing competely
# We should sleep progressively
progressive_sleep = 5 * sum(i + 1 for i in range(retry))
if progressive_sleep:
logging.warning(
"Retry GraphQL request %s time, sleep %s seconds",
retry,
progressive_sleep,
)
time.sleep(progressive_sleep)
response = self.session.post(
"https://api.github.com/graphql", json={"query": query}, headers=headers
)
result = response.json()
if response.status_code == 200:
if "errors" in result:
raise Exception(
f"Errors occurred: {result['errors']}\nOriginal query: {query}"
)
if not is_mutation:
if caller not in self.api_costs:
self.api_costs[caller] = 0
self.api_costs[caller] += result["data"]["rateLimit"]["cost"]
return result["data"]
elif (
response.status_code == 403
and "secondary rate limit" in result["message"]
):
if retry <= max_retries:
logging.warning("Secondary rate limit reached")
return request_with_retry(retry + 1)
elif response.status_code == 502 and "errors" in result:
too_many_data = any(
True
for err in result["errors"]
if "message" in err
and "This may be the result of a timeout" in err["message"]
)
if too_many_data:
logging.warning(
"Too many data is requested, decreasing page size %s by 10%%",
self._max_page_size,
)
self._max_page_size = int(self._max_page_size * 0.9)
return request_with_retry(retry)
data = json.dumps(result, indent=4)
raise Exception(f"Query failed with code {response.status_code}:\n{data}")
return request_with_retry()

View File

@ -1,3 +0,0 @@
# Some scripts for backports implementation
TODO: Remove copy from utils/github

View File

@ -1,10 +1,13 @@
#!/usr/bin/env python
import argparse
import logging
import os.path as p
import re
import subprocess
from typing import List, Optional
logger = logging.getLogger(__name__)
# ^ and $ match subline in `multiple\nlines`
# \A and \Z match only start and end of the whole string
RELEASE_BRANCH_REGEXP = r"\A\d+[.]\d+\Z"
@ -55,6 +58,7 @@ class Runner:
def run(self, cmd: str, cwd: Optional[str] = None, **kwargs) -> str:
if cwd is None:
cwd = self.cwd
logger.debug("Running command: %s", cmd)
return subprocess.check_output(
cmd, shell=True, cwd=cwd, encoding="utf-8", **kwargs
).strip()
@ -70,6 +74,9 @@ class Runner:
return
self._cwd = value
def __call__(self, *args, **kwargs):
return self.run(*args, **kwargs)
git_runner = Runner()
# Set cwd to abs path of git root
@ -109,8 +116,8 @@ class Git:
def update(self):
"""Is used to refresh all attributes after updates, e.g. checkout or commit"""
self.branch = self.run("git branch --show-current")
self.sha = self.run("git rev-parse HEAD")
self.branch = self.run("git branch --show-current") or self.sha
self.sha_short = self.sha[:11]
# The following command shows the most recent tag in a graph
# Format should match TAG_REGEXP

162
tests/ci/github_helper.py Normal file
View File

@ -0,0 +1,162 @@
#!/usr/bin/env python
"""Helper for GitHub API requests"""
import logging
from datetime import date, datetime, timedelta
from pathlib import Path
from os import path as p
from time import sleep
from typing import List, Optional
import github
from github.GithubException import RateLimitExceededException
from github.Issue import Issue
from github.PullRequest import PullRequest
from github.Repository import Repository
CACHE_PATH = p.join(p.dirname(p.realpath(__file__)), "gh_cache")
logger = logging.getLogger(__name__)
PullRequests = List[PullRequest]
Issues = List[Issue]
class GitHub(github.Github):
def __init__(self, *args, **kwargs):
# Define meta attribute
self._cache_path = Path(CACHE_PATH)
# And set Path
super().__init__(*args, **kwargs)
# pylint: disable=signature-differs
def search_issues(self, *args, **kwargs) -> Issues: # type: ignore
"""Wrapper around search method with throttling and splitting by date.
We split only by the first"""
splittable = False
for arg, value in kwargs.items():
if arg in ["closed", "created", "merged", "updated"]:
if (
hasattr(value, "__iter__")
and not isinstance(value, str)
and not splittable
):
assert [True for v in value if isinstance(v, (date, datetime))]
assert len(value) == 2
preserved_arg = arg
preserved_value = value
middle_value = value[0] + (value[1] - value[0]) / 2
splittable = middle_value not in value
kwargs[arg] = f"{value[0].isoformat()}..{value[1].isoformat()}"
continue
assert isinstance(value, (date, datetime, str))
inter_result = [] # type: Issues
for i in range(3):
try:
logger.debug("Search issues, args=%s, kwards=%s", args, kwargs)
result = super().search_issues(*args, **kwargs)
if result.totalCount == 1000 and splittable:
# The hard limit is 1000. If it's splittable, then we make
# two subrequests requests with less time frames
logger.debug(
"The search result contain exactly 1000 results, "
"splitting %s=%s by middle point %s",
preserved_arg,
kwargs[preserved_arg],
middle_value,
)
kwargs[preserved_arg] = [preserved_value[0], middle_value]
inter_result.extend(self.search_issues(*args, **kwargs))
if isinstance(middle_value, date):
# When middle_value is a date, 2022-01-01..2022-01-03
# is split to 2022-01-01..2022-01-02 and
# 2022-01-02..2022-01-03, so we have results for
# 2022-01-02 twicely. We split it to
# 2022-01-01..2022-01-02 and 2022-01-03..2022-01-03.
# 2022-01-01..2022-01-02 aren't split, see splittable
middle_value += timedelta(days=1)
kwargs[preserved_arg] = [middle_value, preserved_value[1]]
inter_result.extend(self.search_issues(*args, **kwargs))
return inter_result
inter_result.extend(result)
return inter_result
except RateLimitExceededException as e:
if i == 2:
exception = e
self.sleep_on_rate_limit()
raise exception
# pylint: enable=signature-differs
def get_pulls_from_search(self, *args, **kwargs) -> PullRequests:
"""The search api returns actually issues, so we need to fetch PullRequests"""
issues = self.search_issues(*args, **kwargs)
repos = {}
prs = [] # type: PullRequests
for issue in issues:
# See https://github.com/PyGithub/PyGithub/issues/2202,
# obj._rawData doesn't spend additional API requests
# pylint: disable=protected-access
repo_url = issue._rawData["repository_url"] # type: ignore
if repo_url not in repos:
repos[repo_url] = issue.repository
prs.append(
self.get_pull_cached(repos[repo_url], issue.number, issue.updated_at)
)
return prs
def sleep_on_rate_limit(self):
for limit, data in self.get_rate_limit().raw_data.items():
if data["remaining"] == 0:
sleep_time = data["reset"] - int(datetime.now().timestamp()) + 1
if sleep_time > 0:
logger.warning(
"Faced rate limit for '%s' requests type, sleeping %s",
limit,
sleep_time,
)
sleep(sleep_time)
return
def get_pull_cached(
self, repo: Repository, number: int, updated_at: Optional[datetime] = None
) -> PullRequest:
pr_cache_file = self.cache_path / f"{number}.pickle"
if updated_at is None:
updated_at = datetime.now() - timedelta(hours=-1)
def _get_pr(path: Path) -> PullRequest:
with open(path, "rb") as prfd:
return self.load(prfd) # type: ignore
if pr_cache_file.is_file():
cached_pr = _get_pr(pr_cache_file)
if updated_at <= cached_pr.updated_at:
logger.debug("Getting PR #%s from cache", number)
return cached_pr
for i in range(3):
try:
pr = repo.get_pull(number)
break
except RateLimitExceededException:
if i == 2:
raise
self.sleep_on_rate_limit()
logger.debug("Getting PR #%s from API", number)
with open(pr_cache_file, "wb") as prfd:
self.dump(pr, prfd) # type: ignore
return pr
@property
def cache_path(self):
return self._cache_path
@cache_path.setter
def cache_path(self, value: str):
self._cache_path = Path(value)
if self._cache_path.exists():
assert self._cache_path.is_dir()
else:
self._cache_path.mkdir(parents=True)