--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/hgext/remotefilelog/fileserverclient.py Thu Sep 27 13:03:19 2018 -0400
@@ -0,0 +1,648 @@
+# fileserverclient.py - client for communicating with the cache process
+#
+# Copyright 2013 Facebook, Inc.
+#
+# This software may be used and distributed according to the terms of the
+# GNU General Public License version 2 or any later version.
+
+from __future__ import absolute_import
+
+import hashlib
+import io
+import os
+import struct
+import threading
+import time
+
+from mercurial.i18n import _
+from mercurial.node import bin, hex, nullid
+from mercurial import (
+ error,
+ revlog,
+ sshpeer,
+ util,
+ wireprotov1peer,
+)
+from mercurial.utils import procutil
+
+from . import (
+ constants,
+ contentstore,
+ lz4wrapper,
+ metadatastore,
+ shallowutil,
+ wirepack,
+)
+
+_sshv1peer = sshpeer.sshv1peer
+
+# Statistics for debugging
+fetchcost = 0
+fetches = 0
+fetched = 0
+fetchmisses = 0
+
+_lfsmod = None
+_downloading = _('downloading')
+
+def getcachekey(reponame, file, id):
+ pathhash = hashlib.sha1(file).hexdigest()
+ return os.path.join(reponame, pathhash[:2], pathhash[2:], id)
+
+def getlocalkey(file, id):
+ pathhash = hashlib.sha1(file).hexdigest()
+ return os.path.join(pathhash, id)
+
+def peersetup(ui, peer):
+
+ class remotefilepeer(peer.__class__):
+ @wireprotov1peer.batchable
+ def getfile(self, file, node):
+ if not self.capable('getfile'):
+ raise error.Abort(
+ 'configured remotefile server does not support getfile')
+ f = wireprotov1peer.future()
+ yield {'file': file, 'node': node}, f
+ code, data = f.value.split('\0', 1)
+ if int(code):
+ raise error.LookupError(file, node, data)
+ yield data
+
+ @wireprotov1peer.batchable
+ def getflogheads(self, path):
+ if not self.capable('getflogheads'):
+ raise error.Abort('configured remotefile server does not '
+ 'support getflogheads')
+ f = wireprotov1peer.future()
+ yield {'path': path}, f
+ heads = f.value.split('\n') if f.value else []
+ yield heads
+
+ def _updatecallstreamopts(self, command, opts):
+ if command != 'getbundle':
+ return
+ if 'remotefilelog' not in self.capabilities():
+ return
+ if not util.safehasattr(self, '_localrepo'):
+ return
+ if constants.REQUIREMENT not in self._localrepo.requirements:
+ return
+
+ bundlecaps = opts.get('bundlecaps')
+ if bundlecaps:
+ bundlecaps = [bundlecaps]
+ else:
+ bundlecaps = []
+
+ # shallow, includepattern, and excludepattern are a hacky way of
+ # carrying over data from the local repo to this getbundle
+ # command. We need to do it this way because bundle1 getbundle
+ # doesn't provide any other place we can hook in to manipulate
+ # getbundle args before it goes across the wire. Once we get rid
+ # of bundle1, we can use bundle2's _pullbundle2extraprepare to
+ # do this more cleanly.
+ bundlecaps.append('remotefilelog')
+ if self._localrepo.includepattern:
+ patterns = '\0'.join(self._localrepo.includepattern)
+ includecap = "includepattern=" + patterns
+ bundlecaps.append(includecap)
+ if self._localrepo.excludepattern:
+ patterns = '\0'.join(self._localrepo.excludepattern)
+ excludecap = "excludepattern=" + patterns
+ bundlecaps.append(excludecap)
+ opts['bundlecaps'] = ','.join(bundlecaps)
+
+ def _sendrequest(self, command, args, **opts):
+ self._updatecallstreamopts(command, args)
+ return super(remotefilepeer, self)._sendrequest(command, args,
+ **opts)
+
+ def _callstream(self, command, **opts):
+ supertype = super(remotefilepeer, self)
+ if not util.safehasattr(supertype, '_sendrequest'):
+ self._updatecallstreamopts(command, opts)
+ return super(remotefilepeer, self)._callstream(command, **opts)
+
+ peer.__class__ = remotefilepeer
+
+class cacheconnection(object):
+ """The connection for communicating with the remote cache. Performs
+ gets and sets by communicating with an external process that has the
+ cache-specific implementation.
+ """
+ def __init__(self):
+ self.pipeo = self.pipei = self.pipee = None
+ self.subprocess = None
+ self.connected = False
+
+ def connect(self, cachecommand):
+ if self.pipeo:
+ raise error.Abort(_("cache connection already open"))
+ self.pipei, self.pipeo, self.pipee, self.subprocess = \
+ procutil.popen4(cachecommand)
+ self.connected = True
+
+ def close(self):
+ def tryclose(pipe):
+ try:
+ pipe.close()
+ except Exception:
+ pass
+ if self.connected:
+ try:
+ self.pipei.write("exit\n")
+ except Exception:
+ pass
+ tryclose(self.pipei)
+ self.pipei = None
+ tryclose(self.pipeo)
+ self.pipeo = None
+ tryclose(self.pipee)
+ self.pipee = None
+ try:
+ # Wait for process to terminate, making sure to avoid deadlock.
+ # See https://docs.python.org/2/library/subprocess.html for
+ # warnings about wait() and deadlocking.
+ self.subprocess.communicate()
+ except Exception:
+ pass
+ self.subprocess = None
+ self.connected = False
+
+ def request(self, request, flush=True):
+ if self.connected:
+ try:
+ self.pipei.write(request)
+ if flush:
+ self.pipei.flush()
+ except IOError:
+ self.close()
+
+ def receiveline(self):
+ if not self.connected:
+ return None
+ try:
+ result = self.pipeo.readline()[:-1]
+ if not result:
+ self.close()
+ except IOError:
+ self.close()
+
+ return result
+
+def _getfilesbatch(
+ remote, receivemissing, progresstick, missed, idmap, batchsize):
+ # Over http(s), iterbatch is a streamy method and we can start
+ # looking at results early. This means we send one (potentially
+ # large) request, but then we show nice progress as we process
+ # file results, rather than showing chunks of $batchsize in
+ # progress.
+ #
+ # Over ssh, iterbatch isn't streamy because batch() wasn't
+ # explicitly designed as a streaming method. In the future we
+ # should probably introduce a streambatch() method upstream and
+ # use that for this.
+ with remote.commandexecutor() as e:
+ futures = []
+ for m in missed:
+ futures.append(e.callcommand('getfile', {
+ 'file': idmap[m],
+ 'node': m[-40:]
+ }))
+
+ for i, m in enumerate(missed):
+ r = futures[i].result()
+ futures[i] = None # release memory
+ file_ = idmap[m]
+ node = m[-40:]
+ receivemissing(io.BytesIO('%d\n%s' % (len(r), r)), file_, node)
+ progresstick()
+
+def _getfiles_optimistic(
+ remote, receivemissing, progresstick, missed, idmap, step):
+ remote._callstream("getfiles")
+ i = 0
+ pipeo = remote._pipeo
+ pipei = remote._pipei
+ while i < len(missed):
+ # issue a batch of requests
+ start = i
+ end = min(len(missed), start + step)
+ i = end
+ for missingid in missed[start:end]:
+ # issue new request
+ versionid = missingid[-40:]
+ file = idmap[missingid]
+ sshrequest = "%s%s\n" % (versionid, file)
+ pipeo.write(sshrequest)
+ pipeo.flush()
+
+ # receive batch results
+ for missingid in missed[start:end]:
+ versionid = missingid[-40:]
+ file = idmap[missingid]
+ receivemissing(pipei, file, versionid)
+ progresstick()
+
+ # End the command
+ pipeo.write('\n')
+ pipeo.flush()
+
+def _getfiles_threaded(
+ remote, receivemissing, progresstick, missed, idmap, step):
+ remote._callstream("getfiles")
+ pipeo = remote._pipeo
+ pipei = remote._pipei
+
+ def writer():
+ for missingid in missed:
+ versionid = missingid[-40:]
+ file = idmap[missingid]
+ sshrequest = "%s%s\n" % (versionid, file)
+ pipeo.write(sshrequest)
+ pipeo.flush()
+ writerthread = threading.Thread(target=writer)
+ writerthread.daemon = True
+ writerthread.start()
+
+ for missingid in missed:
+ versionid = missingid[-40:]
+ file = idmap[missingid]
+ receivemissing(pipei, file, versionid)
+ progresstick()
+
+ writerthread.join()
+ # End the command
+ pipeo.write('\n')
+ pipeo.flush()
+
+class fileserverclient(object):
+ """A client for requesting files from the remote file server.
+ """
+ def __init__(self, repo):
+ ui = repo.ui
+ self.repo = repo
+ self.ui = ui
+ self.cacheprocess = ui.config("remotefilelog", "cacheprocess")
+ if self.cacheprocess:
+ self.cacheprocess = util.expandpath(self.cacheprocess)
+
+ # This option causes remotefilelog to pass the full file path to the
+ # cacheprocess instead of a hashed key.
+ self.cacheprocesspasspath = ui.configbool(
+ "remotefilelog", "cacheprocess.includepath")
+
+ self.debugoutput = ui.configbool("remotefilelog", "debug")
+
+ self.remotecache = cacheconnection()
+
+ def setstore(self, datastore, historystore, writedata, writehistory):
+ self.datastore = datastore
+ self.historystore = historystore
+ self.writedata = writedata
+ self.writehistory = writehistory
+
+ def _connect(self):
+ return self.repo.connectionpool.get(self.repo.fallbackpath)
+
+ def request(self, fileids):
+ """Takes a list of filename/node pairs and fetches them from the
+ server. Files are stored in the local cache.
+ A list of nodes that the server couldn't find is returned.
+ If the connection fails, an exception is raised.
+ """
+ if not self.remotecache.connected:
+ self.connect()
+ cache = self.remotecache
+ writedata = self.writedata
+
+ if self.ui.configbool('remotefilelog', 'fetchpacks'):
+ self.requestpack(fileids)
+ return
+
+ repo = self.repo
+ count = len(fileids)
+ request = "get\n%d\n" % count
+ idmap = {}
+ reponame = repo.name
+ for file, id in fileids:
+ fullid = getcachekey(reponame, file, id)
+ if self.cacheprocesspasspath:
+ request += file + '\0'
+ request += fullid + "\n"
+ idmap[fullid] = file
+
+ cache.request(request)
+
+ total = count
+ self.ui.progress(_downloading, 0, total=count)
+
+ missed = []
+ count = 0
+ while True:
+ missingid = cache.receiveline()
+ if not missingid:
+ missedset = set(missed)
+ for missingid in idmap.iterkeys():
+ if not missingid in missedset:
+ missed.append(missingid)
+ self.ui.warn(_("warning: cache connection closed early - " +
+ "falling back to server\n"))
+ break
+ if missingid == "0":
+ break
+ if missingid.startswith("_hits_"):
+ # receive progress reports
+ parts = missingid.split("_")
+ count += int(parts[2])
+ self.ui.progress(_downloading, count, total=total)
+ continue
+
+ missed.append(missingid)
+
+ global fetchmisses
+ fetchmisses += len(missed)
+
+ count = [total - len(missed)]
+ fromcache = count[0]
+ self.ui.progress(_downloading, count[0], total=total)
+ self.ui.log("remotefilelog", "remote cache hit rate is %r of %r\n",
+ count[0], total, hit=count[0], total=total)
+
+ oldumask = os.umask(0o002)
+ try:
+ # receive cache misses from master
+ if missed:
+ def progresstick():
+ count[0] += 1
+ self.ui.progress(_downloading, count[0], total=total)
+ # When verbose is true, sshpeer prints 'running ssh...'
+ # to stdout, which can interfere with some command
+ # outputs
+ verbose = self.ui.verbose
+ self.ui.verbose = False
+ try:
+ with self._connect() as conn:
+ remote = conn.peer
+ # TODO: deduplicate this with the constant in
+ # shallowrepo
+ if remote.capable("remotefilelog"):
+ if not isinstance(remote, _sshv1peer):
+ raise error.Abort('remotefilelog requires ssh '
+ 'servers')
+ step = self.ui.configint('remotefilelog',
+ 'getfilesstep')
+ getfilestype = self.ui.config('remotefilelog',
+ 'getfilestype')
+ if getfilestype == 'threaded':
+ _getfiles = _getfiles_threaded
+ else:
+ _getfiles = _getfiles_optimistic
+ _getfiles(remote, self.receivemissing, progresstick,
+ missed, idmap, step)
+ elif remote.capable("getfile"):
+ if remote.capable('batch'):
+ batchdefault = 100
+ else:
+ batchdefault = 10
+ batchsize = self.ui.configint(
+ 'remotefilelog', 'batchsize', batchdefault)
+ _getfilesbatch(
+ remote, self.receivemissing, progresstick,
+ missed, idmap, batchsize)
+ else:
+ raise error.Abort("configured remotefilelog server"
+ " does not support remotefilelog")
+
+ self.ui.log("remotefilefetchlog",
+ "Success\n",
+ fetched_files = count[0] - fromcache,
+ total_to_fetch = total - fromcache)
+ except Exception:
+ self.ui.log("remotefilefetchlog",
+ "Fail\n",
+ fetched_files = count[0] - fromcache,
+ total_to_fetch = total - fromcache)
+ raise
+ finally:
+ self.ui.verbose = verbose
+ # send to memcache
+ count[0] = len(missed)
+ request = "set\n%d\n%s\n" % (count[0], "\n".join(missed))
+ cache.request(request)
+
+ self.ui.progress(_downloading, None)
+
+ # mark ourselves as a user of this cache
+ writedata.markrepo(self.repo.path)
+ finally:
+ os.umask(oldumask)
+
+ def receivemissing(self, pipe, filename, node):
+ line = pipe.readline()[:-1]
+ if not line:
+ raise error.ResponseError(_("error downloading file contents:"),
+ _("connection closed early"))
+ size = int(line)
+ data = pipe.read(size)
+ if len(data) != size:
+ raise error.ResponseError(_("error downloading file contents:"),
+ _("only received %s of %s bytes")
+ % (len(data), size))
+
+ self.writedata.addremotefilelognode(filename, bin(node),
+ lz4wrapper.lz4decompress(data))
+
+ def requestpack(self, fileids):
+ """Requests the given file revisions from the server in a pack format.
+
+ See `remotefilelogserver.getpack` for the file format.
+ """
+ try:
+ with self._connect() as conn:
+ total = len(fileids)
+ rcvd = 0
+
+ remote = conn.peer
+ remote._callstream("getpackv1")
+
+ self._sendpackrequest(remote, fileids)
+
+ packpath = shallowutil.getcachepackpath(
+ self.repo, constants.FILEPACK_CATEGORY)
+ pipei = remote._pipei
+ receiveddata, receivedhistory = wirepack.receivepack(
+ self.repo.ui, pipei, packpath)
+ rcvd = len(receiveddata)
+
+ self.ui.log("remotefilefetchlog",
+ "Success(pack)\n" if (rcvd==total) else "Fail(pack)\n",
+ fetched_files = rcvd,
+ total_to_fetch = total)
+ except Exception:
+ self.ui.log("remotefilefetchlog",
+ "Fail(pack)\n",
+ fetched_files = rcvd,
+ total_to_fetch = total)
+ raise
+
+ def _sendpackrequest(self, remote, fileids):
+ """Formats and writes the given fileids to the remote as part of a
+ getpackv1 call.
+ """
+ # Sort the requests by name, so we receive requests in batches by name
+ grouped = {}
+ for filename, node in fileids:
+ grouped.setdefault(filename, set()).add(node)
+
+ # Issue request
+ pipeo = remote._pipeo
+ for filename, nodes in grouped.iteritems():
+ filenamelen = struct.pack(constants.FILENAMESTRUCT, len(filename))
+ countlen = struct.pack(constants.PACKREQUESTCOUNTSTRUCT, len(nodes))
+ rawnodes = ''.join(bin(n) for n in nodes)
+
+ pipeo.write('%s%s%s%s' % (filenamelen, filename, countlen,
+ rawnodes))
+ pipeo.flush()
+ pipeo.write(struct.pack(constants.FILENAMESTRUCT, 0))
+ pipeo.flush()
+
+ def connect(self):
+ if self.cacheprocess:
+ cmd = "%s %s" % (self.cacheprocess, self.writedata._path)
+ self.remotecache.connect(cmd)
+ else:
+ # If no cache process is specified, we fake one that always
+ # returns cache misses. This enables tests to run easily
+ # and may eventually allow us to be a drop in replacement
+ # for the largefiles extension.
+ class simplecache(object):
+ def __init__(self):
+ self.missingids = []
+ self.connected = True
+
+ def close(self):
+ pass
+
+ def request(self, value, flush=True):
+ lines = value.split("\n")
+ if lines[0] != "get":
+ return
+ self.missingids = lines[2:-1]
+ self.missingids.append('0')
+
+ def receiveline(self):
+ if len(self.missingids) > 0:
+ return self.missingids.pop(0)
+ return None
+
+ self.remotecache = simplecache()
+
+ def close(self):
+ if fetches:
+ msg = ("%s files fetched over %d fetches - " +
+ "(%d misses, %0.2f%% hit ratio) over %0.2fs\n") % (
+ fetched,
+ fetches,
+ fetchmisses,
+ float(fetched - fetchmisses) / float(fetched) * 100.0,
+ fetchcost)
+ if self.debugoutput:
+ self.ui.warn(msg)
+ self.ui.log("remotefilelog.prefetch", msg.replace("%", "%%"),
+ remotefilelogfetched=fetched,
+ remotefilelogfetches=fetches,
+ remotefilelogfetchmisses=fetchmisses,
+ remotefilelogfetchtime=fetchcost * 1000)
+
+ if self.remotecache.connected:
+ self.remotecache.close()
+
+ def prefetch(self, fileids, force=False, fetchdata=True,
+ fetchhistory=False):
+ """downloads the given file versions to the cache
+ """
+ repo = self.repo
+ idstocheck = []
+ for file, id in fileids:
+ # hack
+ # - we don't use .hgtags
+ # - workingctx produces ids with length 42,
+ # which we skip since they aren't in any cache
+ if (file == '.hgtags' or len(id) == 42
+ or not repo.shallowmatch(file)):
+ continue
+
+ idstocheck.append((file, bin(id)))
+
+ datastore = self.datastore
+ historystore = self.historystore
+ if force:
+ datastore = contentstore.unioncontentstore(*repo.shareddatastores)
+ historystore = metadatastore.unionmetadatastore(
+ *repo.sharedhistorystores)
+
+ missingids = set()
+ if fetchdata:
+ missingids.update(datastore.getmissing(idstocheck))
+ if fetchhistory:
+ missingids.update(historystore.getmissing(idstocheck))
+
+ # partition missing nodes into nullid and not-nullid so we can
+ # warn about this filtering potentially shadowing bugs.
+ nullids = len([None for unused, id in missingids if id == nullid])
+ if nullids:
+ missingids = [(f, id) for f, id in missingids if id != nullid]
+ repo.ui.develwarn(
+ ('remotefilelog not fetching %d null revs'
+ ' - this is likely hiding bugs' % nullids),
+ config='remotefilelog-ext')
+ if missingids:
+ global fetches, fetched, fetchcost
+ fetches += 1
+
+ # We want to be able to detect excess individual file downloads, so
+ # let's log that information for debugging.
+ if fetches >= 15 and fetches < 18:
+ if fetches == 15:
+ fetchwarning = self.ui.config('remotefilelog',
+ 'fetchwarning')
+ if fetchwarning:
+ self.ui.warn(fetchwarning + '\n')
+ self.logstacktrace()
+ missingids = [(file, hex(id)) for file, id in missingids]
+ fetched += len(missingids)
+ start = time.time()
+ missingids = self.request(missingids)
+ if missingids:
+ raise error.Abort(_("unable to download %d files") %
+ len(missingids))
+ fetchcost += time.time() - start
+ self._lfsprefetch(fileids)
+
+ def _lfsprefetch(self, fileids):
+ if not _lfsmod or not util.safehasattr(
+ self.repo.svfs, 'lfslocalblobstore'):
+ return
+ if not _lfsmod.wrapper.candownload(self.repo):
+ return
+ pointers = []
+ store = self.repo.svfs.lfslocalblobstore
+ for file, id in fileids:
+ node = bin(id)
+ rlog = self.repo.file(file)
+ if rlog.flags(node) & revlog.REVIDX_EXTSTORED:
+ text = rlog.revision(node, raw=True)
+ p = _lfsmod.pointer.deserialize(text)
+ oid = p.oid()
+ if not store.has(oid):
+ pointers.append(p)
+ if len(pointers) > 0:
+ self.repo.svfs.lfsremoteblobstore.readbatch(pointers, store)
+ assert all(store.has(p.oid()) for p in pointers)
+
+ def logstacktrace(self):
+ import traceback
+ self.ui.log('remotefilelog', 'excess remotefilelog fetching:\n%s\n',
+ ''.join(traceback.format_stack()))