hgext/journal.py
changeset 29503 0103b673d6ca
parent 29502 8361131b4768
child 29504 7503d8874617
--- a/hgext/journal.py	Mon Jul 11 13:39:24 2016 +0100
+++ b/hgext/journal.py	Mon Jul 11 14:45:41 2016 +0100
@@ -14,6 +14,7 @@
 from __future__ import absolute_import
 
 import collections
+import errno
 import os
 import weakref
 
@@ -27,12 +28,15 @@
     dispatch,
     error,
     extensions,
+    hg,
     localrepo,
     lock,
     node,
     util,
 )
 
+from . import share
+
 cmdtable = {}
 command = cmdutil.command(cmdtable)
 
@@ -48,6 +52,11 @@
 # namespaces
 bookmarktype = 'bookmark'
 wdirparenttype = 'wdirparent'
+# In a shared repository, what shared feature name is used
+# to indicate this namespace is shared with the source?
+sharednamespaces = {
+    bookmarktype: hg.sharedbookmarks,
+}
 
 # Journal recording, register hooks and storage object
 def extsetup(ui):
@@ -57,6 +66,8 @@
         dirstate.dirstate, '_writedirstate', recorddirstateparents)
     extensions.wrapfunction(
         localrepo.localrepository.dirstate, 'func', wrapdirstate)
+    extensions.wrapfunction(hg, 'postshare', wrappostshare)
+    extensions.wrapfunction(hg, 'copystore', unsharejournal)
 
 def reposetup(ui, repo):
     if repo.local():
@@ -114,6 +125,74 @@
                 repo.journal.record(bookmarktype, mark, oldvalue, value)
     return orig(store, fp)
 
+# shared repository support
+def _readsharedfeatures(repo):
+    """A set of shared features for this repository"""
+    try:
+        return set(repo.vfs.read('shared').splitlines())
+    except IOError as inst:
+        if inst.errno != errno.ENOENT:
+            raise
+        return set()
+
+def _mergeentriesiter(*iterables, **kwargs):
+    """Given a set of sorted iterables, yield the next entry in merged order
+
+    Note that by default entries go from most recent to oldest.
+    """
+    order = kwargs.pop('order', max)
+    iterables = [iter(it) for it in iterables]
+    # this tracks still active iterables; iterables are deleted as they are
+    # exhausted, which is why this is a dictionary and why each entry also
+    # stores the key. Entries are mutable so we can store the next value each
+    # time.
+    iterable_map = {}
+    for key, it in enumerate(iterables):
+        try:
+            iterable_map[key] = [next(it), key, it]
+        except StopIteration:
+            # empty entry, can be ignored
+            pass
+
+    while iterable_map:
+        value, key, it = order(iterable_map.itervalues())
+        yield value
+        try:
+            iterable_map[key][0] = next(it)
+        except StopIteration:
+            # this iterable is empty, remove it from consideration
+            del iterable_map[key]
+
+def wrappostshare(orig, sourcerepo, destrepo, **kwargs):
+    """Mark this shared working copy as sharing journal information"""
+    orig(sourcerepo, destrepo, **kwargs)
+    with destrepo.vfs('shared', 'a') as fp:
+        fp.write('journal\n')
+
+def unsharejournal(orig, ui, repo, repopath):
+    """Copy shared journal entries into this repo when unsharing"""
+    if (repo.path == repopath and repo.shared() and
+            util.safehasattr(repo, 'journal')):
+        sharedrepo = share._getsrcrepo(repo)
+        sharedfeatures = _readsharedfeatures(repo)
+        if sharedrepo and sharedfeatures > set(['journal']):
+            # there is a shared repository and there are shared journal entries
+            # to copy. move shared date over from source to destination but
+            # move the local file first
+            if repo.vfs.exists('journal'):
+                journalpath = repo.join('journal')
+                util.rename(journalpath, journalpath + '.bak')
+            storage = repo.journal
+            local = storage._open(
+                repo.vfs, filename='journal.bak', _newestfirst=False)
+            shared = (
+                e for e in storage._open(sharedrepo.vfs, _newestfirst=False)
+                if sharednamespaces.get(e.namespace) in sharedfeatures)
+            for entry in _mergeentriesiter(local, shared, order=min):
+                storage._write(repo.vfs, entry)
+
+    return orig(ui, repo, repopath)
+
 class journalentry(collections.namedtuple(
         'journalentry',
         'timestamp user command namespace name oldhashes newhashes')):
@@ -157,6 +236,10 @@
 class journalstorage(object):
     """Storage for journal entries
 
+    Entries are divided over two files; one with entries that pertain to the
+    local working copy *only*, and one with entries that are shared across
+    multiple working copies when shared using the share extension.
+
     Entries are stored with NUL bytes as separators. See the journalentry
     class for the per-entry structure.
 
@@ -175,6 +258,15 @@
         self.ui = repo.ui
         self.vfs = repo.vfs
 
+        # is this working copy using a shared storage?
+        self.sharedfeatures = self.sharedvfs = None
+        if repo.shared():
+            features = _readsharedfeatures(repo)
+            sharedrepo = share._getsrcrepo(repo)
+            if sharedrepo is not None and 'journal' in features:
+                self.sharedvfs = sharedrepo.vfs
+                self.sharedfeatures = features
+
     # track the current command for recording in journal entries
     @property
     def command(self):
@@ -192,19 +284,19 @@
         # with a non-local repo (cloning for example).
         cls._currentcommand = fullargs
 
-    def jlock(self):
+    def jlock(self, vfs):
         """Create a lock for the journal file"""
         if self._lockref and self._lockref():
             raise error.Abort(_('journal lock does not support nesting'))
-        desc = _('journal of %s') % self.vfs.base
+        desc = _('journal of %s') % vfs.base
         try:
-            l = lock.lock(self.vfs, 'journal.lock', 0, desc=desc)
+            l = lock.lock(vfs, 'journal.lock', 0, desc=desc)
         except error.LockHeld as inst:
             self.ui.warn(
                 _("waiting for lock on %s held by %r\n") % (desc, inst.locker))
             # default to 600 seconds timeout
             l = lock.lock(
-                self.vfs, 'journal.lock',
+                vfs, 'journal.lock',
                 int(self.ui.config("ui", "timeout", "600")), desc=desc)
             self.ui.warn(_("got lock after %s seconds\n") % l.delay)
         self._lockref = weakref.ref(l)
@@ -231,10 +323,20 @@
             util.makedate(), self.user, self.command, namespace, name,
             oldhashes, newhashes)
 
-        with self.jlock():
+        vfs = self.vfs
+        if self.sharedvfs is not None:
+            # write to the shared repository if this feature is being
+            # shared between working copies.
+            if sharednamespaces.get(namespace) in self.sharedfeatures:
+                vfs = self.sharedvfs
+
+        self._write(vfs, entry)
+
+    def _write(self, vfs, entry):
+        with self.jlock(vfs):
             version = None
             # open file in amend mode to ensure it is created if missing
-            with self.vfs('journal', mode='a+b', atomictemp=True) as f:
+            with vfs('journal', mode='a+b', atomictemp=True) as f:
                 f.seek(0, os.SEEK_SET)
                 # Read just enough bytes to get a version number (up to 2
                 # digits plus separator)
@@ -273,10 +375,23 @@
         Yields journalentry instances for each contained journal record.
 
         """
-        if not self.vfs.exists('journal'):
+        local = self._open(self.vfs)
+
+        if self.sharedvfs is None:
+            return local
+
+        # iterate over both local and shared entries, but only those
+        # shared entries that are among the currently shared features
+        shared = (
+            e for e in self._open(self.sharedvfs)
+            if sharednamespaces.get(e.namespace) in self.sharedfeatures)
+        return _mergeentriesiter(local, shared)
+
+    def _open(self, vfs, filename='journal', _newestfirst=True):
+        if not vfs.exists(filename):
             return
 
-        with self.vfs('journal') as f:
+        with vfs(filename) as f:
             raw = f.read()
 
         lines = raw.split('\0')
@@ -285,8 +400,12 @@
             version = version or _('not available')
             raise error.Abort(_("unknown journal file version '%s'") % version)
 
-        # Skip the first line, it's a version number. Reverse the rest.
-        lines = reversed(lines[1:])
+        # Skip the first line, it's a version number. Normally we iterate over
+        # these in reverse order to list newest first; only when copying across
+        # a shared storage do we forgo reversing.
+        lines = lines[1:]
+        if _newestfirst:
+            lines = reversed(lines)
         for line in lines:
             if not line:
                 continue