mercurial/wireproto.py
changeset 37293 d5d665f6615a
parent 37119 d4a2e0d5d042
child 37295 45b39c69fae0
--- a/mercurial/wireproto.py	Mon Mar 26 14:34:32 2018 -0700
+++ b/mercurial/wireproto.py	Wed Mar 28 14:05:29 2018 -0700
@@ -42,13 +42,6 @@
 urlerr = util.urlerr
 urlreq = util.urlreq
 
-bytesresponse = wireprototypes.bytesresponse
-ooberror = wireprototypes.ooberror
-pushres = wireprototypes.pushres
-pusherr = wireprototypes.pusherr
-streamres = wireprototypes.streamres
-streamres_legacy = wireprototypes.streamreslegacy
-
 bundle2requiredmain = _('incompatible Mercurial client; bundle2 required')
 bundle2requiredhint = _('see https://www.mercurial-scm.org/wiki/'
                         'IncompatibleClient')
@@ -771,17 +764,17 @@
             result = func(repo, proto, *[data[k] for k in keys])
         else:
             result = func(repo, proto)
-        if isinstance(result, ooberror):
+        if isinstance(result, wireprototypes.ooberror):
             return result
 
         # For now, all batchable commands must return bytesresponse or
         # raw bytes (for backwards compatibility).
-        assert isinstance(result, (bytesresponse, bytes))
-        if isinstance(result, bytesresponse):
+        assert isinstance(result, (wireprototypes.bytesresponse, bytes))
+        if isinstance(result, wireprototypes.bytesresponse):
             result = result.data
         res.append(escapearg(result))
 
-    return bytesresponse(';'.join(res))
+    return wireprototypes.bytesresponse(';'.join(res))
 
 @wireprotocommand('between', 'pairs', transportpolicy=POLICY_V1_ONLY,
                   permission='pull')
@@ -791,7 +784,7 @@
     for b in repo.between(pairs):
         r.append(encodelist(b) + "\n")
 
-    return bytesresponse(''.join(r))
+    return wireprototypes.bytesresponse(''.join(r))
 
 @wireprotocommand('branchmap', permission='pull')
 def branchmap(repo, proto):
@@ -802,7 +795,7 @@
         branchnodes = encodelist(nodes)
         heads.append('%s %s' % (branchname, branchnodes))
 
-    return bytesresponse('\n'.join(heads))
+    return wireprototypes.bytesresponse('\n'.join(heads))
 
 @wireprotocommand('branches', 'nodes', transportpolicy=POLICY_V1_ONLY,
                   permission='pull')
@@ -812,7 +805,7 @@
     for b in repo.branches(nodes):
         r.append(encodelist(b) + "\n")
 
-    return bytesresponse(''.join(r))
+    return wireprototypes.bytesresponse(''.join(r))
 
 @wireprotocommand('clonebundles', '', permission='pull')
 def clonebundles(repo, proto):
@@ -824,7 +817,8 @@
     depending on the request. e.g. you could advertise URLs for the closest
     data center given the client's IP address.
     """
-    return bytesresponse(repo.vfs.tryread('clonebundles.manifest'))
+    return wireprototypes.bytesresponse(
+        repo.vfs.tryread('clonebundles.manifest'))
 
 wireprotocaps = ['lookup', 'branchmap', 'pushkey',
                  'known', 'getbundle', 'unbundlehash']
@@ -868,7 +862,7 @@
 # `_capabilities` instead.
 @wireprotocommand('capabilities', permission='pull')
 def capabilities(repo, proto):
-    return bytesresponse(' '.join(_capabilities(repo, proto)))
+    return wireprototypes.bytesresponse(' '.join(_capabilities(repo, proto)))
 
 @wireprotocommand('changegroup', 'roots', transportpolicy=POLICY_V1_ONLY,
                   permission='pull')
@@ -878,7 +872,7 @@
                                   missingheads=repo.heads())
     cg = changegroupmod.makechangegroup(repo, outgoing, '01', 'serve')
     gen = iter(lambda: cg.read(32768), '')
-    return streamres(gen=gen)
+    return wireprototypes.streamres(gen=gen)
 
 @wireprotocommand('changegroupsubset', 'bases heads',
                   transportpolicy=POLICY_V1_ONLY,
@@ -890,15 +884,15 @@
                                   missingheads=heads)
     cg = changegroupmod.makechangegroup(repo, outgoing, '01', 'serve')
     gen = iter(lambda: cg.read(32768), '')
-    return streamres(gen=gen)
+    return wireprototypes.streamres(gen=gen)
 
 @wireprotocommand('debugwireargs', 'one two *',
                   permission='pull')
 def debugwireargs(repo, proto, one, two, others):
     # only accept optional args from the known set
     opts = options('debugwireargs', ['three', 'four'], others)
-    return bytesresponse(repo.debugwireargs(one, two,
-                                            **pycompat.strkwargs(opts)))
+    return wireprototypes.bytesresponse(repo.debugwireargs(
+        one, two, **pycompat.strkwargs(opts)))
 
 @wireprotocommand('getbundle', '*', permission='pull')
 def getbundle(repo, proto, others):
@@ -925,7 +919,7 @@
     if not bundle1allowed(repo, 'pull'):
         if not exchange.bundle2requested(opts.get('bundlecaps')):
             if proto.name == 'http-v1':
-                return ooberror(bundle2required)
+                return wireprototypes.ooberror(bundle2required)
             raise error.Abort(bundle2requiredmain,
                               hint=bundle2requiredhint)
 
@@ -951,7 +945,7 @@
         # cleanly forward Abort error to the client
         if not exchange.bundle2requested(opts.get('bundlecaps')):
             if proto.name == 'http-v1':
-                return ooberror(pycompat.bytestr(exc) + '\n')
+                return wireprototypes.ooberror(pycompat.bytestr(exc) + '\n')
             raise # cannot do better for bundle1 + ssh
         # bundle2 request expect a bundle2 reply
         bundler = bundle2.bundle20(repo.ui)
@@ -964,12 +958,13 @@
         chunks = bundler.getchunks()
         prefercompressed = False
 
-    return streamres(gen=chunks, prefer_uncompressed=not prefercompressed)
+    return wireprototypes.streamres(
+        gen=chunks, prefer_uncompressed=not prefercompressed)
 
 @wireprotocommand('heads', permission='pull')
 def heads(repo, proto):
     h = repo.heads()
-    return bytesresponse(encodelist(h) + '\n')
+    return wireprototypes.bytesresponse(encodelist(h) + '\n')
 
 @wireprotocommand('hello', permission='pull')
 def hello(repo, proto):
@@ -984,12 +979,12 @@
         capabilities: <token0> <token1> <token2>
     """
     caps = capabilities(repo, proto).data
-    return bytesresponse('capabilities: %s\n' % caps)
+    return wireprototypes.bytesresponse('capabilities: %s\n' % caps)
 
 @wireprotocommand('listkeys', 'namespace', permission='pull')
 def listkeys(repo, proto, namespace):
     d = sorted(repo.listkeys(encoding.tolocal(namespace)).items())
-    return bytesresponse(pushkeymod.encodekeys(d))
+    return wireprototypes.bytesresponse(pushkeymod.encodekeys(d))
 
 @wireprotocommand('lookup', 'key', permission='pull')
 def lookup(repo, proto, key):
@@ -1001,12 +996,12 @@
     except Exception as inst:
         r = stringutil.forcebytestr(inst)
         success = 0
-    return bytesresponse('%d %s\n' % (success, r))
+    return wireprototypes.bytesresponse('%d %s\n' % (success, r))
 
 @wireprotocommand('known', 'nodes *', permission='pull')
 def known(repo, proto, nodes, others):
     v = ''.join(b and '1' or '0' for b in repo.known(decodelist(nodes)))
-    return bytesresponse(v)
+    return wireprototypes.bytesresponse(v)
 
 @wireprotocommand('pushkey', 'namespace key old new', permission='push')
 def pushkey(repo, proto, namespace, key, old, new):
@@ -1027,7 +1022,7 @@
                          encoding.tolocal(old), new) or False
 
     output = output.getvalue() if output else ''
-    return bytesresponse('%d\n%s' % (int(r), output))
+    return wireprototypes.bytesresponse('%d\n%s' % (int(r), output))
 
 @wireprotocommand('stream_out', permission='pull')
 def stream(repo, proto):
@@ -1035,7 +1030,8 @@
     capability with a value representing the version and flags of the repo
     it is serving. Client checks to see if it understands the format.
     '''
-    return streamres_legacy(streamclone.generatev1wireproto(repo))
+    return wireprototypes.streamreslegacy(
+        streamclone.generatev1wireproto(repo))
 
 @wireprotocommand('unbundle', 'heads', permission='push')
 def unbundle(repo, proto, heads):
@@ -1060,7 +1056,7 @@
                         # the http client on failed push so we need to abuse
                         # some other error type to make sure the message get to
                         # the user.
-                        return ooberror(bundle2required)
+                        return wireprototypes.ooberror(bundle2required)
                     raise error.Abort(bundle2requiredmain,
                                       hint=bundle2requiredhint)
 
@@ -1069,8 +1065,9 @@
                 if util.safehasattr(r, 'addpart'):
                     # The return looks streamable, we are in the bundle2 case
                     # and should return a stream.
-                    return streamres_legacy(gen=r.getchunks())
-                return pushres(r, output.getvalue() if output else '')
+                    return wireprototypes.streamreslegacy(gen=r.getchunks())
+                return wireprototypes.pushres(
+                    r, output.getvalue() if output else '')
 
             finally:
                 fp.close()
@@ -1090,10 +1087,12 @@
                     if exc.hint is not None:
                         procutil.stderr.write("(%s)\n" % exc.hint)
                     procutil.stderr.flush()
-                    return pushres(0, output.getvalue() if output else '')
+                    return wireprototypes.pushres(
+                        0, output.getvalue() if output else '')
                 except error.PushRaced:
-                    return pusherr(pycompat.bytestr(exc),
-                                   output.getvalue() if output else '')
+                    return wireprototypes.pusherr(
+                        pycompat.bytestr(exc),
+                        output.getvalue() if output else '')
 
             bundler = bundle2.bundle20(repo.ui)
             for out in getattr(exc, '_bundle2salvagedoutput', ()):
@@ -1137,4 +1136,4 @@
             except error.PushRaced as exc:
                 bundler.newpart('error:pushraced',
                                 [('message', stringutil.forcebytestr(exc))])
-            return streamres_legacy(gen=bundler.getchunks())
+            return wireprototypes.streamreslegacy(gen=bundler.getchunks())