Improvement of cherry-pick/backport script

- cherry_pick.py now can ba launched locally, with dry-run
- get rid of fallback import paths
- do not create a huge pile of objects for every sneezing
- the same for hidden imports in deep local functions
- improve logging
- fix imports for cherry_pick_utils entities
- Significantly reduced requests to GraphQL API
This commit is contained in:
Mikhail f. Shiryaev 2022-06-16 13:20:03 +02:00
parent 7fce1d54fe
commit 7ed305f9b1
No known key found for this signature in database
GPG Key ID: 4B02ED204C7D93F4
6 changed files with 171 additions and 129 deletions

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python3
import sys
import argparse
import logging
import os
import subprocess
@ -12,37 +12,61 @@ from cherry_pick_utils.backport import Backport
from cherry_pick_utils.cherrypick import CherryPick
def parse_args():
parser = argparse.ArgumentParser("Create cherry-pick and backport PRs")
parser.add_argument("--token", help="github token, if not set, used from smm")
parser.add_argument("--dry-run", action="store_true", help="do not create anything")
return parser.parse_args()
def main():
args = parse_args()
token = args.token or get_parameter_from_ssm("github_robot_token_1")
bp = Backport(
token,
os.environ.get("REPO_OWNER"),
os.environ.get("REPO_NAME"),
os.environ.get("REPO_TEAM"),
)
cherry_pick = CherryPick(
token,
os.environ.get("REPO_OWNER"),
os.environ.get("REPO_NAME"),
os.environ.get("REPO_TEAM"),
1,
"master",
)
# Use the same _gh in both objects to have a proper cost
# pylint: disable=protected-access
for key in bp._gh.api_costs:
if key in cherry_pick._gh.api_costs:
bp._gh.api_costs[key] += cherry_pick._gh.api_costs[key]
for key in cherry_pick._gh.api_costs:
if key not in bp._gh.api_costs:
bp._gh.api_costs[key] = cherry_pick._gh.api_costs[key]
cherry_pick._gh = bp._gh
# pylint: enable=protected-access
def cherrypick_run(pr_data, branch):
cherry_pick.update_pr_branch(pr_data, branch)
return cherry_pick.execute(GITHUB_WORKSPACE, args.dry_run)
try:
bp.execute(GITHUB_WORKSPACE, "origin", None, cherrypick_run)
except subprocess.CalledProcessError as e:
logging.error(e.output)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
repo_path = GITHUB_WORKSPACE
temp_path = TEMP_PATH
if not os.path.exists(temp_path):
os.makedirs(temp_path)
if not os.path.exists(TEMP_PATH):
os.makedirs(TEMP_PATH)
sys.path.append(os.path.join(repo_path, "utils/github"))
with SSHKey("ROBOT_CLICKHOUSE_SSH_KEY"):
token = get_parameter_from_ssm("github_robot_token_1")
bp = Backport(
token,
os.environ.get("REPO_OWNER"),
os.environ.get("REPO_NAME"),
os.environ.get("REPO_TEAM"),
)
def cherrypick_run(token, pr, branch):
return CherryPick(
token,
os.environ.get("REPO_OWNER"),
os.environ.get("REPO_NAME"),
os.environ.get("REPO_TEAM"),
pr,
branch,
).execute(repo_path, False)
try:
bp.execute(repo_path, "origin", None, cherrypick_run)
except subprocess.CalledProcessError as e:
logging.error(e.output)
if os.getenv("ROBOT_CLICKHOUSE_SSH_KEY", ""):
with SSHKey("ROBOT_CLICKHOUSE_SSH_KEY"):
main()
else:
main()

View File

@ -1 +1,2 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

View File

@ -1,19 +1,17 @@
# -*- coding: utf-8 -*-
try:
from clickhouse.utils.github.cherrypick import CherryPick
from clickhouse.utils.github.query import Query as RemoteRepo
from clickhouse.utils.github.local import Repository as LocalRepo
except:
from .cherrypick import CherryPick
from .query import Query as RemoteRepo
from .local import Repository as LocalRepo
import argparse
import logging
import os
import re
import sys
sys.path.append(os.path.dirname(__file__))
from cherrypick import CherryPick
from query import Query as RemoteRepo
from local import Repository as LocalRepo
class Backport:
def __init__(self, token, owner, name, team):
@ -49,14 +47,16 @@ class Backport:
logging.info("No release branches found!")
return
for branch in branches:
logging.info("Found release branch: %s", branch[0])
logging.info(
"Found release branches: %s", ", ".join([br[0] for br in branches])
)
if not until_commit:
until_commit = branches[0][1]
pull_requests = self.getPullRequests(until_commit)
backport_map = {}
pr_map = {pr["number"]: pr for pr in pull_requests}
RE_MUST_BACKPORT = re.compile(r"^v(\d+\.\d+)-must-backport$")
RE_NO_BACKPORT = re.compile(r"^v(\d+\.\d+)-no-backport$")
@ -68,17 +68,17 @@ class Backport:
pr["mergeCommit"]["oid"]
):
logging.info(
"PR #{} is already inside {}. Dropping this branch for further PRs".format(
pr["number"], branches[-1][0]
)
"PR #%s is already inside %s. Dropping this branch for further PRs",
pr["number"],
branches[-1][0],
)
branches.pop()
logging.info("Processing PR #{}".format(pr["number"]))
logging.info("Processing PR #%s", pr["number"])
assert len(branches)
assert len(branches) != 0
branch_set = set([branch[0] for branch in branches])
branch_set = {branch[0] for branch in branches}
# First pass. Find all must-backports
for label in pr["labels"]["nodes"]:
@ -120,16 +120,16 @@ class Backport:
)
for pr, branches in list(backport_map.items()):
logging.info("PR #%s needs to be backported to:", pr)
statuses = []
for branch in branches:
logging.info(
"\t%s, and the status is: %s",
branch,
run_cherrypick(self._token, pr, branch),
)
branch_status = run_cherrypick(pr_map[pr], branch)
statuses.append(f"{branch}, and the status is: {branch_status}")
logging.info(
"PR #%s needs to be backported to:\n\t%s", pr, "\n\t".join(statuses)
)
# print API costs
logging.info("\nGitHub API total costs per query:")
logging.info("\nGitHub API total costs for backporting per query:")
for name, value in list(self._gh.api_costs.items()):
logging.info("%s : %s", name, value)
@ -178,8 +178,13 @@ if __name__ == "__main__":
else:
logging.basicConfig(format="%(message)s", stream=sys.stdout, level=logging.INFO)
cherrypick_run = lambda token, pr, branch: CherryPick(
token, "ClickHouse", "ClickHouse", "core", pr, branch
).execute(args.repo, args.dry_run)
cherry_pick = CherryPick(
args.token, "ClickHouse", "ClickHouse", "core", 1, "master"
)
def cherrypick_run(pr_data, branch):
cherry_pick.update_pr_branch(pr_data, branch)
return cherry_pick.execute(args.repo, args.dry_run)
bp = Backport(args.token, "ClickHouse", "ClickHouse", "core")
bp.execute(args.repo, args.upstream, args.til, cherrypick_run)

View File

@ -14,10 +14,6 @@ Second run checks PR from previous run to be merged or at least being mergeable.
Third run creates PR from backport branch (with merged previous PR) to release branch.
"""
try:
from clickhouse.utils.github.query import Query as RemoteRepo
except:
from .query import Query as RemoteRepo
import argparse
from enum import Enum
@ -26,6 +22,10 @@ import os
import subprocess
import sys
sys.path.append(os.path.dirname(__file__))
from query import Query as RemoteRepo
class CherryPick:
class Status(Enum):
@ -45,20 +45,21 @@ class CherryPick:
def __init__(self, token, owner, name, team, pr_number, target_branch):
self._gh = RemoteRepo(token, owner=owner, name=name, team=team)
self._pr = self._gh.get_pull_request(pr_number)
self.target_branch = target_branch
self.ssh_url = self._gh.ssh_url
# TODO: check if pull-request is merged.
self.update_pr_branch(self._pr, self.target_branch)
def update_pr_branch(self, pr_data, target_branch):
"""The method is here to avoid unnecessary creation of new objects"""
self._pr = pr_data
self.target_branch = target_branch
self.merge_commit_oid = self._pr["mergeCommit"]["oid"]
self.target_branch = target_branch
self.backport_branch = "backport/{branch}/{pr}".format(
branch=target_branch, pr=pr_number
)
self.cherrypick_branch = "cherrypick/{branch}/{oid}".format(
branch=target_branch, oid=self.merge_commit_oid
)
self.backport_branch = f"backport/{target_branch}/{pr_data['number']}"
self.cherrypick_branch = f"cherrypick/{target_branch}/{self.merge_commit_oid}"
def getCherryPickPullRequest(self):
return self._gh.find_pull_request(

View File

@ -5,10 +5,11 @@ import logging
import os
import re
import git
class RepositoryBase:
def __init__(self, repo_path):
import git
self._repo = git.Repo(repo_path, search_parent_directories=(not repo_path))
@ -23,22 +24,22 @@ class RepositoryBase:
self.comparator = functools.cmp_to_key(cmp)
def get_head_commit(self):
return self._repo.commit(self._default)
def iterate(self, begin, end):
rev_range = "{}...{}".format(begin, end)
rev_range = f"{begin}...{end}"
for commit in self._repo.iter_commits(rev_range, first_parent=True):
yield commit
class Repository(RepositoryBase):
def __init__(self, repo_path, remote_name, default_branch_name):
super(Repository, self).__init__(repo_path)
super().__init__(repo_path)
self._remote = self._repo.remotes[remote_name]
self._remote.fetch()
self._default = self._remote.refs[default_branch_name]
def get_head_commit(self):
return self._repo.commit(self._default)
def get_release_branches(self):
"""
Returns sorted list of tuples:
@ -73,7 +74,7 @@ class Repository(RepositoryBase):
class BareRepository(RepositoryBase):
def __init__(self, repo_path, default_branch_name):
super(BareRepository, self).__init__(repo_path)
super().__init__(repo_path)
self._default = self._repo.branches[default_branch_name]
def get_release_branches(self):

View File

@ -1,7 +1,13 @@
# -*- coding: utf-8 -*-
import requests
import json
import inspect
import logging
import time
from urllib3.util.retry import Retry # type: ignore
import requests # type: ignore
from requests.adapters import HTTPAdapter # type: ignore
class Query:
@ -56,6 +62,7 @@ class Query:
self._owner = owner
self._name = name
self._team = team
self._session = None
self._max_page_size = max_page_size
self._min_page_size = min_page_size
@ -129,7 +136,11 @@ class Query:
next='after: "{}"'.format(result["pageInfo"]["endCursor"]),
)
members += dict([(node["login"], node["id"]) for node in result["nodes"]])
# Update members with new nodes compatible with py3.8-py3.10
members = {
**members,
**{node["login"]: node["id"] for node in result["nodes"]},
}
return members
@ -415,32 +426,37 @@ class Query:
query = _SET_LABEL.format(pr_id=pull_request["id"], label_id=labels[0]["id"])
self._run(query, is_mutation=True)
@property
def session(self):
if self._session is not None:
return self._session
retries = 5
self._session = requests.Session()
retry = Retry(
total=retries,
read=retries,
connect=retries,
backoff_factor=1,
status_forcelist=(403, 500, 502, 504),
)
adapter = HTTPAdapter(max_retries=retry)
self._session.mount("http://", adapter)
self._session.mount("https://", adapter)
return self._session
def _run(self, query, is_mutation=False):
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
# Get caller and parameters from the stack to track the progress
frame = inspect.getouterframes(inspect.currentframe(), 2)[1]
caller = frame[3]
f_parameters = inspect.signature(getattr(self, caller)).parameters
parameters = ", ".join(str(frame[0].f_locals[p]) for p in f_parameters)
mutation = ""
if is_mutation:
mutation = ", is mutation"
print(f"---GraphQL request for {caller}({parameters}){mutation}---")
# sleep a little, because we querying github too often
print("Request, is mutation", is_mutation)
time.sleep(0.5)
def requests_retry_session(
retries=5,
backoff_factor=0.5,
status_forcelist=(403, 500, 502, 504),
session=None,
):
session = session or requests.Session()
retry = Retry(
total=retries,
read=retries,
connect=retries,
backoff_factor=backoff_factor,
status_forcelist=status_forcelist,
)
adapter = HTTPAdapter(max_retries=retry)
session.mount("http://", adapter)
session.mount("https://", adapter)
return session
time.sleep(0.1)
headers = {"Authorization": "bearer {}".format(self._token)}
if is_mutation:
@ -464,34 +480,28 @@ class Query:
query=query
)
while True:
request = requests_retry_session().post(
"https://api.github.com/graphql", json={"query": query}, headers=headers
)
if request.status_code == 200:
result = request.json()
if "errors" in result:
raise Exception(
"Errors occurred: {}\nOriginal query: {}".format(
result["errors"], query
)
)
if not is_mutation:
import inspect
caller = inspect.getouterframes(inspect.currentframe(), 2)[1][3]
if caller not in list(self.api_costs.keys()):
self.api_costs[caller] = 0
self.api_costs[caller] += result["data"]["rateLimit"]["cost"]
return result["data"]
else:
import json
response = self.session.post(
"https://api.github.com/graphql", json={"query": query}, headers=headers
)
if response.status_code == 200:
result = response.json()
if "errors" in result:
raise Exception(
"Query failed with code {code}:\n{json}".format(
code=request.status_code,
json=json.dumps(request.json(), indent=4),
"Errors occurred: {}\nOriginal query: {}".format(
result["errors"], query
)
)
if not is_mutation:
if caller not in list(self.api_costs.keys()):
self.api_costs[caller] = 0
self.api_costs[caller] += result["data"]["rateLimit"]["cost"]
return result["data"]
else:
raise Exception(
"Query failed with code {code}:\n{json}".format(
code=response.status_code,
json=json.dumps(response.json(), indent=4),
)
)