branchcache: pass the target repository when copying
authorPierre-Yves David <pierre-yves.david@octobus.net>
Mon, 19 Feb 2024 11:43:19 +0100
changeset 51450 3aba79ce52a9
parent 51449 7f7086a42b2b
child 51451 fd30c4301929
branchcache: pass the target repository when copying Branchmap are usually copied to be used on a different repoview using a different filter level. Passing the repository around means the repository in `branchcache._repo` will drift from the actual branchmap filter. This is currently "fine" because the repo is only used to retrieve the `nullid` value. However, this is a fairly big trap for any extension or future code using the `_repo` attribute. The replace logic is now using a copy to ensure the right repository view is used to initialized the cached value. We add a couple of assert for make sure this inconsistency does not sneak back.
contrib/perf.py
mercurial/branchmap.py
--- a/contrib/perf.py	Fri Jan 19 11:30:10 2024 +0100
+++ b/contrib/perf.py	Mon Feb 19 11:43:19 2024 +0100
@@ -4303,6 +4303,11 @@
         baserepo = repo.filtered(b'__perf_branchmap_update_base')
         targetrepo = repo.filtered(b'__perf_branchmap_update_target')
 
+        copy_base_kwargs = copy_base_kwargs = {}
+        if 'repo' in getargspec(repo.branchmap().copy).args:
+            copy_base_kwargs = {"repo": baserepo}
+            copy_target_kwargs = {"repo": targetrepo}
+
         # try to find an existing branchmap to reuse
         subsettable = getbranchmapsubsettable()
         candidatefilter = subsettable.get(None)
@@ -4311,7 +4316,7 @@
             if candidatebm.validfor(baserepo):
                 filtered = repoview.filterrevs(repo, candidatefilter)
                 missing = [r for r in allbaserevs if r in filtered]
-                base = candidatebm.copy()
+                base = candidatebm.copy(**copy_base_kwargs)
                 base.update(baserepo, missing)
                 break
             candidatefilter = subsettable.get(candidatefilter)
@@ -4321,7 +4326,7 @@
             base.update(baserepo, allbaserevs)
 
         def setup():
-            x[0] = base.copy()
+            x[0] = base.copy(**copy_target_kwargs)
             if clearcaches:
                 unfi._revbranchcache = None
                 clearchangelog(repo)
--- a/mercurial/branchmap.py	Fri Jan 19 11:30:10 2024 +0100
+++ b/mercurial/branchmap.py	Mon Feb 19 11:43:19 2024 +0100
@@ -60,6 +60,10 @@
     def __getitem__(self, repo):
         self.updatecache(repo)
         bcache = self._per_filter[repo.filtername]
+        assert bcache._repo.filtername == repo.filtername, (
+            bcache._repo.filtername,
+            repo.filtername,
+        )
         return bcache
 
     def update_disk(self, repo):
@@ -76,6 +80,10 @@
         """
         self.updatecache(repo)
         bcache = self._per_filter[repo.filtername]
+        assert bcache._repo.filtername == repo.filtername, (
+            bcache._repo.filtername,
+            repo.filtername,
+        )
         bcache.write(repo)
 
     def updatecache(self, repo):
@@ -99,7 +107,7 @@
             subsetname = subsettable.get(filtername)
             if subsetname is not None:
                 subset = repo.filtered(subsetname)
-                bcache = self[subset].copy()
+                bcache = self[subset].copy(repo)
                 extrarevs = subset.changelog.filteredrevs - cl.filteredrevs
                 revs.extend(r for r in extrarevs if r <= bcache.tiprev)
             else:
@@ -148,7 +156,7 @@
             for candidate in (b'base', b'immutable', b'served'):
                 rview = repo.filtered(candidate)
                 if cache.validfor(rview):
-                    self._per_filter[candidate] = cache
+                    cache = self._per_filter[candidate] = cache.copy(rview)
                     cache.write(rview)
                     return
 
@@ -415,10 +423,10 @@
         self._verifyall()
         return self._entries.values()
 
-    def copy(self):
+    def copy(self, repo):
         """return an deep copy of the branchcache object"""
         return type(self)(
-            self._repo,
+            repo,
             self._entries,
             self.tipnode,
             self.tiprev,
@@ -427,6 +435,10 @@
         )
 
     def write(self, repo):
+        assert self._repo.filtername == repo.filtername, (
+            self._repo.filtername,
+            repo.filtername,
+        )
         tr = repo.currenttransaction()
         if not getattr(tr, 'finalized', True):
             # Avoid premature writing.
@@ -471,6 +483,10 @@
         missing heads, and a generator of nodes that are strictly a superset of
         heads missing, this function updates self to be correct.
         """
+        assert self._repo.filtername == repo.filtername, (
+            self._repo.filtername,
+            repo.filtername,
+        )
         starttime = util.timer()
         cl = repo.changelog
         # collect new branch entries