mercurial/phases.py
changeset 51406 f8bf1a8e9181
parent 51405 12881244e48a
child 51407 71ae6fee2b9d
--- a/mercurial/phases.py	Tue Feb 20 21:40:08 2024 +0100
+++ b/mercurial/phases.py	Tue Feb 20 21:40:13 2024 +0100
@@ -133,7 +133,7 @@
     util,
 )
 
-Phaseroots = Dict[int, Set[bytes]]
+Phaseroots = Dict[int, Set[int]]
 
 if typing.TYPE_CHECKING:
     from . import (
@@ -210,7 +210,7 @@
     repo = repo.unfiltered()
     dirty = False
     roots = {i: set() for i in allphases}
-    has_node = repo.changelog.index.has_node
+    to_rev = repo.changelog.index.get_rev
     unknown_msg = b'removing unknown node %s from %i-phase boundary\n'
     try:
         f, pending = txnutil.trypending(repo.root, repo.svfs, b'phaseroots')
@@ -219,11 +219,12 @@
                 str_phase, hex_node = line.split()
                 phase = int(str_phase)
                 node = bin(hex_node)
-                if not has_node(node):
+                rev = to_rev(node)
+                if rev is None:
                     repo.ui.debug(unknown_msg % (short(hex_node), phase))
                     dirty = True
                 else:
-                    roots[phase].add(node)
+                    roots[phase].add(rev)
         finally:
             f.close()
     except FileNotFoundError:
@@ -391,7 +392,7 @@
 
     def nonpublicphaseroots(
         self, repo: "localrepo.localrepository"
-    ) -> Set[bytes]:
+    ) -> Set[int]:
         """returns the roots of all non-public phases
 
         The roots are not minimized, so if the secret revisions are
@@ -499,7 +500,7 @@
         self._phasesets = {phase: set() for phase in allphases}
         lowerroots = set()
         for phase in reversed(trackedphases):
-            roots = pycompat.maplist(cl.rev, self._phaseroots[phase])
+            roots = self._phaseroots[phase]
             if roots:
                 ps = set(cl.descendants(roots))
                 for root in roots:
@@ -551,8 +552,10 @@
 
     def _write(self, repo, fp):
         assert repo.filtername is None
+        to_node = repo.changelog.node
         for phase, roots in self._phaseroots.items():
-            for h in sorted(roots):
+            for r in sorted(roots):
+                h = to_node(r)
                 fp.write(b'%i %s\n' % (phase, hex(h)))
         self.dirty = False
 
@@ -584,7 +587,7 @@
         repo.invalidatevolatilesets()
 
     def advanceboundary(
-        self, repo, tr, targetphase, nodes, revs=None, dryrun=None
+        self, repo, tr, targetphase, nodes=None, revs=None, dryrun=None
     ):
         """Set all 'nodes' to phase 'targetphase'
 
@@ -598,6 +601,8 @@
         # phaseroots values, replace them.
         if revs is None:
             revs = []
+        if not revs and not nodes:
+            return set()
         if tr is None:
             phasetracking = None
         else:
@@ -616,7 +621,7 @@
 
             olds = self._phaseroots[phase]
 
-            affected = repo.revs(b'%ln::%ld', olds, revs)
+            affected = repo.revs(b'%ld::%ld', olds, revs)
             changes.update(affected)
             if dryrun:
                 continue
@@ -625,10 +630,7 @@
                     phasetracking, r, self.phase(repo, r), targetphase
                 )
 
-            roots = {
-                ctx.node()
-                for ctx in repo.set(b'roots((%ln::) - %ld)', olds, affected)
-            }
+            roots = set(repo.revs(b'roots((%ld::) - %ld)', olds, affected))
             if olds != roots:
                 self._updateroots(repo, phase, roots, tr)
                 # some roots may need to be declared for lower phases
@@ -636,7 +638,7 @@
         if not dryrun:
             # declare deleted root in the target phase
             if targetphase != 0:
-                self._retractboundary(repo, tr, targetphase, delroots)
+                self._retractboundary(repo, tr, targetphase, revs=delroots)
             repo.invalidatevolatilesets()
         return changes
 
@@ -651,21 +653,19 @@
         else:
             phasetracking = tr.changes.get(b'phases')
         repo = repo.unfiltered()
-        if (
-            self._retractboundary(repo, tr, targetphase, nodes)
-            and phasetracking is not None
-        ):
+        retracted = self._retractboundary(repo, tr, targetphase, nodes)
+        if retracted and phasetracking is not None:
 
             # find the affected revisions
             new = self._phaseroots[targetphase]
             old = oldroots[targetphase]
-            affected = set(repo.revs(b'(%ln::) - (%ln::)', new, old))
+            affected = set(repo.revs(b'(%ld::) - (%ld::)', new, old))
 
             # find the phase of the affected revision
             for phase in range(targetphase, -1, -1):
                 if phase:
                     roots = oldroots.get(phase, [])
-                    revs = set(repo.revs(b'%ln::%ld', roots, affected))
+                    revs = set(repo.revs(b'%ld::%ld', roots, affected))
                     affected -= revs
                 else:  # public phase
                     revs = affected
@@ -673,11 +673,15 @@
                     _trackphasechange(phasetracking, r, phase, targetphase)
         repo.invalidatevolatilesets()
 
-    def _retractboundary(self, repo, tr, targetphase, nodes, revs=None):
+    def _retractboundary(self, repo, tr, targetphase, nodes=None, revs=None):
         # Be careful to preserve shallow-copied values: do not update
         # phaseroots values, replace them.
         if revs is None:
             revs = []
+        if nodes is None:
+            nodes = []
+        if not revs and not nodes:
+            return False
         if (
             targetphase == internal
             and not supportinternal(repo)
@@ -688,10 +692,8 @@
             msg = b'this repository does not support the %s phase' % name
             raise error.ProgrammingError(msg)
 
-        repo = repo.unfiltered()
-        torev = repo.changelog.rev
-        tonode = repo.changelog.node
-        currentroots = {torev(node) for node in self._phaseroots[targetphase]}
+        torev = repo.changelog.index.rev
+        currentroots = self._phaseroots[targetphase]
         finalroots = oldroots = set(currentroots)
         newroots = [torev(node) for node in nodes] + [r for r in revs]
         newroots = [
@@ -701,6 +703,8 @@
         if newroots:
             if nullrev in newroots:
                 raise error.Abort(_(b'cannot change null revision phase'))
+            # do not break the CoW assumption of the shallow copy
+            currentroots = currentroots.copy()
             currentroots.update(newroots)
 
             # Only compute new roots for revs above the roots that are being
@@ -712,18 +716,13 @@
             finalroots = {rev for rev in currentroots if rev < minnewroot}
             finalroots.update(updatedroots)
         if finalroots != oldroots:
-            self._updateroots(
-                repo,
-                targetphase,
-                {tonode(rev) for rev in finalroots},
-                tr,
-            )
+            self._updateroots(repo, targetphase, finalroots, tr)
             return True
         return False
 
     def register_strip(
         self,
-        repo: "localrepo.localrepository",
+        repo,
         tr,
         strip_rev: int,
     ):
@@ -731,12 +730,10 @@
 
         Any roots higher than the stripped revision should be dropped.
         """
-        assert repo.filtername is None
-        to_rev = repo.changelog.index.rev
-        for targetphase, nodes in list(self._phaseroots.items()):
-            filtered = {n for n in nodes if to_rev(n) >= strip_rev}
+        for targetphase, roots in list(self._phaseroots.items()):
+            filtered = {r for r in roots if r >= strip_rev}
             if filtered:
-                self._updateroots(repo, targetphase, nodes - filtered, tr)
+                self._updateroots(repo, targetphase, roots - filtered, tr)
         self.invalidate()
 
 
@@ -793,9 +790,10 @@
     keys = util.sortdict()
     value = b'%i' % draft
     cl = repo.unfiltered().changelog
+    to_node = cl.node
     for root in repo._phasecache._phaseroots[draft]:
-        if repo._phasecache.phase(repo, cl.rev(root)) <= draft:
-            keys[hex(root)] = value
+        if repo._phasecache.phase(repo, root) <= draft:
+            keys[hex(to_node(root))] = value
 
     if repo.publishing():
         # Add an extra data to let remote know we are a publishing