mercurial/wireproto.py
branchstable
changeset 37788 ed5448edcbfa
parent 37287 fb92df8b634c
parent 37787 92213f6745ed
child 37789 bfd32db06952
equal deleted inserted replaced
37287:fb92df8b634c 37788:ed5448edcbfa
     1 # wireproto.py - generic wire protocol support functions
       
     2 #
       
     3 # Copyright 2005-2010 Matt Mackall <mpm@selenic.com>
       
     4 #
       
     5 # This software may be used and distributed according to the terms of the
       
     6 # GNU General Public License version 2 or any later version.
       
     7 
       
     8 from __future__ import absolute_import
       
     9 
       
    10 import hashlib
       
    11 import os
       
    12 import tempfile
       
    13 
       
    14 from .i18n import _
       
    15 from .node import (
       
    16     bin,
       
    17     hex,
       
    18     nullid,
       
    19 )
       
    20 
       
    21 from . import (
       
    22     bundle2,
       
    23     changegroup as changegroupmod,
       
    24     discovery,
       
    25     encoding,
       
    26     error,
       
    27     exchange,
       
    28     peer,
       
    29     pushkey as pushkeymod,
       
    30     pycompat,
       
    31     repository,
       
    32     streamclone,
       
    33     util,
       
    34 )
       
    35 
       
    36 urlerr = util.urlerr
       
    37 urlreq = util.urlreq
       
    38 
       
    39 bundle2requiredmain = _('incompatible Mercurial client; bundle2 required')
       
    40 bundle2requiredhint = _('see https://www.mercurial-scm.org/wiki/'
       
    41                         'IncompatibleClient')
       
    42 bundle2required = '%s\n(%s)\n' % (bundle2requiredmain, bundle2requiredhint)
       
    43 
       
    44 class abstractserverproto(object):
       
    45     """abstract class that summarizes the protocol API
       
    46 
       
    47     Used as reference and documentation.
       
    48     """
       
    49 
       
    50     def getargs(self, args):
       
    51         """return the value for arguments in <args>
       
    52 
       
    53         returns a list of values (same order as <args>)"""
       
    54         raise NotImplementedError()
       
    55 
       
    56     def getfile(self, fp):
       
    57         """write the whole content of a file into a file like object
       
    58 
       
    59         The file is in the form::
       
    60 
       
    61             (<chunk-size>\n<chunk>)+0\n
       
    62 
       
    63         chunk size is the ascii version of the int.
       
    64         """
       
    65         raise NotImplementedError()
       
    66 
       
    67     def redirect(self):
       
    68         """may setup interception for stdout and stderr
       
    69 
       
    70         See also the `restore` method."""
       
    71         raise NotImplementedError()
       
    72 
       
    73     # If the `redirect` function does install interception, the `restore`
       
    74     # function MUST be defined. If interception is not used, this function
       
    75     # MUST NOT be defined.
       
    76     #
       
    77     # left commented here on purpose
       
    78     #
       
    79     #def restore(self):
       
    80     #    """reinstall previous stdout and stderr and return intercepted stdout
       
    81     #    """
       
    82     #    raise NotImplementedError()
       
    83 
       
    84 class remoteiterbatcher(peer.iterbatcher):
       
    85     def __init__(self, remote):
       
    86         super(remoteiterbatcher, self).__init__()
       
    87         self._remote = remote
       
    88 
       
    89     def __getattr__(self, name):
       
    90         # Validate this method is batchable, since submit() only supports
       
    91         # batchable methods.
       
    92         fn = getattr(self._remote, name)
       
    93         if not getattr(fn, 'batchable', None):
       
    94             raise error.ProgrammingError('Attempted to batch a non-batchable '
       
    95                                          'call to %r' % name)
       
    96 
       
    97         return super(remoteiterbatcher, self).__getattr__(name)
       
    98 
       
    99     def submit(self):
       
   100         """Break the batch request into many patch calls and pipeline them.
       
   101 
       
   102         This is mostly valuable over http where request sizes can be
       
   103         limited, but can be used in other places as well.
       
   104         """
       
   105         # 2-tuple of (command, arguments) that represents what will be
       
   106         # sent over the wire.
       
   107         requests = []
       
   108 
       
   109         # 4-tuple of (command, final future, @batchable generator, remote
       
   110         # future).
       
   111         results = []
       
   112 
       
   113         for command, args, opts, finalfuture in self.calls:
       
   114             mtd = getattr(self._remote, command)
       
   115             batchable = mtd.batchable(mtd.__self__, *args, **opts)
       
   116 
       
   117             commandargs, fremote = next(batchable)
       
   118             assert fremote
       
   119             requests.append((command, commandargs))
       
   120             results.append((command, finalfuture, batchable, fremote))
       
   121 
       
   122         if requests:
       
   123             self._resultiter = self._remote._submitbatch(requests)
       
   124 
       
   125         self._results = results
       
   126 
       
   127     def results(self):
       
   128         for command, finalfuture, batchable, remotefuture in self._results:
       
   129             # Get the raw result, set it in the remote future, feed it
       
   130             # back into the @batchable generator so it can be decoded, and
       
   131             # set the result on the final future to this value.
       
   132             remoteresult = next(self._resultiter)
       
   133             remotefuture.set(remoteresult)
       
   134             finalfuture.set(next(batchable))
       
   135 
       
   136             # Verify our @batchable generators only emit 2 values.
       
   137             try:
       
   138                 next(batchable)
       
   139             except StopIteration:
       
   140                 pass
       
   141             else:
       
   142                 raise error.ProgrammingError('%s @batchable generator emitted '
       
   143                                              'unexpected value count' % command)
       
   144 
       
   145             yield finalfuture.value
       
   146 
       
   147 # Forward a couple of names from peer to make wireproto interactions
       
   148 # slightly more sensible.
       
   149 batchable = peer.batchable
       
   150 future = peer.future
       
   151 
       
   152 # list of nodes encoding / decoding
       
   153 
       
   154 def decodelist(l, sep=' '):
       
   155     if l:
       
   156         return [bin(v) for v in  l.split(sep)]
       
   157     return []
       
   158 
       
   159 def encodelist(l, sep=' '):
       
   160     try:
       
   161         return sep.join(map(hex, l))
       
   162     except TypeError:
       
   163         raise
       
   164 
       
   165 # batched call argument encoding
       
   166 
       
   167 def escapearg(plain):
       
   168     return (plain
       
   169             .replace(':', ':c')
       
   170             .replace(',', ':o')
       
   171             .replace(';', ':s')
       
   172             .replace('=', ':e'))
       
   173 
       
   174 def unescapearg(escaped):
       
   175     return (escaped
       
   176             .replace(':e', '=')
       
   177             .replace(':s', ';')
       
   178             .replace(':o', ',')
       
   179             .replace(':c', ':'))
       
   180 
       
   181 def encodebatchcmds(req):
       
   182     """Return a ``cmds`` argument value for the ``batch`` command."""
       
   183     cmds = []
       
   184     for op, argsdict in req:
       
   185         # Old servers didn't properly unescape argument names. So prevent
       
   186         # the sending of argument names that may not be decoded properly by
       
   187         # servers.
       
   188         assert all(escapearg(k) == k for k in argsdict)
       
   189 
       
   190         args = ','.join('%s=%s' % (escapearg(k), escapearg(v))
       
   191                         for k, v in argsdict.iteritems())
       
   192         cmds.append('%s %s' % (op, args))
       
   193 
       
   194     return ';'.join(cmds)
       
   195 
       
   196 # mapping of options accepted by getbundle and their types
       
   197 #
       
   198 # Meant to be extended by extensions. It is extensions responsibility to ensure
       
   199 # such options are properly processed in exchange.getbundle.
       
   200 #
       
   201 # supported types are:
       
   202 #
       
   203 # :nodes: list of binary nodes
       
   204 # :csv:   list of comma-separated values
       
   205 # :scsv:  list of comma-separated values return as set
       
   206 # :plain: string with no transformation needed.
       
   207 gboptsmap = {'heads':  'nodes',
       
   208              'bookmarks': 'boolean',
       
   209              'common': 'nodes',
       
   210              'obsmarkers': 'boolean',
       
   211              'phases': 'boolean',
       
   212              'bundlecaps': 'scsv',
       
   213              'listkeys': 'csv',
       
   214              'cg': 'boolean',
       
   215              'cbattempted': 'boolean',
       
   216              'stream': 'boolean',
       
   217 }
       
   218 
       
   219 # client side
       
   220 
       
   221 class wirepeer(repository.legacypeer):
       
   222     """Client-side interface for communicating with a peer repository.
       
   223 
       
   224     Methods commonly call wire protocol commands of the same name.
       
   225 
       
   226     See also httppeer.py and sshpeer.py for protocol-specific
       
   227     implementations of this interface.
       
   228     """
       
   229     # Begin of basewirepeer interface.
       
   230 
       
   231     def iterbatch(self):
       
   232         return remoteiterbatcher(self)
       
   233 
       
   234     @batchable
       
   235     def lookup(self, key):
       
   236         self.requirecap('lookup', _('look up remote revision'))
       
   237         f = future()
       
   238         yield {'key': encoding.fromlocal(key)}, f
       
   239         d = f.value
       
   240         success, data = d[:-1].split(" ", 1)
       
   241         if int(success):
       
   242             yield bin(data)
       
   243         else:
       
   244             self._abort(error.RepoError(data))
       
   245 
       
   246     @batchable
       
   247     def heads(self):
       
   248         f = future()
       
   249         yield {}, f
       
   250         d = f.value
       
   251         try:
       
   252             yield decodelist(d[:-1])
       
   253         except ValueError:
       
   254             self._abort(error.ResponseError(_("unexpected response:"), d))
       
   255 
       
   256     @batchable
       
   257     def known(self, nodes):
       
   258         f = future()
       
   259         yield {'nodes': encodelist(nodes)}, f
       
   260         d = f.value
       
   261         try:
       
   262             yield [bool(int(b)) for b in d]
       
   263         except ValueError:
       
   264             self._abort(error.ResponseError(_("unexpected response:"), d))
       
   265 
       
   266     @batchable
       
   267     def branchmap(self):
       
   268         f = future()
       
   269         yield {}, f
       
   270         d = f.value
       
   271         try:
       
   272             branchmap = {}
       
   273             for branchpart in d.splitlines():
       
   274                 branchname, branchheads = branchpart.split(' ', 1)
       
   275                 branchname = encoding.tolocal(urlreq.unquote(branchname))
       
   276                 branchheads = decodelist(branchheads)
       
   277                 branchmap[branchname] = branchheads
       
   278             yield branchmap
       
   279         except TypeError:
       
   280             self._abort(error.ResponseError(_("unexpected response:"), d))
       
   281 
       
   282     @batchable
       
   283     def listkeys(self, namespace):
       
   284         if not self.capable('pushkey'):
       
   285             yield {}, None
       
   286         f = future()
       
   287         self.ui.debug('preparing listkeys for "%s"\n' % namespace)
       
   288         yield {'namespace': encoding.fromlocal(namespace)}, f
       
   289         d = f.value
       
   290         self.ui.debug('received listkey for "%s": %i bytes\n'
       
   291                       % (namespace, len(d)))
       
   292         yield pushkeymod.decodekeys(d)
       
   293 
       
   294     @batchable
       
   295     def pushkey(self, namespace, key, old, new):
       
   296         if not self.capable('pushkey'):
       
   297             yield False, None
       
   298         f = future()
       
   299         self.ui.debug('preparing pushkey for "%s:%s"\n' % (namespace, key))
       
   300         yield {'namespace': encoding.fromlocal(namespace),
       
   301                'key': encoding.fromlocal(key),
       
   302                'old': encoding.fromlocal(old),
       
   303                'new': encoding.fromlocal(new)}, f
       
   304         d = f.value
       
   305         d, output = d.split('\n', 1)
       
   306         try:
       
   307             d = bool(int(d))
       
   308         except ValueError:
       
   309             raise error.ResponseError(
       
   310                 _('push failed (unexpected response):'), d)
       
   311         for l in output.splitlines(True):
       
   312             self.ui.status(_('remote: '), l)
       
   313         yield d
       
   314 
       
   315     def stream_out(self):
       
   316         return self._callstream('stream_out')
       
   317 
       
   318     def getbundle(self, source, **kwargs):
       
   319         kwargs = pycompat.byteskwargs(kwargs)
       
   320         self.requirecap('getbundle', _('look up remote changes'))
       
   321         opts = {}
       
   322         bundlecaps = kwargs.get('bundlecaps')
       
   323         if bundlecaps is not None:
       
   324             kwargs['bundlecaps'] = sorted(bundlecaps)
       
   325         else:
       
   326             bundlecaps = () # kwargs could have it to None
       
   327         for key, value in kwargs.iteritems():
       
   328             if value is None:
       
   329                 continue
       
   330             keytype = gboptsmap.get(key)
       
   331             if keytype is None:
       
   332                 raise error.ProgrammingError(
       
   333                     'Unexpectedly None keytype for key %s' % key)
       
   334             elif keytype == 'nodes':
       
   335                 value = encodelist(value)
       
   336             elif keytype in ('csv', 'scsv'):
       
   337                 value = ','.join(value)
       
   338             elif keytype == 'boolean':
       
   339                 value = '%i' % bool(value)
       
   340             elif keytype != 'plain':
       
   341                 raise KeyError('unknown getbundle option type %s'
       
   342                                % keytype)
       
   343             opts[key] = value
       
   344         f = self._callcompressable("getbundle", **pycompat.strkwargs(opts))
       
   345         if any((cap.startswith('HG2') for cap in bundlecaps)):
       
   346             return bundle2.getunbundler(self.ui, f)
       
   347         else:
       
   348             return changegroupmod.cg1unpacker(f, 'UN')
       
   349 
       
   350     def unbundle(self, cg, heads, url):
       
   351         '''Send cg (a readable file-like object representing the
       
   352         changegroup to push, typically a chunkbuffer object) to the
       
   353         remote server as a bundle.
       
   354 
       
   355         When pushing a bundle10 stream, return an integer indicating the
       
   356         result of the push (see changegroup.apply()).
       
   357 
       
   358         When pushing a bundle20 stream, return a bundle20 stream.
       
   359 
       
   360         `url` is the url the client thinks it's pushing to, which is
       
   361         visible to hooks.
       
   362         '''
       
   363 
       
   364         if heads != ['force'] and self.capable('unbundlehash'):
       
   365             heads = encodelist(['hashed',
       
   366                                 hashlib.sha1(''.join(sorted(heads))).digest()])
       
   367         else:
       
   368             heads = encodelist(heads)
       
   369 
       
   370         if util.safehasattr(cg, 'deltaheader'):
       
   371             # this a bundle10, do the old style call sequence
       
   372             ret, output = self._callpush("unbundle", cg, heads=heads)
       
   373             if ret == "":
       
   374                 raise error.ResponseError(
       
   375                     _('push failed:'), output)
       
   376             try:
       
   377                 ret = int(ret)
       
   378             except ValueError:
       
   379                 raise error.ResponseError(
       
   380                     _('push failed (unexpected response):'), ret)
       
   381 
       
   382             for l in output.splitlines(True):
       
   383                 self.ui.status(_('remote: '), l)
       
   384         else:
       
   385             # bundle2 push. Send a stream, fetch a stream.
       
   386             stream = self._calltwowaystream('unbundle', cg, heads=heads)
       
   387             ret = bundle2.getunbundler(self.ui, stream)
       
   388         return ret
       
   389 
       
   390     # End of basewirepeer interface.
       
   391 
       
   392     # Begin of baselegacywirepeer interface.
       
   393 
       
   394     def branches(self, nodes):
       
   395         n = encodelist(nodes)
       
   396         d = self._call("branches", nodes=n)
       
   397         try:
       
   398             br = [tuple(decodelist(b)) for b in d.splitlines()]
       
   399             return br
       
   400         except ValueError:
       
   401             self._abort(error.ResponseError(_("unexpected response:"), d))
       
   402 
       
   403     def between(self, pairs):
       
   404         batch = 8 # avoid giant requests
       
   405         r = []
       
   406         for i in xrange(0, len(pairs), batch):
       
   407             n = " ".join([encodelist(p, '-') for p in pairs[i:i + batch]])
       
   408             d = self._call("between", pairs=n)
       
   409             try:
       
   410                 r.extend(l and decodelist(l) or [] for l in d.splitlines())
       
   411             except ValueError:
       
   412                 self._abort(error.ResponseError(_("unexpected response:"), d))
       
   413         return r
       
   414 
       
   415     def changegroup(self, nodes, kind):
       
   416         n = encodelist(nodes)
       
   417         f = self._callcompressable("changegroup", roots=n)
       
   418         return changegroupmod.cg1unpacker(f, 'UN')
       
   419 
       
   420     def changegroupsubset(self, bases, heads, kind):
       
   421         self.requirecap('changegroupsubset', _('look up remote changes'))
       
   422         bases = encodelist(bases)
       
   423         heads = encodelist(heads)
       
   424         f = self._callcompressable("changegroupsubset",
       
   425                                    bases=bases, heads=heads)
       
   426         return changegroupmod.cg1unpacker(f, 'UN')
       
   427 
       
   428     # End of baselegacywirepeer interface.
       
   429 
       
   430     def _submitbatch(self, req):
       
   431         """run batch request <req> on the server
       
   432 
       
   433         Returns an iterator of the raw responses from the server.
       
   434         """
       
   435         rsp = self._callstream("batch", cmds=encodebatchcmds(req))
       
   436         chunk = rsp.read(1024)
       
   437         work = [chunk]
       
   438         while chunk:
       
   439             while ';' not in chunk and chunk:
       
   440                 chunk = rsp.read(1024)
       
   441                 work.append(chunk)
       
   442             merged = ''.join(work)
       
   443             while ';' in merged:
       
   444                 one, merged = merged.split(';', 1)
       
   445                 yield unescapearg(one)
       
   446             chunk = rsp.read(1024)
       
   447             work = [merged, chunk]
       
   448         yield unescapearg(''.join(work))
       
   449 
       
   450     def _submitone(self, op, args):
       
   451         return self._call(op, **pycompat.strkwargs(args))
       
   452 
       
   453     def debugwireargs(self, one, two, three=None, four=None, five=None):
       
   454         # don't pass optional arguments left at their default value
       
   455         opts = {}
       
   456         if three is not None:
       
   457             opts[r'three'] = three
       
   458         if four is not None:
       
   459             opts[r'four'] = four
       
   460         return self._call('debugwireargs', one=one, two=two, **opts)
       
   461 
       
   462     def _call(self, cmd, **args):
       
   463         """execute <cmd> on the server
       
   464 
       
   465         The command is expected to return a simple string.
       
   466 
       
   467         returns the server reply as a string."""
       
   468         raise NotImplementedError()
       
   469 
       
   470     def _callstream(self, cmd, **args):
       
   471         """execute <cmd> on the server
       
   472 
       
   473         The command is expected to return a stream. Note that if the
       
   474         command doesn't return a stream, _callstream behaves
       
   475         differently for ssh and http peers.
       
   476 
       
   477         returns the server reply as a file like object.
       
   478         """
       
   479         raise NotImplementedError()
       
   480 
       
   481     def _callcompressable(self, cmd, **args):
       
   482         """execute <cmd> on the server
       
   483 
       
   484         The command is expected to return a stream.
       
   485 
       
   486         The stream may have been compressed in some implementations. This
       
   487         function takes care of the decompression. This is the only difference
       
   488         with _callstream.
       
   489 
       
   490         returns the server reply as a file like object.
       
   491         """
       
   492         raise NotImplementedError()
       
   493 
       
   494     def _callpush(self, cmd, fp, **args):
       
   495         """execute a <cmd> on server
       
   496 
       
   497         The command is expected to be related to a push. Push has a special
       
   498         return method.
       
   499 
       
   500         returns the server reply as a (ret, output) tuple. ret is either
       
   501         empty (error) or a stringified int.
       
   502         """
       
   503         raise NotImplementedError()
       
   504 
       
   505     def _calltwowaystream(self, cmd, fp, **args):
       
   506         """execute <cmd> on server
       
   507 
       
   508         The command will send a stream to the server and get a stream in reply.
       
   509         """
       
   510         raise NotImplementedError()
       
   511 
       
   512     def _abort(self, exception):
       
   513         """clearly abort the wire protocol connection and raise the exception
       
   514         """
       
   515         raise NotImplementedError()
       
   516 
       
   517 # server side
       
   518 
       
   519 # wire protocol command can either return a string or one of these classes.
       
   520 class streamres(object):
       
   521     """wireproto reply: binary stream
       
   522 
       
   523     The call was successful and the result is a stream.
       
   524 
       
   525     Accepts a generator containing chunks of data to be sent to the client.
       
   526 
       
   527     ``prefer_uncompressed`` indicates that the data is expected to be
       
   528     uncompressable and that the stream should therefore use the ``none``
       
   529     engine.
       
   530     """
       
   531     def __init__(self, gen=None, prefer_uncompressed=False):
       
   532         self.gen = gen
       
   533         self.prefer_uncompressed = prefer_uncompressed
       
   534 
       
   535 class streamres_legacy(object):
       
   536     """wireproto reply: uncompressed binary stream
       
   537 
       
   538     The call was successful and the result is a stream.
       
   539 
       
   540     Accepts a generator containing chunks of data to be sent to the client.
       
   541 
       
   542     Like ``streamres``, but sends an uncompressed data for "version 1" clients
       
   543     using the application/mercurial-0.1 media type.
       
   544     """
       
   545     def __init__(self, gen=None):
       
   546         self.gen = gen
       
   547 
       
   548 class pushres(object):
       
   549     """wireproto reply: success with simple integer return
       
   550 
       
   551     The call was successful and returned an integer contained in `self.res`.
       
   552     """
       
   553     def __init__(self, res):
       
   554         self.res = res
       
   555 
       
   556 class pusherr(object):
       
   557     """wireproto reply: failure
       
   558 
       
   559     The call failed. The `self.res` attribute contains the error message.
       
   560     """
       
   561     def __init__(self, res):
       
   562         self.res = res
       
   563 
       
   564 class ooberror(object):
       
   565     """wireproto reply: failure of a batch of operation
       
   566 
       
   567     Something failed during a batch call. The error message is stored in
       
   568     `self.message`.
       
   569     """
       
   570     def __init__(self, message):
       
   571         self.message = message
       
   572 
       
   573 def getdispatchrepo(repo, proto, command):
       
   574     """Obtain the repo used for processing wire protocol commands.
       
   575 
       
   576     The intent of this function is to serve as a monkeypatch point for
       
   577     extensions that need commands to operate on different repo views under
       
   578     specialized circumstances.
       
   579     """
       
   580     return repo.filtered('served')
       
   581 
       
   582 def dispatch(repo, proto, command):
       
   583     repo = getdispatchrepo(repo, proto, command)
       
   584     func, spec = commands[command]
       
   585     args = proto.getargs(spec)
       
   586     return func(repo, proto, *args)
       
   587 
       
   588 def options(cmd, keys, others):
       
   589     opts = {}
       
   590     for k in keys:
       
   591         if k in others:
       
   592             opts[k] = others[k]
       
   593             del others[k]
       
   594     if others:
       
   595         util.stderr.write("warning: %s ignored unexpected arguments %s\n"
       
   596                           % (cmd, ",".join(others)))
       
   597     return opts
       
   598 
       
   599 def bundle1allowed(repo, action):
       
   600     """Whether a bundle1 operation is allowed from the server.
       
   601 
       
   602     Priority is:
       
   603 
       
   604     1. server.bundle1gd.<action> (if generaldelta active)
       
   605     2. server.bundle1.<action>
       
   606     3. server.bundle1gd (if generaldelta active)
       
   607     4. server.bundle1
       
   608     """
       
   609     ui = repo.ui
       
   610     gd = 'generaldelta' in repo.requirements
       
   611 
       
   612     if gd:
       
   613         v = ui.configbool('server', 'bundle1gd.%s' % action)
       
   614         if v is not None:
       
   615             return v
       
   616 
       
   617     v = ui.configbool('server', 'bundle1.%s' % action)
       
   618     if v is not None:
       
   619         return v
       
   620 
       
   621     if gd:
       
   622         v = ui.configbool('server', 'bundle1gd')
       
   623         if v is not None:
       
   624             return v
       
   625 
       
   626     return ui.configbool('server', 'bundle1')
       
   627 
       
   628 def supportedcompengines(ui, proto, role):
       
   629     """Obtain the list of supported compression engines for a request."""
       
   630     assert role in (util.CLIENTROLE, util.SERVERROLE)
       
   631 
       
   632     compengines = util.compengines.supportedwireengines(role)
       
   633 
       
   634     # Allow config to override default list and ordering.
       
   635     if role == util.SERVERROLE:
       
   636         configengines = ui.configlist('server', 'compressionengines')
       
   637         config = 'server.compressionengines'
       
   638     else:
       
   639         # This is currently implemented mainly to facilitate testing. In most
       
   640         # cases, the server should be in charge of choosing a compression engine
       
   641         # because a server has the most to lose from a sub-optimal choice. (e.g.
       
   642         # CPU DoS due to an expensive engine or a network DoS due to poor
       
   643         # compression ratio).
       
   644         configengines = ui.configlist('experimental',
       
   645                                       'clientcompressionengines')
       
   646         config = 'experimental.clientcompressionengines'
       
   647 
       
   648     # No explicit config. Filter out the ones that aren't supposed to be
       
   649     # advertised and return default ordering.
       
   650     if not configengines:
       
   651         attr = 'serverpriority' if role == util.SERVERROLE else 'clientpriority'
       
   652         return [e for e in compengines
       
   653                 if getattr(e.wireprotosupport(), attr) > 0]
       
   654 
       
   655     # If compression engines are listed in the config, assume there is a good
       
   656     # reason for it (like server operators wanting to achieve specific
       
   657     # performance characteristics). So fail fast if the config references
       
   658     # unusable compression engines.
       
   659     validnames = set(e.name() for e in compengines)
       
   660     invalidnames = set(e for e in configengines if e not in validnames)
       
   661     if invalidnames:
       
   662         raise error.Abort(_('invalid compression engine defined in %s: %s') %
       
   663                           (config, ', '.join(sorted(invalidnames))))
       
   664 
       
   665     compengines = [e for e in compengines if e.name() in configengines]
       
   666     compengines = sorted(compengines,
       
   667                          key=lambda e: configengines.index(e.name()))
       
   668 
       
   669     if not compengines:
       
   670         raise error.Abort(_('%s config option does not specify any known '
       
   671                             'compression engines') % config,
       
   672                           hint=_('usable compression engines: %s') %
       
   673                           ', '.sorted(validnames))
       
   674 
       
   675     return compengines
       
   676 
       
   677 # list of commands
       
   678 commands = {}
       
   679 
       
   680 # Maps wire protocol name to operation type. This is used for permissions
       
   681 # checking. All defined @wireiprotocommand should have an entry in this
       
   682 # dict.
       
   683 permissions = {}
       
   684 
       
   685 def wireprotocommand(name, args=''):
       
   686     """decorator for wire protocol command"""
       
   687     def register(func):
       
   688         commands[name] = (func, args)
       
   689         return func
       
   690     return register
       
   691 
       
   692 # TODO define a more appropriate permissions type to use for this.
       
   693 permissions['batch'] = 'pull'
       
   694 @wireprotocommand('batch', 'cmds *')
       
   695 def batch(repo, proto, cmds, others):
       
   696     repo = repo.filtered("served")
       
   697     res = []
       
   698     for pair in cmds.split(';'):
       
   699         op, args = pair.split(' ', 1)
       
   700         vals = {}
       
   701         for a in args.split(','):
       
   702             if a:
       
   703                 n, v = a.split('=')
       
   704                 vals[unescapearg(n)] = unescapearg(v)
       
   705         func, spec = commands[op]
       
   706 
       
   707         # If the protocol supports permissions checking, perform that
       
   708         # checking on each batched command.
       
   709         # TODO formalize permission checking as part of protocol interface.
       
   710         if util.safehasattr(proto, 'checkperm'):
       
   711             # Assume commands with no defined permissions are writes / for
       
   712             # pushes. This is the safest from a security perspective because
       
   713             # it doesn't allow commands with undefined semantics from
       
   714             # bypassing permissions checks.
       
   715             proto.checkperm(permissions.get(op, 'push'))
       
   716 
       
   717         if spec:
       
   718             keys = spec.split()
       
   719             data = {}
       
   720             for k in keys:
       
   721                 if k == '*':
       
   722                     star = {}
       
   723                     for key in vals.keys():
       
   724                         if key not in keys:
       
   725                             star[key] = vals[key]
       
   726                     data['*'] = star
       
   727                 else:
       
   728                     data[k] = vals[k]
       
   729             result = func(repo, proto, *[data[k] for k in keys])
       
   730         else:
       
   731             result = func(repo, proto)
       
   732         if isinstance(result, ooberror):
       
   733             return result
       
   734         res.append(escapearg(result))
       
   735     return ';'.join(res)
       
   736 
       
   737 permissions['between'] = 'pull'
       
   738 @wireprotocommand('between', 'pairs')
       
   739 def between(repo, proto, pairs):
       
   740     pairs = [decodelist(p, '-') for p in pairs.split(" ")]
       
   741     r = []
       
   742     for b in repo.between(pairs):
       
   743         r.append(encodelist(b) + "\n")
       
   744     return "".join(r)
       
   745 
       
   746 permissions['branchmap'] = 'pull'
       
   747 @wireprotocommand('branchmap')
       
   748 def branchmap(repo, proto):
       
   749     branchmap = repo.branchmap()
       
   750     heads = []
       
   751     for branch, nodes in branchmap.iteritems():
       
   752         branchname = urlreq.quote(encoding.fromlocal(branch))
       
   753         branchnodes = encodelist(nodes)
       
   754         heads.append('%s %s' % (branchname, branchnodes))
       
   755     return '\n'.join(heads)
       
   756 
       
   757 permissions['branches'] = 'pull'
       
   758 @wireprotocommand('branches', 'nodes')
       
   759 def branches(repo, proto, nodes):
       
   760     nodes = decodelist(nodes)
       
   761     r = []
       
   762     for b in repo.branches(nodes):
       
   763         r.append(encodelist(b) + "\n")
       
   764     return "".join(r)
       
   765 
       
   766 permissions['clonebundles'] = 'pull'
       
   767 @wireprotocommand('clonebundles', '')
       
   768 def clonebundles(repo, proto):
       
   769     """Server command for returning info for available bundles to seed clones.
       
   770 
       
   771     Clients will parse this response and determine what bundle to fetch.
       
   772 
       
   773     Extensions may wrap this command to filter or dynamically emit data
       
   774     depending on the request. e.g. you could advertise URLs for the closest
       
   775     data center given the client's IP address.
       
   776     """
       
   777     return repo.vfs.tryread('clonebundles.manifest')
       
   778 
       
   779 wireprotocaps = ['lookup', 'changegroupsubset', 'branchmap', 'pushkey',
       
   780                  'known', 'getbundle', 'unbundlehash', 'batch']
       
   781 
       
   782 def _capabilities(repo, proto):
       
   783     """return a list of capabilities for a repo
       
   784 
       
   785     This function exists to allow extensions to easily wrap capabilities
       
   786     computation
       
   787 
       
   788     - returns a lists: easy to alter
       
   789     - change done here will be propagated to both `capabilities` and `hello`
       
   790       command without any other action needed.
       
   791     """
       
   792     # copy to prevent modification of the global list
       
   793     caps = list(wireprotocaps)
       
   794     if streamclone.allowservergeneration(repo):
       
   795         if repo.ui.configbool('server', 'preferuncompressed'):
       
   796             caps.append('stream-preferred')
       
   797         requiredformats = repo.requirements & repo.supportedformats
       
   798         # if our local revlogs are just revlogv1, add 'stream' cap
       
   799         if not requiredformats - {'revlogv1'}:
       
   800             caps.append('stream')
       
   801         # otherwise, add 'streamreqs' detailing our local revlog format
       
   802         else:
       
   803             caps.append('streamreqs=%s' % ','.join(sorted(requiredformats)))
       
   804     if repo.ui.configbool('experimental', 'bundle2-advertise'):
       
   805         capsblob = bundle2.encodecaps(bundle2.getrepocaps(repo, role='server'))
       
   806         caps.append('bundle2=' + urlreq.quote(capsblob))
       
   807     caps.append('unbundle=%s' % ','.join(bundle2.bundlepriority))
       
   808 
       
   809     if proto.name == 'http':
       
   810         caps.append('httpheader=%d' %
       
   811                     repo.ui.configint('server', 'maxhttpheaderlen'))
       
   812         if repo.ui.configbool('experimental', 'httppostargs'):
       
   813             caps.append('httppostargs')
       
   814 
       
   815         # FUTURE advertise 0.2rx once support is implemented
       
   816         # FUTURE advertise minrx and mintx after consulting config option
       
   817         caps.append('httpmediatype=0.1rx,0.1tx,0.2tx')
       
   818 
       
   819         compengines = supportedcompengines(repo.ui, proto, util.SERVERROLE)
       
   820         if compengines:
       
   821             comptypes = ','.join(urlreq.quote(e.wireprotosupport().name)
       
   822                                  for e in compengines)
       
   823             caps.append('compression=%s' % comptypes)
       
   824 
       
   825     return caps
       
   826 
       
   827 # If you are writing an extension and consider wrapping this function. Wrap
       
   828 # `_capabilities` instead.
       
   829 permissions['capabilities'] = 'pull'
       
   830 @wireprotocommand('capabilities')
       
   831 def capabilities(repo, proto):
       
   832     return ' '.join(_capabilities(repo, proto))
       
   833 
       
   834 permissions['changegroup'] = 'pull'
       
   835 @wireprotocommand('changegroup', 'roots')
       
   836 def changegroup(repo, proto, roots):
       
   837     nodes = decodelist(roots)
       
   838     outgoing = discovery.outgoing(repo, missingroots=nodes,
       
   839                                   missingheads=repo.heads())
       
   840     cg = changegroupmod.makechangegroup(repo, outgoing, '01', 'serve')
       
   841     gen = iter(lambda: cg.read(32768), '')
       
   842     return streamres(gen=gen)
       
   843 
       
   844 permissions['changegroupsubset'] = 'pull'
       
   845 @wireprotocommand('changegroupsubset', 'bases heads')
       
   846 def changegroupsubset(repo, proto, bases, heads):
       
   847     bases = decodelist(bases)
       
   848     heads = decodelist(heads)
       
   849     outgoing = discovery.outgoing(repo, missingroots=bases,
       
   850                                   missingheads=heads)
       
   851     cg = changegroupmod.makechangegroup(repo, outgoing, '01', 'serve')
       
   852     gen = iter(lambda: cg.read(32768), '')
       
   853     return streamres(gen=gen)
       
   854 
       
   855 permissions['debugwireargs'] = 'pull'
       
   856 @wireprotocommand('debugwireargs', 'one two *')
       
   857 def debugwireargs(repo, proto, one, two, others):
       
   858     # only accept optional args from the known set
       
   859     opts = options('debugwireargs', ['three', 'four'], others)
       
   860     return repo.debugwireargs(one, two, **pycompat.strkwargs(opts))
       
   861 
       
   862 permissions['getbundle'] = 'pull'
       
   863 @wireprotocommand('getbundle', '*')
       
   864 def getbundle(repo, proto, others):
       
   865     opts = options('getbundle', gboptsmap.keys(), others)
       
   866     for k, v in opts.iteritems():
       
   867         keytype = gboptsmap[k]
       
   868         if keytype == 'nodes':
       
   869             opts[k] = decodelist(v)
       
   870         elif keytype == 'csv':
       
   871             opts[k] = list(v.split(','))
       
   872         elif keytype == 'scsv':
       
   873             opts[k] = set(v.split(','))
       
   874         elif keytype == 'boolean':
       
   875             # Client should serialize False as '0', which is a non-empty string
       
   876             # so it evaluates as a True bool.
       
   877             if v == '0':
       
   878                 opts[k] = False
       
   879             else:
       
   880                 opts[k] = bool(v)
       
   881         elif keytype != 'plain':
       
   882             raise KeyError('unknown getbundle option type %s'
       
   883                            % keytype)
       
   884 
       
   885     if not bundle1allowed(repo, 'pull'):
       
   886         if not exchange.bundle2requested(opts.get('bundlecaps')):
       
   887             if proto.name == 'http':
       
   888                 return ooberror(bundle2required)
       
   889             raise error.Abort(bundle2requiredmain,
       
   890                               hint=bundle2requiredhint)
       
   891 
       
   892     prefercompressed = True
       
   893 
       
   894     try:
       
   895         if repo.ui.configbool('server', 'disablefullbundle'):
       
   896             # Check to see if this is a full clone.
       
   897             clheads = set(repo.changelog.heads())
       
   898             changegroup = opts.get('cg', True)
       
   899             heads = set(opts.get('heads', set()))
       
   900             common = set(opts.get('common', set()))
       
   901             common.discard(nullid)
       
   902             if changegroup and not common and clheads == heads:
       
   903                 raise error.Abort(
       
   904                     _('server has pull-based clones disabled'),
       
   905                     hint=_('remove --pull if specified or upgrade Mercurial'))
       
   906 
       
   907         info, chunks = exchange.getbundlechunks(repo, 'serve',
       
   908                                                 **pycompat.strkwargs(opts))
       
   909         prefercompressed = info.get('prefercompressed', True)
       
   910     except error.Abort as exc:
       
   911         # cleanly forward Abort error to the client
       
   912         if not exchange.bundle2requested(opts.get('bundlecaps')):
       
   913             if proto.name == 'http':
       
   914                 return ooberror(str(exc) + '\n')
       
   915             raise # cannot do better for bundle1 + ssh
       
   916         # bundle2 request expect a bundle2 reply
       
   917         bundler = bundle2.bundle20(repo.ui)
       
   918         manargs = [('message', str(exc))]
       
   919         advargs = []
       
   920         if exc.hint is not None:
       
   921             advargs.append(('hint', exc.hint))
       
   922         bundler.addpart(bundle2.bundlepart('error:abort',
       
   923                                            manargs, advargs))
       
   924         chunks = bundler.getchunks()
       
   925         prefercompressed = False
       
   926 
       
   927     return streamres(gen=chunks, prefer_uncompressed=not prefercompressed)
       
   928 
       
   929 permissions['heads'] = 'pull'
       
   930 @wireprotocommand('heads')
       
   931 def heads(repo, proto):
       
   932     h = repo.heads()
       
   933     return encodelist(h) + "\n"
       
   934 
       
   935 permissions['hello'] = 'pull'
       
   936 @wireprotocommand('hello')
       
   937 def hello(repo, proto):
       
   938     '''the hello command returns a set of lines describing various
       
   939     interesting things about the server, in an RFC822-like format.
       
   940     Currently the only one defined is "capabilities", which
       
   941     consists of a line in the form:
       
   942 
       
   943     capabilities: space separated list of tokens
       
   944     '''
       
   945     return "capabilities: %s\n" % (capabilities(repo, proto))
       
   946 
       
   947 permissions['listkeys'] = 'pull'
       
   948 @wireprotocommand('listkeys', 'namespace')
       
   949 def listkeys(repo, proto, namespace):
       
   950     d = repo.listkeys(encoding.tolocal(namespace)).items()
       
   951     return pushkeymod.encodekeys(d)
       
   952 
       
   953 permissions['lookup'] = 'pull'
       
   954 @wireprotocommand('lookup', 'key')
       
   955 def lookup(repo, proto, key):
       
   956     try:
       
   957         k = encoding.tolocal(key)
       
   958         c = repo[k]
       
   959         r = c.hex()
       
   960         success = 1
       
   961     except Exception as inst:
       
   962         r = str(inst)
       
   963         success = 0
       
   964     return "%d %s\n" % (success, r)
       
   965 
       
   966 permissions['known'] = 'pull'
       
   967 @wireprotocommand('known', 'nodes *')
       
   968 def known(repo, proto, nodes, others):
       
   969     return ''.join(b and "1" or "0" for b in repo.known(decodelist(nodes)))
       
   970 
       
   971 permissions['pushkey'] = 'push'
       
   972 @wireprotocommand('pushkey', 'namespace key old new')
       
   973 def pushkey(repo, proto, namespace, key, old, new):
       
   974     # compatibility with pre-1.8 clients which were accidentally
       
   975     # sending raw binary nodes rather than utf-8-encoded hex
       
   976     if len(new) == 20 and util.escapestr(new) != new:
       
   977         # looks like it could be a binary node
       
   978         try:
       
   979             new.decode('utf-8')
       
   980             new = encoding.tolocal(new) # but cleanly decodes as UTF-8
       
   981         except UnicodeDecodeError:
       
   982             pass # binary, leave unmodified
       
   983     else:
       
   984         new = encoding.tolocal(new) # normal path
       
   985 
       
   986     if util.safehasattr(proto, 'restore'):
       
   987 
       
   988         proto.redirect()
       
   989 
       
   990         try:
       
   991             r = repo.pushkey(encoding.tolocal(namespace), encoding.tolocal(key),
       
   992                              encoding.tolocal(old), new) or False
       
   993         except error.Abort:
       
   994             r = False
       
   995 
       
   996         output = proto.restore()
       
   997 
       
   998         return '%s\n%s' % (int(r), output)
       
   999 
       
  1000     r = repo.pushkey(encoding.tolocal(namespace), encoding.tolocal(key),
       
  1001                      encoding.tolocal(old), new)
       
  1002     return '%s\n' % int(r)
       
  1003 
       
  1004 permissions['stream_out'] = 'pull'
       
  1005 @wireprotocommand('stream_out')
       
  1006 def stream(repo, proto):
       
  1007     '''If the server supports streaming clone, it advertises the "stream"
       
  1008     capability with a value representing the version and flags of the repo
       
  1009     it is serving. Client checks to see if it understands the format.
       
  1010     '''
       
  1011     return streamres_legacy(streamclone.generatev1wireproto(repo))
       
  1012 
       
  1013 permissions['unbundle'] = 'push'
       
  1014 @wireprotocommand('unbundle', 'heads')
       
  1015 def unbundle(repo, proto, heads):
       
  1016     their_heads = decodelist(heads)
       
  1017 
       
  1018     try:
       
  1019         proto.redirect()
       
  1020 
       
  1021         exchange.check_heads(repo, their_heads, 'preparing changes')
       
  1022 
       
  1023         # write bundle data to temporary file because it can be big
       
  1024         fd, tempname = tempfile.mkstemp(prefix='hg-unbundle-')
       
  1025         fp = os.fdopen(fd, pycompat.sysstr('wb+'))
       
  1026         r = 0
       
  1027         try:
       
  1028             proto.getfile(fp)
       
  1029             fp.seek(0)
       
  1030             gen = exchange.readbundle(repo.ui, fp, None)
       
  1031             if (isinstance(gen, changegroupmod.cg1unpacker)
       
  1032                 and not bundle1allowed(repo, 'push')):
       
  1033                 if proto.name == 'http':
       
  1034                     # need to special case http because stderr do not get to
       
  1035                     # the http client on failed push so we need to abuse some
       
  1036                     # other error type to make sure the message get to the
       
  1037                     # user.
       
  1038                     return ooberror(bundle2required)
       
  1039                 raise error.Abort(bundle2requiredmain,
       
  1040                                   hint=bundle2requiredhint)
       
  1041 
       
  1042             r = exchange.unbundle(repo, gen, their_heads, 'serve',
       
  1043                                   proto._client())
       
  1044             if util.safehasattr(r, 'addpart'):
       
  1045                 # The return looks streamable, we are in the bundle2 case and
       
  1046                 # should return a stream.
       
  1047                 return streamres_legacy(gen=r.getchunks())
       
  1048             return pushres(r)
       
  1049 
       
  1050         finally:
       
  1051             fp.close()
       
  1052             os.unlink(tempname)
       
  1053 
       
  1054     except (error.BundleValueError, error.Abort, error.PushRaced) as exc:
       
  1055         # handle non-bundle2 case first
       
  1056         if not getattr(exc, 'duringunbundle2', False):
       
  1057             try:
       
  1058                 raise
       
  1059             except error.Abort:
       
  1060                 # The old code we moved used util.stderr directly.
       
  1061                 # We did not change it to minimise code change.
       
  1062                 # This need to be moved to something proper.
       
  1063                 # Feel free to do it.
       
  1064                 util.stderr.write("abort: %s\n" % exc)
       
  1065                 if exc.hint is not None:
       
  1066                     util.stderr.write("(%s)\n" % exc.hint)
       
  1067                 return pushres(0)
       
  1068             except error.PushRaced:
       
  1069                 return pusherr(str(exc))
       
  1070 
       
  1071         bundler = bundle2.bundle20(repo.ui)
       
  1072         for out in getattr(exc, '_bundle2salvagedoutput', ()):
       
  1073             bundler.addpart(out)
       
  1074         try:
       
  1075             try:
       
  1076                 raise
       
  1077             except error.PushkeyFailed as exc:
       
  1078                 # check client caps
       
  1079                 remotecaps = getattr(exc, '_replycaps', None)
       
  1080                 if (remotecaps is not None
       
  1081                         and 'pushkey' not in remotecaps.get('error', ())):
       
  1082                     # no support remote side, fallback to Abort handler.
       
  1083                     raise
       
  1084                 part = bundler.newpart('error:pushkey')
       
  1085                 part.addparam('in-reply-to', exc.partid)
       
  1086                 if exc.namespace is not None:
       
  1087                     part.addparam('namespace', exc.namespace, mandatory=False)
       
  1088                 if exc.key is not None:
       
  1089                     part.addparam('key', exc.key, mandatory=False)
       
  1090                 if exc.new is not None:
       
  1091                     part.addparam('new', exc.new, mandatory=False)
       
  1092                 if exc.old is not None:
       
  1093                     part.addparam('old', exc.old, mandatory=False)
       
  1094                 if exc.ret is not None:
       
  1095                     part.addparam('ret', exc.ret, mandatory=False)
       
  1096         except error.BundleValueError as exc:
       
  1097             errpart = bundler.newpart('error:unsupportedcontent')
       
  1098             if exc.parttype is not None:
       
  1099                 errpart.addparam('parttype', exc.parttype)
       
  1100             if exc.params:
       
  1101                 errpart.addparam('params', '\0'.join(exc.params))
       
  1102         except error.Abort as exc:
       
  1103             manargs = [('message', str(exc))]
       
  1104             advargs = []
       
  1105             if exc.hint is not None:
       
  1106                 advargs.append(('hint', exc.hint))
       
  1107             bundler.addpart(bundle2.bundlepart('error:abort',
       
  1108                                                manargs, advargs))
       
  1109         except error.PushRaced as exc:
       
  1110             bundler.newpart('error:pushraced', [('message', str(exc))])
       
  1111         return streamres_legacy(gen=bundler.getchunks())