mergestate: extract a base class to be shared by future memmergestate
authorMartin von Zweigbergk <martinvonz@google.com>
Tue, 15 Sep 2020 11:17:24 -0700
changeset 45498 cc5f811b1f15
parent 45497 e833ff4dd0ea
child 45499 19590b126764
mergestate: extract a base class to be shared by future memmergestate This extracts a new base class from `mergestate` and leaves all the vfs-touching code in `mergestate`. Differential Revision: https://phab.mercurial-scm.org/D9039
mercurial/mergestate.py
--- a/mercurial/mergestate.py	Tue Sep 15 11:33:26 2020 -0700
+++ b/mercurial/mergestate.py	Tue Sep 15 11:17:24 2020 -0700
@@ -127,7 +127,7 @@
 ACTION_CREATED_MERGE = b'cm'
 
 
-class mergestate(object):
+class _mergestate_base(object):
     '''track 3-way merge state of individual files
 
     The merge state is stored on disk when needed. Two files are used: one with
@@ -172,24 +172,6 @@
     'pu' and 'pr' for path conflicts.
     '''
 
-    statepathv1 = b'merge/state'
-    statepathv2 = b'merge/state2'
-
-    @staticmethod
-    def clean(repo):
-        """Initialize a brand new merge state, removing any existing state on
-        disk."""
-        ms = mergestate(repo)
-        ms.reset()
-        return ms
-
-    @staticmethod
-    def read(repo):
-        """Initialize the merge state, reading it from disk."""
-        ms = mergestate(repo)
-        ms._read()
-        return ms
-
     def __init__(self, repo):
         """Initialize the merge state.
 
@@ -220,177 +202,6 @@
         if self.mergedriver:
             self._mdstate = MERGE_DRIVER_STATE_SUCCESS
 
-    def _read(self):
-        """Analyse each record content to restore a serialized state from disk
-
-        This function process "record" entry produced by the de-serialization
-        of on disk file.
-        """
-        self._mdstate = MERGE_DRIVER_STATE_SUCCESS
-        unsupported = set()
-        records = self._readrecords()
-        for rtype, record in records:
-            if rtype == RECORD_LOCAL:
-                self._local = bin(record)
-            elif rtype == RECORD_OTHER:
-                self._other = bin(record)
-            elif rtype == RECORD_MERGE_DRIVER_STATE:
-                bits = record.split(b'\0', 1)
-                mdstate = bits[1]
-                if len(mdstate) != 1 or mdstate not in (
-                    MERGE_DRIVER_STATE_UNMARKED,
-                    MERGE_DRIVER_STATE_MARKED,
-                    MERGE_DRIVER_STATE_SUCCESS,
-                ):
-                    # the merge driver should be idempotent, so just rerun it
-                    mdstate = MERGE_DRIVER_STATE_UNMARKED
-
-                self._readmergedriver = bits[0]
-                self._mdstate = mdstate
-            elif rtype in (
-                RECORD_MERGED,
-                RECORD_CHANGEDELETE_CONFLICT,
-                RECORD_PATH_CONFLICT,
-                RECORD_MERGE_DRIVER_MERGE,
-                LEGACY_RECORD_RESOLVED_OTHER,
-            ):
-                bits = record.split(b'\0')
-                # merge entry type MERGE_RECORD_MERGED_OTHER is deprecated
-                # and we now store related information in _stateextras, so
-                # lets write to _stateextras directly
-                if bits[1] == MERGE_RECORD_MERGED_OTHER:
-                    self._stateextras[bits[0]][b'filenode-source'] = b'other'
-                else:
-                    self._state[bits[0]] = bits[1:]
-            elif rtype == RECORD_FILE_VALUES:
-                filename, rawextras = record.split(b'\0', 1)
-                extraparts = rawextras.split(b'\0')
-                extras = {}
-                i = 0
-                while i < len(extraparts):
-                    extras[extraparts[i]] = extraparts[i + 1]
-                    i += 2
-
-                self._stateextras[filename] = extras
-            elif rtype == RECORD_LABELS:
-                labels = record.split(b'\0', 2)
-                self._labels = [l for l in labels if len(l) > 0]
-            elif not rtype.islower():
-                unsupported.add(rtype)
-
-        if unsupported:
-            raise error.UnsupportedMergeRecords(unsupported)
-
-    def _readrecords(self):
-        """Read merge state from disk and return a list of record (TYPE, data)
-
-        We read data from both v1 and v2 files and decide which one to use.
-
-        V1 has been used by version prior to 2.9.1 and contains less data than
-        v2. We read both versions and check if no data in v2 contradicts
-        v1. If there is not contradiction we can safely assume that both v1
-        and v2 were written at the same time and use the extract data in v2. If
-        there is contradiction we ignore v2 content as we assume an old version
-        of Mercurial has overwritten the mergestate file and left an old v2
-        file around.
-
-        returns list of record [(TYPE, data), ...]"""
-        v1records = self._readrecordsv1()
-        v2records = self._readrecordsv2()
-        if self._v1v2match(v1records, v2records):
-            return v2records
-        else:
-            # v1 file is newer than v2 file, use it
-            # we have to infer the "other" changeset of the merge
-            # we cannot do better than that with v1 of the format
-            mctx = self._repo[None].parents()[-1]
-            v1records.append((RECORD_OTHER, mctx.hex()))
-            # add place holder "other" file node information
-            # nobody is using it yet so we do no need to fetch the data
-            # if mctx was wrong `mctx[bits[-2]]` may fails.
-            for idx, r in enumerate(v1records):
-                if r[0] == RECORD_MERGED:
-                    bits = r[1].split(b'\0')
-                    bits.insert(-2, b'')
-                    v1records[idx] = (r[0], b'\0'.join(bits))
-            return v1records
-
-    def _v1v2match(self, v1records, v2records):
-        oldv2 = set()  # old format version of v2 record
-        for rec in v2records:
-            if rec[0] == RECORD_LOCAL:
-                oldv2.add(rec)
-            elif rec[0] == RECORD_MERGED:
-                # drop the onode data (not contained in v1)
-                oldv2.add((RECORD_MERGED, _droponode(rec[1])))
-        for rec in v1records:
-            if rec not in oldv2:
-                return False
-        else:
-            return True
-
-    def _readrecordsv1(self):
-        """read on disk merge state for version 1 file
-
-        returns list of record [(TYPE, data), ...]
-
-        Note: the "F" data from this file are one entry short
-              (no "other file node" entry)
-        """
-        records = []
-        try:
-            f = self._repo.vfs(self.statepathv1)
-            for i, l in enumerate(f):
-                if i == 0:
-                    records.append((RECORD_LOCAL, l[:-1]))
-                else:
-                    records.append((RECORD_MERGED, l[:-1]))
-            f.close()
-        except IOError as err:
-            if err.errno != errno.ENOENT:
-                raise
-        return records
-
-    def _readrecordsv2(self):
-        """read on disk merge state for version 2 file
-
-        This format is a list of arbitrary records of the form:
-
-          [type][length][content]
-
-        `type` is a single character, `length` is a 4 byte integer, and
-        `content` is an arbitrary byte sequence of length `length`.
-
-        Mercurial versions prior to 3.7 have a bug where if there are
-        unsupported mandatory merge records, attempting to clear out the merge
-        state with hg update --clean or similar aborts. The 't' record type
-        works around that by writing out what those versions treat as an
-        advisory record, but later versions interpret as special: the first
-        character is the 'real' record type and everything onwards is the data.
-
-        Returns list of records [(TYPE, data), ...]."""
-        records = []
-        try:
-            f = self._repo.vfs(self.statepathv2)
-            data = f.read()
-            off = 0
-            end = len(data)
-            while off < end:
-                rtype = data[off : off + 1]
-                off += 1
-                length = _unpack(b'>I', data[off : (off + 4)])[0]
-                off += 4
-                record = data[off : (off + length)]
-                off += length
-                if rtype == RECORD_OVERRIDE:
-                    rtype, record = record[0:1], record[1:]
-                records.append((rtype, record))
-            f.close()
-        except IOError as err:
-            if err.errno != errno.ENOENT:
-                raise
-        return records
-
     @util.propertycache
     def mergedriver(self):
         # protect against the following:
@@ -506,38 +317,6 @@
             records.append((RECORD_LABELS, labels))
         return records
 
-    def _writerecords(self, records):
-        """Write current state on disk (both v1 and v2)"""
-        self._writerecordsv1(records)
-        self._writerecordsv2(records)
-
-    def _writerecordsv1(self, records):
-        """Write current state on disk in a version 1 file"""
-        f = self._repo.vfs(self.statepathv1, b'wb')
-        irecords = iter(records)
-        lrecords = next(irecords)
-        assert lrecords[0] == RECORD_LOCAL
-        f.write(hex(self._local) + b'\n')
-        for rtype, data in irecords:
-            if rtype == RECORD_MERGED:
-                f.write(b'%s\n' % _droponode(data))
-        f.close()
-
-    def _writerecordsv2(self, records):
-        """Write current state on disk in a version 2 file
-
-        See the docstring for _readrecordsv2 for why we use 't'."""
-        # these are the records that all version 2 clients can read
-        allowlist = (RECORD_LOCAL, RECORD_OTHER, RECORD_MERGED)
-        f = self._repo.vfs(self.statepathv2, b'wb')
-        for key, data in records:
-            assert len(key) == 1
-            if key not in allowlist:
-                key, data = RECORD_OVERRIDE, b'%s%s' % (key, data)
-            format = b'>sI%is' % len(data)
-            f.write(_pack(format, key, len(data), data))
-        f.close()
-
     @staticmethod
     def getlocalkey(path):
         """hash the path of a local file context for storage in the .hg/merge
@@ -546,11 +325,10 @@
         return hex(hashutil.sha1(path).digest())
 
     def _make_backup(self, fctx, localkey):
-        self._repo.vfs.write(b'merge/' + localkey, fctx.data())
+        raise NotImplementedError()
 
     def _restore_backup(self, fctx, localkey, flags):
-        with self._repo.vfs(b'merge/' + localkey) as f:
-            fctx.write(f.read(), flags)
+        raise NotImplementedError()
 
     def add(self, fcl, fco, fca, fd):
         """add a new (potentially?) conflicting file the merge state
@@ -789,6 +567,240 @@
         self._results[f] = 0, ACTION_GET
 
 
+class mergestate(_mergestate_base):
+
+    statepathv1 = b'merge/state'
+    statepathv2 = b'merge/state2'
+
+    @staticmethod
+    def clean(repo):
+        """Initialize a brand new merge state, removing any existing state on
+        disk."""
+        ms = mergestate(repo)
+        ms.reset()
+        return ms
+
+    @staticmethod
+    def read(repo):
+        """Initialize the merge state, reading it from disk."""
+        ms = mergestate(repo)
+        ms._read()
+        return ms
+
+    def _read(self):
+        """Analyse each record content to restore a serialized state from disk
+
+        This function process "record" entry produced by the de-serialization
+        of on disk file.
+        """
+        self._mdstate = MERGE_DRIVER_STATE_SUCCESS
+        unsupported = set()
+        records = self._readrecords()
+        for rtype, record in records:
+            if rtype == RECORD_LOCAL:
+                self._local = bin(record)
+            elif rtype == RECORD_OTHER:
+                self._other = bin(record)
+            elif rtype == RECORD_MERGE_DRIVER_STATE:
+                bits = record.split(b'\0', 1)
+                mdstate = bits[1]
+                if len(mdstate) != 1 or mdstate not in (
+                    MERGE_DRIVER_STATE_UNMARKED,
+                    MERGE_DRIVER_STATE_MARKED,
+                    MERGE_DRIVER_STATE_SUCCESS,
+                ):
+                    # the merge driver should be idempotent, so just rerun it
+                    mdstate = MERGE_DRIVER_STATE_UNMARKED
+
+                self._readmergedriver = bits[0]
+                self._mdstate = mdstate
+            elif rtype in (
+                RECORD_MERGED,
+                RECORD_CHANGEDELETE_CONFLICT,
+                RECORD_PATH_CONFLICT,
+                RECORD_MERGE_DRIVER_MERGE,
+                LEGACY_RECORD_RESOLVED_OTHER,
+            ):
+                bits = record.split(b'\0')
+                # merge entry type MERGE_RECORD_MERGED_OTHER is deprecated
+                # and we now store related information in _stateextras, so
+                # lets write to _stateextras directly
+                if bits[1] == MERGE_RECORD_MERGED_OTHER:
+                    self._stateextras[bits[0]][b'filenode-source'] = b'other'
+                else:
+                    self._state[bits[0]] = bits[1:]
+            elif rtype == RECORD_FILE_VALUES:
+                filename, rawextras = record.split(b'\0', 1)
+                extraparts = rawextras.split(b'\0')
+                extras = {}
+                i = 0
+                while i < len(extraparts):
+                    extras[extraparts[i]] = extraparts[i + 1]
+                    i += 2
+
+                self._stateextras[filename] = extras
+            elif rtype == RECORD_LABELS:
+                labels = record.split(b'\0', 2)
+                self._labels = [l for l in labels if len(l) > 0]
+            elif not rtype.islower():
+                unsupported.add(rtype)
+
+        if unsupported:
+            raise error.UnsupportedMergeRecords(unsupported)
+
+    def _readrecords(self):
+        """Read merge state from disk and return a list of record (TYPE, data)
+
+        We read data from both v1 and v2 files and decide which one to use.
+
+        V1 has been used by version prior to 2.9.1 and contains less data than
+        v2. We read both versions and check if no data in v2 contradicts
+        v1. If there is not contradiction we can safely assume that both v1
+        and v2 were written at the same time and use the extract data in v2. If
+        there is contradiction we ignore v2 content as we assume an old version
+        of Mercurial has overwritten the mergestate file and left an old v2
+        file around.
+
+        returns list of record [(TYPE, data), ...]"""
+        v1records = self._readrecordsv1()
+        v2records = self._readrecordsv2()
+        if self._v1v2match(v1records, v2records):
+            return v2records
+        else:
+            # v1 file is newer than v2 file, use it
+            # we have to infer the "other" changeset of the merge
+            # we cannot do better than that with v1 of the format
+            mctx = self._repo[None].parents()[-1]
+            v1records.append((RECORD_OTHER, mctx.hex()))
+            # add place holder "other" file node information
+            # nobody is using it yet so we do no need to fetch the data
+            # if mctx was wrong `mctx[bits[-2]]` may fails.
+            for idx, r in enumerate(v1records):
+                if r[0] == RECORD_MERGED:
+                    bits = r[1].split(b'\0')
+                    bits.insert(-2, b'')
+                    v1records[idx] = (r[0], b'\0'.join(bits))
+            return v1records
+
+    def _v1v2match(self, v1records, v2records):
+        oldv2 = set()  # old format version of v2 record
+        for rec in v2records:
+            if rec[0] == RECORD_LOCAL:
+                oldv2.add(rec)
+            elif rec[0] == RECORD_MERGED:
+                # drop the onode data (not contained in v1)
+                oldv2.add((RECORD_MERGED, _droponode(rec[1])))
+        for rec in v1records:
+            if rec not in oldv2:
+                return False
+        else:
+            return True
+
+    def _readrecordsv1(self):
+        """read on disk merge state for version 1 file
+
+        returns list of record [(TYPE, data), ...]
+
+        Note: the "F" data from this file are one entry short
+              (no "other file node" entry)
+        """
+        records = []
+        try:
+            f = self._repo.vfs(self.statepathv1)
+            for i, l in enumerate(f):
+                if i == 0:
+                    records.append((RECORD_LOCAL, l[:-1]))
+                else:
+                    records.append((RECORD_MERGED, l[:-1]))
+            f.close()
+        except IOError as err:
+            if err.errno != errno.ENOENT:
+                raise
+        return records
+
+    def _readrecordsv2(self):
+        """read on disk merge state for version 2 file
+
+        This format is a list of arbitrary records of the form:
+
+          [type][length][content]
+
+        `type` is a single character, `length` is a 4 byte integer, and
+        `content` is an arbitrary byte sequence of length `length`.
+
+        Mercurial versions prior to 3.7 have a bug where if there are
+        unsupported mandatory merge records, attempting to clear out the merge
+        state with hg update --clean or similar aborts. The 't' record type
+        works around that by writing out what those versions treat as an
+        advisory record, but later versions interpret as special: the first
+        character is the 'real' record type and everything onwards is the data.
+
+        Returns list of records [(TYPE, data), ...]."""
+        records = []
+        try:
+            f = self._repo.vfs(self.statepathv2)
+            data = f.read()
+            off = 0
+            end = len(data)
+            while off < end:
+                rtype = data[off : off + 1]
+                off += 1
+                length = _unpack(b'>I', data[off : (off + 4)])[0]
+                off += 4
+                record = data[off : (off + length)]
+                off += length
+                if rtype == RECORD_OVERRIDE:
+                    rtype, record = record[0:1], record[1:]
+                records.append((rtype, record))
+            f.close()
+        except IOError as err:
+            if err.errno != errno.ENOENT:
+                raise
+        return records
+
+    def _writerecords(self, records):
+        """Write current state on disk (both v1 and v2)"""
+        self._writerecordsv1(records)
+        self._writerecordsv2(records)
+
+    def _writerecordsv1(self, records):
+        """Write current state on disk in a version 1 file"""
+        f = self._repo.vfs(self.statepathv1, b'wb')
+        irecords = iter(records)
+        lrecords = next(irecords)
+        assert lrecords[0] == RECORD_LOCAL
+        f.write(hex(self._local) + b'\n')
+        for rtype, data in irecords:
+            if rtype == RECORD_MERGED:
+                f.write(b'%s\n' % _droponode(data))
+        f.close()
+
+    def _writerecordsv2(self, records):
+        """Write current state on disk in a version 2 file
+
+        See the docstring for _readrecordsv2 for why we use 't'."""
+        # these are the records that all version 2 clients can read
+        allowlist = (RECORD_LOCAL, RECORD_OTHER, RECORD_MERGED)
+        f = self._repo.vfs(self.statepathv2, b'wb')
+        for key, data in records:
+            assert len(key) == 1
+            if key not in allowlist:
+                key, data = RECORD_OVERRIDE, b'%s%s' % (key, data)
+            format = b'>sI%is' % len(data)
+            f.write(_pack(format, key, len(data), data))
+        f.close()
+
+    def _make_backup(self, fctx, localkey):
+        self._repo.vfs.write(b'merge/' + localkey, fctx.data())
+
+    def _restore_backup(self, fctx, localkey, flags):
+        with self._repo.vfs(b'merge/' + localkey) as f:
+            fctx.write(f.read(), flags)
+
+    def reset(self):
+        shutil.rmtree(self._repo.vfs.join(b'merge'), True)
+
+
 def recordupdates(repo, actions, branchmerge, getfiledata):
     """record merge actions to the dirstate"""
     # remove (must come first)