branchcache: dispatch the code into the dedicated subclass
authorPierre-Yves David <pierre-yves.david@octobus.net>
Mon, 26 Feb 2024 15:46:24 +0100
changeset 51455 7a063dd9d64e
parent 51454 84fca6d79e25
child 51456 19b2736c8e45
branchcache: dispatch the code into the dedicated subclass The code useful only to the local brancache have now been moved into the dedicated subclass. This will help improving the branchcache code without subtle breaking the remote variants.
mercurial/branchmap.py
--- a/mercurial/branchmap.py	Sun Feb 25 14:09:36 2024 +0100
+++ b/mercurial/branchmap.py	Mon Feb 26 15:46:24 2024 +0100
@@ -211,73 +211,18 @@
         entries: Union[
             Dict[bytes, List[bytes]], Iterable[Tuple[bytes, List[bytes]]]
         ] = (),
-        tipnode: Optional[bytes] = None,
-        tiprev: Optional[int] = nullrev,
-        filteredhash: Optional[bytes] = None,
-        closednodes: Optional[Set[bytes]] = None,
-        hasnode: Optional[Callable[[bytes], bool]] = None,
-        verify_node: bool = False,
+        closed_nodes: Optional[Set[bytes]] = None,
     ) -> None:
         """hasnode is a function which can be used to verify whether changelog
         has a given node or not. If it's not provided, we assume that every node
         we have exists in changelog"""
-        self._filtername = repo.filtername
-        self._delayed = False
-        if tipnode is None:
-            self.tipnode = repo.nullid
-        else:
-            self.tipnode = tipnode
-        self.tiprev = tiprev
-        self.filteredhash = filteredhash
         # closednodes is a set of nodes that close their branch. If the branch
         # cache has been updated, it may contain nodes that are no longer
         # heads.
-        if closednodes is None:
-            self._closednodes = set()
-        else:
-            self._closednodes = closednodes
+        if closed_nodes is None:
+            closed_nodes = set()
+        self._closednodes = set(closed_nodes)
         self._entries = dict(entries)
-        # Do we need to verify branch at all ?
-        self._verify_node = verify_node
-        # whether closed nodes are verified or not
-        self._closedverified = False
-        # branches for which nodes are verified
-        self._verifiedbranches = set()
-        self._hasnode = None
-        if self._verify_node:
-            self._hasnode = repo.changelog.hasnode
-
-    def _verifyclosed(self):
-        """verify the closed nodes we have"""
-        if not self._verify_node:
-            return
-        if self._closedverified:
-            return
-        assert self._hasnode is not None
-        for node in self._closednodes:
-            if not self._hasnode(node):
-                _unknownnode(node)
-
-        self._closedverified = True
-
-    def _verifybranch(self, branch):
-        """verify head nodes for the given branch."""
-        if not self._verify_node:
-            return
-        if branch not in self._entries or branch in self._verifiedbranches:
-            return
-        assert self._hasnode is not None
-        for n in self._entries[branch]:
-            if not self._hasnode(n):
-                _unknownnode(n)
-
-        self._verifiedbranches.add(branch)
-
-    def _verifyall(self):
-        """verifies nodes of all the branches"""
-        needverification = set(self._entries.keys()) - self._verifiedbranches
-        for b in needverification:
-            self._verifybranch(b)
 
     def __iter__(self):
         return iter(self._entries)
@@ -286,114 +231,20 @@
         self._entries[key] = value
 
     def __getitem__(self, key):
-        self._verifybranch(key)
         return self._entries[key]
 
     def __contains__(self, key):
-        self._verifybranch(key)
         return key in self._entries
 
     def iteritems(self):
-        for k, v in self._entries.items():
-            self._verifybranch(k)
-            yield k, v
+        return self._entries.items()
 
     items = iteritems
 
     def hasbranch(self, label):
         """checks whether a branch of this name exists or not"""
-        self._verifybranch(label)
         return label in self._entries
 
-    @classmethod
-    def fromfile(cls, repo):
-        f = None
-        try:
-            f = repo.cachevfs(cls._filename(repo))
-            lineiter = iter(f)
-            cachekey = next(lineiter).rstrip(b'\n').split(b" ", 2)
-            last, lrev = cachekey[:2]
-            last, lrev = bin(last), int(lrev)
-            filteredhash = None
-            if len(cachekey) > 2:
-                filteredhash = bin(cachekey[2])
-            bcache = cls(
-                repo,
-                tipnode=last,
-                tiprev=lrev,
-                filteredhash=filteredhash,
-                verify_node=True,
-            )
-            if not bcache.validfor(repo):
-                # invalidate the cache
-                raise ValueError('tip differs')
-            bcache.load(repo, lineiter)
-        except (IOError, OSError):
-            return None
-
-        except Exception as inst:
-            if repo.ui.debugflag:
-                msg = b'invalid %s: %s\n'
-                repo.ui.debug(
-                    msg
-                    % (
-                        _branchcachedesc(repo),
-                        stringutil.forcebytestr(inst),
-                    )
-                )
-            bcache = None
-
-        finally:
-            if f:
-                f.close()
-
-        return bcache
-
-    def load(self, repo, lineiter):
-        """fully loads the branchcache by reading from the file using the line
-        iterator passed"""
-        for line in lineiter:
-            line = line.rstrip(b'\n')
-            if not line:
-                continue
-            node, state, label = line.split(b" ", 2)
-            if state not in b'oc':
-                raise ValueError('invalid branch state')
-            label = encoding.tolocal(label.strip())
-            node = bin(node)
-            self._entries.setdefault(label, []).append(node)
-            if state == b'c':
-                self._closednodes.add(node)
-
-    @staticmethod
-    def _filename(repo):
-        """name of a branchcache file for a given repo or repoview"""
-        filename = b"branch2"
-        if repo.filtername:
-            filename = b'%s-%s' % (filename, repo.filtername)
-        return filename
-
-    def validfor(self, repo):
-        """check that cache contents are valid for (a subset of) this repo
-
-        - False when the order of changesets changed or if we detect a strip.
-        - True when cache is up-to-date for the current repo or its subset."""
-        try:
-            node = repo.changelog.node(self.tiprev)
-        except IndexError:
-            # changesets were stripped and now we don't even have enough to
-            # find tiprev
-            return False
-        if self.tipnode != node:
-            # tiprev doesn't correspond to tipnode: repo was stripped, or this
-            # repo has a different order of changesets
-            return False
-        tiphash = scmutil.filteredhash(repo, self.tiprev, needobsolete=True)
-        # hashes don't match if this repo view has a different set of filtered
-        # revisions (e.g. due to phase changes) or obsolete revisions (e.g.
-        # history was rewritten)
-        return self.filteredhash == tiphash
-
     def _branchtip(self, heads):
         """Return tuple with last open head in heads and false,
         otherwise return last closed head and true."""
@@ -416,7 +267,6 @@
         return (n for n in nodes if n not in self._closednodes)
 
     def branchheads(self, branch, closed=False):
-        self._verifybranch(branch)
         heads = self._entries[branch]
         if not closed:
             heads = list(self.iteropen(heads))
@@ -428,98 +278,27 @@
 
     def iterheads(self):
         """returns all the heads"""
-        self._verifyall()
         return self._entries.values()
 
-    def copy(self, repo):
-        """return a deep copy of the branchcache object"""
-        other = type(self)(
-            repo=repo,
-            # we always do a shally copy of self._entries, and the values is
-            # always replaced, so no need to deepcopy until the above remains
-            # true.
-            entries=self._entries,
-            tipnode=self.tipnode,
-            tiprev=self.tiprev,
-            filteredhash=self.filteredhash,
-            closednodes=set(self._closednodes),
-            verify_node=self._verify_node,
-        )
-        # we copy will likely schedule a write anyway, but that does not seems
-        # to hurt to overschedule
-        other._delayed = self._delayed
-        # also copy information about the current verification state
-        other._closedverified = self._closedverified
-        other._verifiedbranches = set(self._verifiedbranches)
-        return other
-
-    def write(self, repo):
-        assert self._filtername == repo.filtername, (
-            self._filtername,
-            repo.filtername,
-        )
-        tr = repo.currenttransaction()
-        if not getattr(tr, 'finalized', True):
-            # Avoid premature writing.
-            #
-            # (The cache warming setup by localrepo will update the file later.)
-            self._delayed = True
-            return
-        try:
-            filename = self._filename(repo)
-            with repo.cachevfs(filename, b"w", atomictemp=True) as f:
-                cachekey = [hex(self.tipnode), b'%d' % self.tiprev]
-                if self.filteredhash is not None:
-                    cachekey.append(hex(self.filteredhash))
-                f.write(b" ".join(cachekey) + b'\n')
-                nodecount = 0
-                for label, nodes in sorted(self._entries.items()):
-                    label = encoding.fromlocal(label)
-                    for node in nodes:
-                        nodecount += 1
-                        if node in self._closednodes:
-                            state = b'c'
-                        else:
-                            state = b'o'
-                        f.write(b"%s %s %s\n" % (hex(node), state, label))
-            repo.ui.log(
-                b'branchcache',
-                b'wrote %s with %d labels and %d nodes\n',
-                _branchcachedesc(repo),
-                len(self._entries),
-                nodecount,
-            )
-            self._delayed = False
-        except (IOError, OSError, error.Abort) as inst:
-            # Abort may be raised by read only opener, so log and continue
-            repo.ui.debug(
-                b"couldn't write branch cache: %s\n"
-                % stringutil.forcebytestr(inst)
-            )
-
     def update(self, repo, revgen):
         """Given a branchhead cache, self, that may have extra nodes or be
         missing heads, and a generator of nodes that are strictly a superset of
         heads missing, this function updates self to be correct.
         """
-        assert self._filtername == repo.filtername, (
-            self._filtername,
-            repo.filtername,
-        )
         starttime = util.timer()
         cl = repo.changelog
         # collect new branch entries
         newbranches = {}
         getbranchinfo = repo.revbranchcache().branchinfo
+        max_rev = -1
         for r in revgen:
             branch, closesbranch = getbranchinfo(r)
             newbranches.setdefault(branch, []).append(r)
             if closesbranch:
                 self._closednodes.add(cl.node(r))
-
-        # new tip revision which we found after iterating items from new
-        # branches
-        ntiprev = self.tiprev
+            max_rev = max(max_rev, r)
+        if max_rev < 0:
+            max_rev = None
 
         # Delay fetching the topological heads until they are needed.
         # A repository without non-continous branches can skip this part.
@@ -613,13 +392,287 @@
                         bheadset -= ancestors
             if bheadset:
                 self[branch] = [cl.node(rev) for rev in sorted(bheadset)]
-            tiprev = max(newheadrevs)
-            if tiprev > ntiprev:
-                ntiprev = tiprev
+
+        duration = util.timer() - starttime
+        repo.ui.log(
+            b'branchcache',
+            b'updated %s in %.4f seconds\n',
+            _branchcachedesc(repo),
+            duration,
+        )
+        return max_rev
+
+
+class branchcache(_BaseBranchCache):
+    """Branchmap info for a local repo or repoview"""
+
+    def __init__(
+        self,
+        repo: "localrepo.localrepository",
+        entries: Union[
+            Dict[bytes, List[bytes]], Iterable[Tuple[bytes, List[bytes]]]
+        ] = (),
+        tipnode: Optional[bytes] = None,
+        tiprev: Optional[int] = nullrev,
+        filteredhash: Optional[bytes] = None,
+        closednodes: Optional[Set[bytes]] = None,
+        hasnode: Optional[Callable[[bytes], bool]] = None,
+        verify_node: bool = False,
+    ) -> None:
+        """hasnode is a function which can be used to verify whether changelog
+        has a given node or not. If it's not provided, we assume that every node
+        we have exists in changelog"""
+        self._filtername = repo.filtername
+        self._delayed = False
+        if tipnode is None:
+            self.tipnode = repo.nullid
+        else:
+            self.tipnode = tipnode
+        self.tiprev = tiprev
+        self.filteredhash = filteredhash
+
+        super().__init__(repo=repo, entries=entries, closed_nodes=closednodes)
+        # closednodes is a set of nodes that close their branch. If the branch
+        # cache has been updated, it may contain nodes that are no longer
+        # heads.
+
+        # Do we need to verify branch at all ?
+        self._verify_node = verify_node
+        # whether closed nodes are verified or not
+        self._closedverified = False
+        # branches for which nodes are verified
+        self._verifiedbranches = set()
+        self._hasnode = None
+        if self._verify_node:
+            self._hasnode = repo.changelog.hasnode
+
+    def validfor(self, repo):
+        """check that cache contents are valid for (a subset of) this repo
+
+        - False when the order of changesets changed or if we detect a strip.
+        - True when cache is up-to-date for the current repo or its subset."""
+        try:
+            node = repo.changelog.node(self.tiprev)
+        except IndexError:
+            # changesets were stripped and now we don't even have enough to
+            # find tiprev
+            return False
+        if self.tipnode != node:
+            # tiprev doesn't correspond to tipnode: repo was stripped, or this
+            # repo has a different order of changesets
+            return False
+        tiphash = scmutil.filteredhash(repo, self.tiprev, needobsolete=True)
+        # hashes don't match if this repo view has a different set of filtered
+        # revisions (e.g. due to phase changes) or obsolete revisions (e.g.
+        # history was rewritten)
+        return self.filteredhash == tiphash
+
+    @classmethod
+    def fromfile(cls, repo):
+        f = None
+        try:
+            f = repo.cachevfs(cls._filename(repo))
+            lineiter = iter(f)
+            cachekey = next(lineiter).rstrip(b'\n').split(b" ", 2)
+            last, lrev = cachekey[:2]
+            last, lrev = bin(last), int(lrev)
+            filteredhash = None
+            if len(cachekey) > 2:
+                filteredhash = bin(cachekey[2])
+            bcache = cls(
+                repo,
+                tipnode=last,
+                tiprev=lrev,
+                filteredhash=filteredhash,
+                verify_node=True,
+            )
+            if not bcache.validfor(repo):
+                # invalidate the cache
+                raise ValueError('tip differs')
+            bcache.load(repo, lineiter)
+        except (IOError, OSError):
+            return None
+
+        except Exception as inst:
+            if repo.ui.debugflag:
+                msg = b'invalid %s: %s\n'
+                repo.ui.debug(
+                    msg
+                    % (
+                        _branchcachedesc(repo),
+                        stringutil.forcebytestr(inst),
+                    )
+                )
+            bcache = None
+
+        finally:
+            if f:
+                f.close()
+
+        return bcache
+
+    def load(self, repo, lineiter):
+        """fully loads the branchcache by reading from the file using the line
+        iterator passed"""
+        for line in lineiter:
+            line = line.rstrip(b'\n')
+            if not line:
+                continue
+            node, state, label = line.split(b" ", 2)
+            if state not in b'oc':
+                raise ValueError('invalid branch state')
+            label = encoding.tolocal(label.strip())
+            node = bin(node)
+            self._entries.setdefault(label, []).append(node)
+            if state == b'c':
+                self._closednodes.add(node)
 
-        if ntiprev > self.tiprev:
-            self.tiprev = ntiprev
-            self.tipnode = cl.node(ntiprev)
+    @staticmethod
+    def _filename(repo):
+        """name of a branchcache file for a given repo or repoview"""
+        filename = b"branch2"
+        if repo.filtername:
+            filename = b'%s-%s' % (filename, repo.filtername)
+        return filename
+
+    def copy(self, repo):
+        """return a deep copy of the branchcache object"""
+        other = type(self)(
+            repo=repo,
+            # we always do a shally copy of self._entries, and the values is
+            # always replaced, so no need to deepcopy until the above remains
+            # true.
+            entries=self._entries,
+            tipnode=self.tipnode,
+            tiprev=self.tiprev,
+            filteredhash=self.filteredhash,
+            closednodes=set(self._closednodes),
+            verify_node=self._verify_node,
+        )
+        # we copy will likely schedule a write anyway, but that does not seems
+        # to hurt to overschedule
+        other._delayed = self._delayed
+        # also copy information about the current verification state
+        other._closedverified = self._closedverified
+        other._verifiedbranches = set(self._verifiedbranches)
+        return other
+
+    def write(self, repo):
+        assert self._filtername == repo.filtername, (
+            self._filtername,
+            repo.filtername,
+        )
+        tr = repo.currenttransaction()
+        if not getattr(tr, 'finalized', True):
+            # Avoid premature writing.
+            #
+            # (The cache warming setup by localrepo will update the file later.)
+            self._delayed = True
+            return
+        try:
+            filename = self._filename(repo)
+            with repo.cachevfs(filename, b"w", atomictemp=True) as f:
+                cachekey = [hex(self.tipnode), b'%d' % self.tiprev]
+                if self.filteredhash is not None:
+                    cachekey.append(hex(self.filteredhash))
+                f.write(b" ".join(cachekey) + b'\n')
+                nodecount = 0
+                for label, nodes in sorted(self._entries.items()):
+                    label = encoding.fromlocal(label)
+                    for node in nodes:
+                        nodecount += 1
+                        if node in self._closednodes:
+                            state = b'c'
+                        else:
+                            state = b'o'
+                        f.write(b"%s %s %s\n" % (hex(node), state, label))
+            repo.ui.log(
+                b'branchcache',
+                b'wrote %s with %d labels and %d nodes\n',
+                _branchcachedesc(repo),
+                len(self._entries),
+                nodecount,
+            )
+            self._delayed = False
+        except (IOError, OSError, error.Abort) as inst:
+            # Abort may be raised by read only opener, so log and continue
+            repo.ui.debug(
+                b"couldn't write branch cache: %s\n"
+                % stringutil.forcebytestr(inst)
+            )
+
+    def _verifyclosed(self):
+        """verify the closed nodes we have"""
+        if not self._verify_node:
+            return
+        if self._closedverified:
+            return
+        assert self._hasnode is not None
+        for node in self._closednodes:
+            if not self._hasnode(node):
+                _unknownnode(node)
+
+        self._closedverified = True
+
+    def _verifybranch(self, branch):
+        """verify head nodes for the given branch."""
+        if not self._verify_node:
+            return
+        if branch not in self._entries or branch in self._verifiedbranches:
+            return
+        assert self._hasnode is not None
+        for n in self._entries[branch]:
+            if not self._hasnode(n):
+                _unknownnode(n)
+
+        self._verifiedbranches.add(branch)
+
+    def _verifyall(self):
+        """verifies nodes of all the branches"""
+        for b in self._entries.keys():
+            if b not in self._verifiedbranches:
+                self._verifybranch(b)
+
+    def __getitem__(self, key):
+        self._verifybranch(key)
+        return super().__getitem__(key)
+
+    def __contains__(self, key):
+        self._verifybranch(key)
+        return super().__contains__(key)
+
+    def iteritems(self):
+        self._verifyall()
+        return super().iteritems()
+
+    items = iteritems
+
+    def iterheads(self):
+        """returns all the heads"""
+        self._verifyall()
+        return super().iterheads()
+
+    def hasbranch(self, label):
+        """checks whether a branch of this name exists or not"""
+        self._verifybranch(label)
+        return super().hasbranch(label)
+
+    def branchheads(self, branch, closed=False):
+        self._verifybranch(branch)
+        return super().branchheads(branch, closed=closed)
+
+    def update(self, repo, revgen):
+        assert self._filtername == repo.filtername, (
+            self._filtername,
+            repo.filtername,
+        )
+        cl = repo.changelog
+        max_rev = super().update(repo, revgen)
+        # new tip revision which we found after iterating items from new
+        # branches
+        if max_rev is not None and max_rev > self.tiprev:
+            self.tiprev = max_rev
+            self.tipnode = cl.node(max_rev)
 
         if not self.validfor(repo):
             # old cache key is now invalid for the repo, but we've just updated
@@ -641,24 +694,22 @@
             repo, self.tiprev, needobsolete=True
         )
 
-        duration = util.timer() - starttime
-        repo.ui.log(
-            b'branchcache',
-            b'updated %s in %.4f seconds\n',
-            _branchcachedesc(repo),
-            duration,
-        )
-
         self.write(repo)
 
 
-class branchcache(_BaseBranchCache):
-    """Branchmap info for a local repo or repoview"""
-
-
 class remotebranchcache(_BaseBranchCache):
     """Branchmap info for a remote connection, should not write locally"""
 
+    def __init__(
+        self,
+        repo: "localrepo.localrepository",
+        entries: Union[
+            Dict[bytes, List[bytes]], Iterable[Tuple[bytes, List[bytes]]]
+        ] = (),
+        closednodes: Optional[Set[bytes]] = None,
+    ) -> None:
+        super().__init__(repo=repo, entries=entries, closed_nodes=closednodes)
+
 
 # Revision branch info cache