mercurial/branchmap.py
changeset 51455 7a063dd9d64e
parent 51454 84fca6d79e25
child 51456 19b2736c8e45
--- a/mercurial/branchmap.py	Sun Feb 25 14:09:36 2024 +0100
+++ b/mercurial/branchmap.py	Mon Feb 26 15:46:24 2024 +0100
@@ -211,73 +211,18 @@
         entries: Union[
             Dict[bytes, List[bytes]], Iterable[Tuple[bytes, List[bytes]]]
         ] = (),
-        tipnode: Optional[bytes] = None,
-        tiprev: Optional[int] = nullrev,
-        filteredhash: Optional[bytes] = None,
-        closednodes: Optional[Set[bytes]] = None,
-        hasnode: Optional[Callable[[bytes], bool]] = None,
-        verify_node: bool = False,
+        closed_nodes: Optional[Set[bytes]] = None,
     ) -> None:
         """hasnode is a function which can be used to verify whether changelog
         has a given node or not. If it's not provided, we assume that every node
         we have exists in changelog"""
-        self._filtername = repo.filtername
-        self._delayed = False
-        if tipnode is None:
-            self.tipnode = repo.nullid
-        else:
-            self.tipnode = tipnode
-        self.tiprev = tiprev
-        self.filteredhash = filteredhash
         # closednodes is a set of nodes that close their branch. If the branch
         # cache has been updated, it may contain nodes that are no longer
         # heads.
-        if closednodes is None:
-            self._closednodes = set()
-        else:
-            self._closednodes = closednodes
+        if closed_nodes is None:
+            closed_nodes = set()
+        self._closednodes = set(closed_nodes)
         self._entries = dict(entries)
-        # Do we need to verify branch at all ?
-        self._verify_node = verify_node
-        # whether closed nodes are verified or not
-        self._closedverified = False
-        # branches for which nodes are verified
-        self._verifiedbranches = set()
-        self._hasnode = None
-        if self._verify_node:
-            self._hasnode = repo.changelog.hasnode
-
-    def _verifyclosed(self):
-        """verify the closed nodes we have"""
-        if not self._verify_node:
-            return
-        if self._closedverified:
-            return
-        assert self._hasnode is not None
-        for node in self._closednodes:
-            if not self._hasnode(node):
-                _unknownnode(node)
-
-        self._closedverified = True
-
-    def _verifybranch(self, branch):
-        """verify head nodes for the given branch."""
-        if not self._verify_node:
-            return
-        if branch not in self._entries or branch in self._verifiedbranches:
-            return
-        assert self._hasnode is not None
-        for n in self._entries[branch]:
-            if not self._hasnode(n):
-                _unknownnode(n)
-
-        self._verifiedbranches.add(branch)
-
-    def _verifyall(self):
-        """verifies nodes of all the branches"""
-        needverification = set(self._entries.keys()) - self._verifiedbranches
-        for b in needverification:
-            self._verifybranch(b)
 
     def __iter__(self):
         return iter(self._entries)
@@ -286,114 +231,20 @@
         self._entries[key] = value
 
     def __getitem__(self, key):
-        self._verifybranch(key)
         return self._entries[key]
 
     def __contains__(self, key):
-        self._verifybranch(key)
         return key in self._entries
 
     def iteritems(self):
-        for k, v in self._entries.items():
-            self._verifybranch(k)
-            yield k, v
+        return self._entries.items()
 
     items = iteritems
 
     def hasbranch(self, label):
         """checks whether a branch of this name exists or not"""
-        self._verifybranch(label)
         return label in self._entries
 
-    @classmethod
-    def fromfile(cls, repo):
-        f = None
-        try:
-            f = repo.cachevfs(cls._filename(repo))
-            lineiter = iter(f)
-            cachekey = next(lineiter).rstrip(b'\n').split(b" ", 2)
-            last, lrev = cachekey[:2]
-            last, lrev = bin(last), int(lrev)
-            filteredhash = None
-            if len(cachekey) > 2:
-                filteredhash = bin(cachekey[2])
-            bcache = cls(
-                repo,
-                tipnode=last,
-                tiprev=lrev,
-                filteredhash=filteredhash,
-                verify_node=True,
-            )
-            if not bcache.validfor(repo):
-                # invalidate the cache
-                raise ValueError('tip differs')
-            bcache.load(repo, lineiter)
-        except (IOError, OSError):
-            return None
-
-        except Exception as inst:
-            if repo.ui.debugflag:
-                msg = b'invalid %s: %s\n'
-                repo.ui.debug(
-                    msg
-                    % (
-                        _branchcachedesc(repo),
-                        stringutil.forcebytestr(inst),
-                    )
-                )
-            bcache = None
-
-        finally:
-            if f:
-                f.close()
-
-        return bcache
-
-    def load(self, repo, lineiter):
-        """fully loads the branchcache by reading from the file using the line
-        iterator passed"""
-        for line in lineiter:
-            line = line.rstrip(b'\n')
-            if not line:
-                continue
-            node, state, label = line.split(b" ", 2)
-            if state not in b'oc':
-                raise ValueError('invalid branch state')
-            label = encoding.tolocal(label.strip())
-            node = bin(node)
-            self._entries.setdefault(label, []).append(node)
-            if state == b'c':
-                self._closednodes.add(node)
-
-    @staticmethod
-    def _filename(repo):
-        """name of a branchcache file for a given repo or repoview"""
-        filename = b"branch2"
-        if repo.filtername:
-            filename = b'%s-%s' % (filename, repo.filtername)
-        return filename
-
-    def validfor(self, repo):
-        """check that cache contents are valid for (a subset of) this repo
-
-        - False when the order of changesets changed or if we detect a strip.
-        - True when cache is up-to-date for the current repo or its subset."""
-        try:
-            node = repo.changelog.node(self.tiprev)
-        except IndexError:
-            # changesets were stripped and now we don't even have enough to
-            # find tiprev
-            return False
-        if self.tipnode != node:
-            # tiprev doesn't correspond to tipnode: repo was stripped, or this
-            # repo has a different order of changesets
-            return False
-        tiphash = scmutil.filteredhash(repo, self.tiprev, needobsolete=True)
-        # hashes don't match if this repo view has a different set of filtered
-        # revisions (e.g. due to phase changes) or obsolete revisions (e.g.
-        # history was rewritten)
-        return self.filteredhash == tiphash
-
     def _branchtip(self, heads):
         """Return tuple with last open head in heads and false,
         otherwise return last closed head and true."""
@@ -416,7 +267,6 @@
         return (n for n in nodes if n not in self._closednodes)
 
     def branchheads(self, branch, closed=False):
-        self._verifybranch(branch)
         heads = self._entries[branch]
         if not closed:
             heads = list(self.iteropen(heads))
@@ -428,98 +278,27 @@
 
     def iterheads(self):
         """returns all the heads"""
-        self._verifyall()
         return self._entries.values()
 
-    def copy(self, repo):
-        """return a deep copy of the branchcache object"""
-        other = type(self)(
-            repo=repo,
-            # we always do a shally copy of self._entries, and the values is
-            # always replaced, so no need to deepcopy until the above remains
-            # true.
-            entries=self._entries,
-            tipnode=self.tipnode,
-            tiprev=self.tiprev,
-            filteredhash=self.filteredhash,
-            closednodes=set(self._closednodes),
-            verify_node=self._verify_node,
-        )
-        # we copy will likely schedule a write anyway, but that does not seems
-        # to hurt to overschedule
-        other._delayed = self._delayed
-        # also copy information about the current verification state
-        other._closedverified = self._closedverified
-        other._verifiedbranches = set(self._verifiedbranches)
-        return other
-
-    def write(self, repo):
-        assert self._filtername == repo.filtername, (
-            self._filtername,
-            repo.filtername,
-        )
-        tr = repo.currenttransaction()
-        if not getattr(tr, 'finalized', True):
-            # Avoid premature writing.
-            #
-            # (The cache warming setup by localrepo will update the file later.)
-            self._delayed = True
-            return
-        try:
-            filename = self._filename(repo)
-            with repo.cachevfs(filename, b"w", atomictemp=True) as f:
-                cachekey = [hex(self.tipnode), b'%d' % self.tiprev]
-                if self.filteredhash is not None:
-                    cachekey.append(hex(self.filteredhash))
-                f.write(b" ".join(cachekey) + b'\n')
-                nodecount = 0
-                for label, nodes in sorted(self._entries.items()):
-                    label = encoding.fromlocal(label)
-                    for node in nodes:
-                        nodecount += 1
-                        if node in self._closednodes:
-                            state = b'c'
-                        else:
-                            state = b'o'
-                        f.write(b"%s %s %s\n" % (hex(node), state, label))
-            repo.ui.log(
-                b'branchcache',
-                b'wrote %s with %d labels and %d nodes\n',
-                _branchcachedesc(repo),
-                len(self._entries),
-                nodecount,
-            )
-            self._delayed = False
-        except (IOError, OSError, error.Abort) as inst:
-            # Abort may be raised by read only opener, so log and continue
-            repo.ui.debug(
-                b"couldn't write branch cache: %s\n"
-                % stringutil.forcebytestr(inst)
-            )
-
     def update(self, repo, revgen):
         """Given a branchhead cache, self, that may have extra nodes or be
         missing heads, and a generator of nodes that are strictly a superset of
         heads missing, this function updates self to be correct.
         """
-        assert self._filtername == repo.filtername, (
-            self._filtername,
-            repo.filtername,
-        )
         starttime = util.timer()
         cl = repo.changelog
         # collect new branch entries
         newbranches = {}
         getbranchinfo = repo.revbranchcache().branchinfo
+        max_rev = -1
         for r in revgen:
             branch, closesbranch = getbranchinfo(r)
             newbranches.setdefault(branch, []).append(r)
             if closesbranch:
                 self._closednodes.add(cl.node(r))
-
-        # new tip revision which we found after iterating items from new
-        # branches
-        ntiprev = self.tiprev
+            max_rev = max(max_rev, r)
+        if max_rev < 0:
+            max_rev = None
 
         # Delay fetching the topological heads until they are needed.
         # A repository without non-continous branches can skip this part.
@@ -613,13 +392,287 @@
                         bheadset -= ancestors
             if bheadset:
                 self[branch] = [cl.node(rev) for rev in sorted(bheadset)]
-            tiprev = max(newheadrevs)
-            if tiprev > ntiprev:
-                ntiprev = tiprev
+
+        duration = util.timer() - starttime
+        repo.ui.log(
+            b'branchcache',
+            b'updated %s in %.4f seconds\n',
+            _branchcachedesc(repo),
+            duration,
+        )
+        return max_rev
+
+
+class branchcache(_BaseBranchCache):
+    """Branchmap info for a local repo or repoview"""
+
+    def __init__(
+        self,
+        repo: "localrepo.localrepository",
+        entries: Union[
+            Dict[bytes, List[bytes]], Iterable[Tuple[bytes, List[bytes]]]
+        ] = (),
+        tipnode: Optional[bytes] = None,
+        tiprev: Optional[int] = nullrev,
+        filteredhash: Optional[bytes] = None,
+        closednodes: Optional[Set[bytes]] = None,
+        hasnode: Optional[Callable[[bytes], bool]] = None,
+        verify_node: bool = False,
+    ) -> None:
+        """hasnode is a function which can be used to verify whether changelog
+        has a given node or not. If it's not provided, we assume that every node
+        we have exists in changelog"""
+        self._filtername = repo.filtername
+        self._delayed = False
+        if tipnode is None:
+            self.tipnode = repo.nullid
+        else:
+            self.tipnode = tipnode
+        self.tiprev = tiprev
+        self.filteredhash = filteredhash
+
+        super().__init__(repo=repo, entries=entries, closed_nodes=closednodes)
+        # closednodes is a set of nodes that close their branch. If the branch
+        # cache has been updated, it may contain nodes that are no longer
+        # heads.
+
+        # Do we need to verify branch at all ?
+        self._verify_node = verify_node
+        # whether closed nodes are verified or not
+        self._closedverified = False
+        # branches for which nodes are verified
+        self._verifiedbranches = set()
+        self._hasnode = None
+        if self._verify_node:
+            self._hasnode = repo.changelog.hasnode
+
+    def validfor(self, repo):
+        """check that cache contents are valid for (a subset of) this repo
+
+        - False when the order of changesets changed or if we detect a strip.
+        - True when cache is up-to-date for the current repo or its subset."""
+        try:
+            node = repo.changelog.node(self.tiprev)
+        except IndexError:
+            # changesets were stripped and now we don't even have enough to
+            # find tiprev
+            return False
+        if self.tipnode != node:
+            # tiprev doesn't correspond to tipnode: repo was stripped, or this
+            # repo has a different order of changesets
+            return False
+        tiphash = scmutil.filteredhash(repo, self.tiprev, needobsolete=True)
+        # hashes don't match if this repo view has a different set of filtered
+        # revisions (e.g. due to phase changes) or obsolete revisions (e.g.
+        # history was rewritten)
+        return self.filteredhash == tiphash
+
+    @classmethod
+    def fromfile(cls, repo):
+        f = None
+        try:
+            f = repo.cachevfs(cls._filename(repo))
+            lineiter = iter(f)
+            cachekey = next(lineiter).rstrip(b'\n').split(b" ", 2)
+            last, lrev = cachekey[:2]
+            last, lrev = bin(last), int(lrev)
+            filteredhash = None
+            if len(cachekey) > 2:
+                filteredhash = bin(cachekey[2])
+            bcache = cls(
+                repo,
+                tipnode=last,
+                tiprev=lrev,
+                filteredhash=filteredhash,
+                verify_node=True,
+            )
+            if not bcache.validfor(repo):
+                # invalidate the cache
+                raise ValueError('tip differs')
+            bcache.load(repo, lineiter)
+        except (IOError, OSError):
+            return None
+
+        except Exception as inst:
+            if repo.ui.debugflag:
+                msg = b'invalid %s: %s\n'
+                repo.ui.debug(
+                    msg
+                    % (
+                        _branchcachedesc(repo),
+                        stringutil.forcebytestr(inst),
+                    )
+                )
+            bcache = None
+
+        finally:
+            if f:
+                f.close()
+
+        return bcache
+
+    def load(self, repo, lineiter):
+        """fully loads the branchcache by reading from the file using the line
+        iterator passed"""
+        for line in lineiter:
+            line = line.rstrip(b'\n')
+            if not line:
+                continue
+            node, state, label = line.split(b" ", 2)
+            if state not in b'oc':
+                raise ValueError('invalid branch state')
+            label = encoding.tolocal(label.strip())
+            node = bin(node)
+            self._entries.setdefault(label, []).append(node)
+            if state == b'c':
+                self._closednodes.add(node)
 
-        if ntiprev > self.tiprev:
-            self.tiprev = ntiprev
-            self.tipnode = cl.node(ntiprev)
+    @staticmethod
+    def _filename(repo):
+        """name of a branchcache file for a given repo or repoview"""
+        filename = b"branch2"
+        if repo.filtername:
+            filename = b'%s-%s' % (filename, repo.filtername)
+        return filename
+
+    def copy(self, repo):
+        """return a deep copy of the branchcache object"""
+        other = type(self)(
+            repo=repo,
+            # we always do a shally copy of self._entries, and the values is
+            # always replaced, so no need to deepcopy until the above remains
+            # true.
+            entries=self._entries,
+            tipnode=self.tipnode,
+            tiprev=self.tiprev,
+            filteredhash=self.filteredhash,
+            closednodes=set(self._closednodes),
+            verify_node=self._verify_node,
+        )
+        # we copy will likely schedule a write anyway, but that does not seems
+        # to hurt to overschedule
+        other._delayed = self._delayed
+        # also copy information about the current verification state
+        other._closedverified = self._closedverified
+        other._verifiedbranches = set(self._verifiedbranches)
+        return other
+
+    def write(self, repo):
+        assert self._filtername == repo.filtername, (
+            self._filtername,
+            repo.filtername,
+        )
+        tr = repo.currenttransaction()
+        if not getattr(tr, 'finalized', True):
+            # Avoid premature writing.
+            #
+            # (The cache warming setup by localrepo will update the file later.)
+            self._delayed = True
+            return
+        try:
+            filename = self._filename(repo)
+            with repo.cachevfs(filename, b"w", atomictemp=True) as f:
+                cachekey = [hex(self.tipnode), b'%d' % self.tiprev]
+                if self.filteredhash is not None:
+                    cachekey.append(hex(self.filteredhash))
+                f.write(b" ".join(cachekey) + b'\n')
+                nodecount = 0
+                for label, nodes in sorted(self._entries.items()):
+                    label = encoding.fromlocal(label)
+                    for node in nodes:
+                        nodecount += 1
+                        if node in self._closednodes:
+                            state = b'c'
+                        else:
+                            state = b'o'
+                        f.write(b"%s %s %s\n" % (hex(node), state, label))
+            repo.ui.log(
+                b'branchcache',
+                b'wrote %s with %d labels and %d nodes\n',
+                _branchcachedesc(repo),
+                len(self._entries),
+                nodecount,
+            )
+            self._delayed = False
+        except (IOError, OSError, error.Abort) as inst:
+            # Abort may be raised by read only opener, so log and continue
+            repo.ui.debug(
+                b"couldn't write branch cache: %s\n"
+                % stringutil.forcebytestr(inst)
+            )
+
+    def _verifyclosed(self):
+        """verify the closed nodes we have"""
+        if not self._verify_node:
+            return
+        if self._closedverified:
+            return
+        assert self._hasnode is not None
+        for node in self._closednodes:
+            if not self._hasnode(node):
+                _unknownnode(node)
+
+        self._closedverified = True
+
+    def _verifybranch(self, branch):
+        """verify head nodes for the given branch."""
+        if not self._verify_node:
+            return
+        if branch not in self._entries or branch in self._verifiedbranches:
+            return
+        assert self._hasnode is not None
+        for n in self._entries[branch]:
+            if not self._hasnode(n):
+                _unknownnode(n)
+
+        self._verifiedbranches.add(branch)
+
+    def _verifyall(self):
+        """verifies nodes of all the branches"""
+        for b in self._entries.keys():
+            if b not in self._verifiedbranches:
+                self._verifybranch(b)
+
+    def __getitem__(self, key):
+        self._verifybranch(key)
+        return super().__getitem__(key)
+
+    def __contains__(self, key):
+        self._verifybranch(key)
+        return super().__contains__(key)
+
+    def iteritems(self):
+        self._verifyall()
+        return super().iteritems()
+
+    items = iteritems
+
+    def iterheads(self):
+        """returns all the heads"""
+        self._verifyall()
+        return super().iterheads()
+
+    def hasbranch(self, label):
+        """checks whether a branch of this name exists or not"""
+        self._verifybranch(label)
+        return super().hasbranch(label)
+
+    def branchheads(self, branch, closed=False):
+        self._verifybranch(branch)
+        return super().branchheads(branch, closed=closed)
+
+    def update(self, repo, revgen):
+        assert self._filtername == repo.filtername, (
+            self._filtername,
+            repo.filtername,
+        )
+        cl = repo.changelog
+        max_rev = super().update(repo, revgen)
+        # new tip revision which we found after iterating items from new
+        # branches
+        if max_rev is not None and max_rev > self.tiprev:
+            self.tiprev = max_rev
+            self.tipnode = cl.node(max_rev)
 
         if not self.validfor(repo):
             # old cache key is now invalid for the repo, but we've just updated
@@ -641,24 +694,22 @@
             repo, self.tiprev, needobsolete=True
         )
 
-        duration = util.timer() - starttime
-        repo.ui.log(
-            b'branchcache',
-            b'updated %s in %.4f seconds\n',
-            _branchcachedesc(repo),
-            duration,
-        )
-
         self.write(repo)
 
 
-class branchcache(_BaseBranchCache):
-    """Branchmap info for a local repo or repoview"""
-
-
 class remotebranchcache(_BaseBranchCache):
     """Branchmap info for a remote connection, should not write locally"""
 
+    def __init__(
+        self,
+        repo: "localrepo.localrepository",
+        entries: Union[
+            Dict[bytes, List[bytes]], Iterable[Tuple[bytes, List[bytes]]]
+        ] = (),
+        closednodes: Optional[Set[bytes]] = None,
+    ) -> None:
+        super().__init__(repo=repo, entries=entries, closed_nodes=closednodes)
+
 
 # Revision branch info cache