mercurial/branchmap.py
changeset 51450 3aba79ce52a9
parent 51449 7f7086a42b2b
child 51451 fd30c4301929
--- a/mercurial/branchmap.py	Fri Jan 19 11:30:10 2024 +0100
+++ b/mercurial/branchmap.py	Mon Feb 19 11:43:19 2024 +0100
@@ -60,6 +60,10 @@
     def __getitem__(self, repo):
         self.updatecache(repo)
         bcache = self._per_filter[repo.filtername]
+        assert bcache._repo.filtername == repo.filtername, (
+            bcache._repo.filtername,
+            repo.filtername,
+        )
         return bcache
 
     def update_disk(self, repo):
@@ -76,6 +80,10 @@
         """
         self.updatecache(repo)
         bcache = self._per_filter[repo.filtername]
+        assert bcache._repo.filtername == repo.filtername, (
+            bcache._repo.filtername,
+            repo.filtername,
+        )
         bcache.write(repo)
 
     def updatecache(self, repo):
@@ -99,7 +107,7 @@
             subsetname = subsettable.get(filtername)
             if subsetname is not None:
                 subset = repo.filtered(subsetname)
-                bcache = self[subset].copy()
+                bcache = self[subset].copy(repo)
                 extrarevs = subset.changelog.filteredrevs - cl.filteredrevs
                 revs.extend(r for r in extrarevs if r <= bcache.tiprev)
             else:
@@ -148,7 +156,7 @@
             for candidate in (b'base', b'immutable', b'served'):
                 rview = repo.filtered(candidate)
                 if cache.validfor(rview):
-                    self._per_filter[candidate] = cache
+                    cache = self._per_filter[candidate] = cache.copy(rview)
                     cache.write(rview)
                     return
 
@@ -415,10 +423,10 @@
         self._verifyall()
         return self._entries.values()
 
-    def copy(self):
+    def copy(self, repo):
         """return an deep copy of the branchcache object"""
         return type(self)(
-            self._repo,
+            repo,
             self._entries,
             self.tipnode,
             self.tiprev,
@@ -427,6 +435,10 @@
         )
 
     def write(self, repo):
+        assert self._repo.filtername == repo.filtername, (
+            self._repo.filtername,
+            repo.filtername,
+        )
         tr = repo.currenttransaction()
         if not getattr(tr, 'finalized', True):
             # Avoid premature writing.
@@ -471,6 +483,10 @@
         missing heads, and a generator of nodes that are strictly a superset of
         heads missing, this function updates self to be correct.
         """
+        assert self._repo.filtername == repo.filtername, (
+            self._repo.filtername,
+            repo.filtername,
+        )
         starttime = util.timer()
         cl = repo.changelog
         # collect new branch entries