diff --git a/tests/ci/cherry_pick.py b/tests/ci/cherry_pick.py index 91a03e55d87..f81d84365ac 100644 --- a/tests/ci/cherry_pick.py +++ b/tests/ci/cherry_pick.py @@ -91,7 +91,7 @@ close it. name: str, pr: PullRequest, repo: Repository, - backport_created_label: str = Labels.PR_BACKPORTS_CREATED, + backport_created_label: str, ): self.name = name self.pr = pr @@ -119,8 +119,6 @@ close it. """the method processes all prs and pops the ReleaseBranch related prs""" to_pop = [] # type: List[int] for i, pr in enumerate(prs): - if self.name not in pr.head.ref: - continue if pr.head.ref.startswith(f"cherrypick/{self.name}"): self.cherrypick_pr = pr to_pop.append(i) @@ -128,10 +126,7 @@ close it. self.backport_pr = pr to_pop.append(i) else: - logging.error( - "head ref of PR #%s isn't starting with known suffix", - pr.number, - ) + assert False, "BUG! Invalid branch suffix" for i in reversed(to_pop): # Going from the tail to keep the order and pop greater index first prs.pop(i) @@ -225,7 +220,7 @@ close it. # There are changes to apply, so continue git_runner(f"{GIT_PREFIX} reset --merge") - # Push, create the cherrypick PR, lable and assign it + # Push, create the cherry-pick PR, label and assign it for branch in [self.cherrypick_branch, self.backport_branch]: git_runner(f"{GIT_PREFIX} push -f {self.REMOTE} {branch}:{branch}") @@ -351,16 +346,22 @@ class Backport: repo: str, fetch_from: Optional[str], dry_run: bool, - must_create_backport_labels: List[str], - backport_created_label: str, ): self.gh = gh self._repo_name = repo self._fetch_from = fetch_from self.dry_run = dry_run - self.must_create_backport_labels = must_create_backport_labels - self.backport_created_label = backport_created_label + self.must_create_backport_label = ( + Labels.MUST_BACKPORT + if self._repo_name == self._fetch_from + else Labels.MUST_BACKPORT_CLOUD + ) + self.backport_created_label = ( + Labels.PR_BACKPORTS_CREATED + if self._repo_name == self._fetch_from + else Labels.PR_BACKPORTS_CREATED_CLOUD + ) self._remote = "" self._remote_line = "" @@ -460,7 +461,7 @@ class Backport: query_args = { "query": f"type:pr repo:{self._fetch_from} -label:{self.backport_created_label}", "label": ",".join( - self.labels_to_backport + self.must_create_backport_labels + self.labels_to_backport + [self.must_create_backport_label] ), "merged": [since_date, tomorrow], } @@ -484,16 +485,12 @@ class Backport: def process_pr(self, pr: PullRequest) -> None: pr_labels = [label.name for label in pr.labels] - for label in self.must_create_backport_labels: - # We backport any vXXX-must-backport to all branches of the fetch repo (better than no backport) - if label in pr_labels or self._fetch_from: - branches = [ - ReleaseBranch(br, pr, self.repo, self.backport_created_label) - for br in self.release_branches - ] # type: List[ReleaseBranch] - break - - if not branches: + if self.must_create_backport_label in pr_labels: + branches = [ + ReleaseBranch(br, pr, self.repo, self.backport_created_label) + for br in self.release_branches + ] # type: List[ReleaseBranch] + else: branches = [ ReleaseBranch(br, pr, self.repo, self.backport_created_label) for br in [ @@ -502,20 +499,20 @@ class Backport: if label in self.labels_to_backport ] ] - if not branches: - # This is definitely some error. There must be at least one branch - # It also make the whole program exit code non-zero - self.error = Exception( - f"There are no branches to backport PR #{pr.number}, logical error" - ) - raise self.error + if not branches: + # This is definitely some error. There must be at least one branch + # It also make the whole program exit code non-zero + self.error = Exception( + f"There are no branches to backport PR #{pr.number}, logical error" + ) + raise self.error logging.info( " PR #%s is supposed to be backported to %s", pr.number, ", ".join(map(str, branches)), ) - # All PRs for cherrypick and backport branches as heads + # All PRs for cherry-pick and backport branches as heads query_suffix = " ".join( [ f"head:{branch.backport_branch} head:{branch.cherrypick_branch}" @@ -539,16 +536,10 @@ class Backport: ) raise self.error - if all(br.backported for br in branches): - # Let's check if the PR is already backported - self.mark_pr_backported(pr) - return - for br in branches: br.process(self.dry_run) if all(br.backported for br in branches): - # And check it after the running self.mark_pr_backported(pr) def mark_pr_backported(self, pr: PullRequest) -> None: @@ -586,19 +577,6 @@ def parse_args(): ) parser.add_argument("--dry-run", action="store_true", help="do not create anything") - parser.add_argument( - "--must-create-backport-label", - default=Labels.MUST_BACKPORT, - choices=(Labels.MUST_BACKPORT, Labels.MUST_BACKPORT_CLOUD), - help="label to filter PRs to backport", - nargs="+", - ) - parser.add_argument( - "--backport-created-label", - default=Labels.PR_BACKPORTS_CREATED, - choices=(Labels.PR_BACKPORTS_CREATED, Labels.PR_BACKPORTS_CREATED_CLOUD), - help="label to mark PRs as backported", - ) parser.add_argument( "--reserve-search-days", default=0, @@ -663,12 +641,6 @@ def main(): args.repo, args.from_repo, args.dry_run, - ( - args.must_create_backport_label - if isinstance(args.must_create_backport_label, list) - else [args.must_create_backport_label] - ), - args.backport_created_label, ) # https://github.com/python/mypy/issues/3004 bp.gh.cache_path = temp_path / "gh_cache"