mercurial/wireprotoserver.py
changeset 43076 2372284d9457
parent 42896 7e19b640c53e
child 43077 687b865b95ad
--- a/mercurial/wireprotoserver.py	Sat Oct 05 10:29:34 2019 -0400
+++ b/mercurial/wireprotoserver.py	Sun Oct 06 09:45:02 2019 -0400
@@ -21,9 +21,7 @@
     wireprotov1server,
     wireprotov2server,
 )
-from .interfaces import (
-    util as interfaceutil,
-)
+from .interfaces import util as interfaceutil
 from .utils import (
     cborutil,
     compression,
@@ -43,6 +41,7 @@
 SSHV1 = wireprototypes.SSHV1
 SSHV2 = wireprototypes.SSHV2
 
+
 def decodevaluefromheaders(req, headerprefix):
     """Decode a long value from multiple HTTP request headers.
 
@@ -59,6 +58,7 @@
 
     return ''.join(chunks)
 
+
 @interfaceutil.implementer(wireprototypes.baseprotocolhandler)
 class httpv1protocolhandler(object):
     def __init__(self, req, ui, checkperm):
@@ -90,8 +90,11 @@
         args = self._req.qsparams.asdictoflists()
         postlen = int(self._req.headers.get(b'X-HgArgs-Post', 0))
         if postlen:
-            args.update(urlreq.parseqs(
-                self._req.bodyfh.read(postlen), keep_blank_values=True))
+            args.update(
+                urlreq.parseqs(
+                    self._req.bodyfh.read(postlen), keep_blank_values=True
+                )
+            )
             return args
 
         argvalue = decodevaluefromheaders(self._req, b'X-HgArg')
@@ -132,13 +135,15 @@
         return 'remote:%s:%s:%s' % (
             self._req.urlscheme,
             urlreq.quote(self._req.remotehost or ''),
-            urlreq.quote(self._req.remoteuser or ''))
+            urlreq.quote(self._req.remoteuser or ''),
+        )
 
     def addcapabilities(self, repo, caps):
         caps.append(b'batch')
 
-        caps.append('httpheader=%d' %
-                    repo.ui.configint('server', 'maxhttpheaderlen'))
+        caps.append(
+            'httpheader=%d' % repo.ui.configint('server', 'maxhttpheaderlen')
+        )
         if repo.ui.configbool('experimental', 'httppostargs'):
             caps.append('httppostargs')
 
@@ -146,11 +151,13 @@
         # FUTURE advertise minrx and mintx after consulting config option
         caps.append('httpmediatype=0.1rx,0.1tx,0.2tx')
 
-        compengines = wireprototypes.supportedcompengines(repo.ui,
-            compression.SERVERROLE)
+        compengines = wireprototypes.supportedcompengines(
+            repo.ui, compression.SERVERROLE
+        )
         if compengines:
-            comptypes = ','.join(urlreq.quote(e.wireprotosupport().name)
-                                 for e in compengines)
+            comptypes = ','.join(
+                urlreq.quote(e.wireprotosupport().name) for e in compengines
+            )
             caps.append('compression=%s' % comptypes)
 
         return caps
@@ -158,6 +165,7 @@
     def checkperm(self, perm):
         return self._checkperm(perm)
 
+
 # This method exists mostly so that extensions like remotefilelog can
 # disable a kludgey legacy method only over http. As of early 2018,
 # there are no other known users, so with any luck we can discard this
@@ -165,6 +173,7 @@
 def iscmd(cmd):
     return cmd in wireprotov1server.commands
 
+
 def handlewsgirequest(rctx, req, res, checkperm):
     """Possibly process a wire protocol request.
 
@@ -212,8 +221,9 @@
         res.setbodybytes('0\n%s\n' % b'Not Found')
         return True
 
-    proto = httpv1protocolhandler(req, repo.ui,
-                                  lambda perm: checkperm(rctx, req, perm))
+    proto = httpv1protocolhandler(
+        req, repo.ui, lambda perm: checkperm(rctx, req, perm)
+    )
 
     # The permissions checker should be the only thing that can raise an
     # ErrorResponse. It is kind of a layer violation to catch an hgweb
@@ -231,6 +241,7 @@
 
     return True
 
+
 def _availableapis(repo):
     apis = set()
 
@@ -243,6 +254,7 @@
 
     return apis
 
+
 def handlewsgiapirequest(rctx, req, res, checkperm):
     """Handle requests to /api/*."""
     assert req.dispatchparts[0] == b'api'
@@ -266,8 +278,12 @@
     if req.dispatchparts == [b'api']:
         res.status = b'200 OK'
         res.headers[b'Content-Type'] = b'text/plain'
-        lines = [_('APIs can be accessed at /api/<name>, where <name> can be '
-                   'one of the following:\n')]
+        lines = [
+            _(
+                'APIs can be accessed at /api/<name>, where <name> can be '
+                'one of the following:\n'
+            )
+        ]
         if availableapis:
             lines.extend(sorted(availableapis))
         else:
@@ -280,8 +296,10 @@
     if proto not in API_HANDLERS:
         res.status = b'404 Not Found'
         res.headers[b'Content-Type'] = b'text/plain'
-        res.setbodybytes(_('Unknown API: %s\nKnown APIs: %s') % (
-            proto, b', '.join(sorted(availableapis))))
+        res.setbodybytes(
+            _('Unknown API: %s\nKnown APIs: %s')
+            % (proto, b', '.join(sorted(availableapis)))
+        )
         return
 
     if proto not in availableapis:
@@ -290,8 +308,10 @@
         res.setbodybytes(_('API %s not enabled\n') % proto)
         return
 
-    API_HANDLERS[proto]['handler'](rctx, req, res, checkperm,
-                                   req.dispatchparts[2:])
+    API_HANDLERS[proto]['handler'](
+        rctx, req, res, checkperm, req.dispatchparts[2:]
+    )
+
 
 # Maps API name to metadata so custom API can be registered.
 # Keys are:
@@ -312,6 +332,7 @@
     },
 }
 
+
 def _httpresponsetype(ui, proto, prefer_uncompressed):
     """Determine the appropriate response type and compression settings.
 
@@ -327,8 +348,9 @@
 
         # Now find an agreed upon compression format.
         compformats = wireprotov1server.clientcompressionsupport(proto)
-        for engine in wireprototypes.supportedcompengines(ui,
-                compression.SERVERROLE):
+        for engine in wireprototypes.supportedcompengines(
+            ui, compression.SERVERROLE
+        ):
             if engine.wireprotosupport().name in compformats:
                 opts = {}
                 level = ui.configint('server', '%slevel' % engine.name())
@@ -346,6 +368,7 @@
     opts = {'level': ui.configint('server', 'zliblevel')}
     return HGTYPE, util.compengines['zlib'], opts
 
+
 def processcapabilitieshandshake(repo, req, res, proto):
     """Called during a ?cmd=capabilities request.
 
@@ -394,6 +417,7 @@
 
     return True
 
+
 def _callhttp(repo, req, res, proto, cmd):
     # Avoid cycle involving hg module.
     from .hgweb import common as hgwebcommon
@@ -423,16 +447,19 @@
             res.setbodygen(bodygen)
 
     if not wireprotov1server.commands.commandavailable(cmd, proto):
-        setresponse(HTTP_OK, HGERRTYPE,
-                    _('requested wire protocol command is not available over '
-                      'HTTP'))
+        setresponse(
+            HTTP_OK,
+            HGERRTYPE,
+            _('requested wire protocol command is not available over ' 'HTTP'),
+        )
         return
 
     proto.checkperm(wireprotov1server.commands[cmd].permission)
 
     # Possibly handle a modern client wanting to switch protocols.
-    if (cmd == 'capabilities' and
-        processcapabilitieshandshake(repo, req, res, proto)):
+    if cmd == 'capabilities' and processcapabilitieshandshake(
+        repo, req, res, proto
+    ):
 
         return
 
@@ -450,7 +477,8 @@
         # This code for compression should not be streamres specific. It
         # is here because we only compress streamres at the moment.
         mediatype, engine, engineopts = _httpresponsetype(
-            repo.ui, proto, rsp.prefer_uncompressed)
+            repo.ui, proto, rsp.prefer_uncompressed
+        )
         gen = engine.compressstream(gen, engineopts)
 
         if mediatype == HGTYPE2:
@@ -469,27 +497,32 @@
     else:
         raise error.ProgrammingError('hgweb.protocol internal failure', rsp)
 
+
 def _sshv1respondbytes(fout, value):
     """Send a bytes response for protocol version 1."""
     fout.write('%d\n' % len(value))
     fout.write(value)
     fout.flush()
 
+
 def _sshv1respondstream(fout, source):
     write = fout.write
     for chunk in source.gen:
         write(chunk)
     fout.flush()
 
+
 def _sshv1respondooberror(fout, ferr, rsp):
     ferr.write(b'%s\n-\n' % rsp)
     ferr.flush()
     fout.write(b'\n')
     fout.flush()
 
+
 @interfaceutil.implementer(wireprototypes.baseprotocolhandler)
 class sshv1protocolhandler(object):
     """Handler for requests services via version 1 of SSH protocol."""
+
     def __init__(self, ui, fin, fout):
         self._ui = ui
         self._fin = fin
@@ -557,6 +590,7 @@
     def checkperm(self, perm):
         pass
 
+
 class sshv2protocolhandler(sshv1protocolhandler):
     """Protocol handler for version 2 of the SSH protocol."""
 
@@ -567,6 +601,7 @@
     def addcapabilities(self, repo, caps):
         return caps
 
+
 def _runsshserver(ui, repo, fin, fout, ev):
     # This function operates like a state machine of sorts. The following
     # states are defined:
@@ -638,9 +673,11 @@
             # handle it.
             if request.startswith(b'upgrade '):
                 if protoswitched:
-                    _sshv1respondooberror(fout, ui.ferr,
-                                          b'cannot upgrade protocols multiple '
-                                          b'times')
+                    _sshv1respondooberror(
+                        fout,
+                        ui.ferr,
+                        b'cannot upgrade protocols multiple ' b'times',
+                    )
                     state = 'shutdown'
                     continue
 
@@ -648,7 +685,8 @@
                 continue
 
             available = wireprotov1server.commands.commandavailable(
-                request, proto)
+                request, proto
+            )
 
             # This command isn't available. Send an empty response and go
             # back to waiting for a new command.
@@ -676,8 +714,10 @@
             elif isinstance(rsp, wireprototypes.ooberror):
                 _sshv1respondooberror(fout, ui.ferr, rsp.message)
             else:
-                raise error.ProgrammingError('unhandled response type from '
-                                             'wire protocol command: %s' % rsp)
+                raise error.ProgrammingError(
+                    'unhandled response type from '
+                    'wire protocol command: %s' % rsp
+                )
 
         # For now, protocol version 2 serving just goes back to version 1.
         elif state == 'protov2-serving':
@@ -741,9 +781,11 @@
                 request = fin.readline()[:-1]
 
                 if request != line:
-                    _sshv1respondooberror(fout, ui.ferr,
-                                          b'malformed handshake protocol: '
-                                          b'missing %s' % line)
+                    _sshv1respondooberror(
+                        fout,
+                        ui.ferr,
+                        b'malformed handshake protocol: ' b'missing %s' % line,
+                    )
                     ok = False
                     state = 'shutdown'
                     break
@@ -753,9 +795,12 @@
 
             request = fin.read(81)
             if request != b'%s-%s' % (b'0' * 40, b'0' * 40):
-                _sshv1respondooberror(fout, ui.ferr,
-                                      b'malformed handshake protocol: '
-                                      b'missing between argument value')
+                _sshv1respondooberror(
+                    fout,
+                    ui.ferr,
+                    b'malformed handshake protocol: '
+                    b'missing between argument value',
+                )
                 state = 'shutdown'
                 continue
 
@@ -780,8 +825,10 @@
             break
 
         else:
-            raise error.ProgrammingError('unhandled ssh server state: %s' %
-                                         state)
+            raise error.ProgrammingError(
+                'unhandled ssh server state: %s' % state
+            )
+
 
 class sshserver(object):
     def __init__(self, ui, repo, logfh=None):
@@ -792,9 +839,11 @@
         # Log write I/O to stdout and stderr if configured.
         if logfh:
             self._fout = util.makeloggingfileobject(
-                logfh, self._fout, 'o', logdata=True)
+                logfh, self._fout, 'o', logdata=True
+            )
             ui.ferr = util.makeloggingfileobject(
-                logfh, ui.ferr, 'e', logdata=True)
+                logfh, ui.ferr, 'e', logdata=True
+            )
 
     def serve_forever(self):
         self.serveuntil(threading.Event())