mercurial/wireproto.py
changeset 36074 2f7290555c96
parent 36073 cd6ab329c5c7
child 36221 62bca1c50e96
--- a/mercurial/wireproto.py	Wed Feb 07 16:29:05 2018 -0800
+++ b/mercurial/wireproto.py	Wed Feb 07 20:27:36 2018 -0800
@@ -37,6 +37,7 @@
 urlerr = util.urlerr
 urlreq = util.urlreq
 
+bytesresponse = wireprototypes.bytesresponse
 ooberror = wireprototypes.ooberror
 pushres = wireprototypes.pushres
 pusherr = wireprototypes.pusherr
@@ -696,8 +697,15 @@
             result = func(repo, proto)
         if isinstance(result, ooberror):
             return result
+
+        # For now, all batchable commands must return bytesresponse or
+        # raw bytes (for backwards compatibility).
+        assert isinstance(result, (bytesresponse, bytes))
+        if isinstance(result, bytesresponse):
+            result = result.data
         res.append(escapearg(result))
-    return ';'.join(res)
+
+    return bytesresponse(';'.join(res))
 
 @wireprotocommand('between', 'pairs')
 def between(repo, proto, pairs):
@@ -705,7 +713,8 @@
     r = []
     for b in repo.between(pairs):
         r.append(encodelist(b) + "\n")
-    return "".join(r)
+
+    return bytesresponse(''.join(r))
 
 @wireprotocommand('branchmap')
 def branchmap(repo, proto):
@@ -715,7 +724,8 @@
         branchname = urlreq.quote(encoding.fromlocal(branch))
         branchnodes = encodelist(nodes)
         heads.append('%s %s' % (branchname, branchnodes))
-    return '\n'.join(heads)
+
+    return bytesresponse('\n'.join(heads))
 
 @wireprotocommand('branches', 'nodes')
 def branches(repo, proto, nodes):
@@ -723,7 +733,8 @@
     r = []
     for b in repo.branches(nodes):
         r.append(encodelist(b) + "\n")
-    return "".join(r)
+
+    return bytesresponse(''.join(r))
 
 @wireprotocommand('clonebundles', '')
 def clonebundles(repo, proto):
@@ -735,7 +746,7 @@
     depending on the request. e.g. you could advertise URLs for the closest
     data center given the client's IP address.
     """
-    return repo.vfs.tryread('clonebundles.manifest')
+    return bytesresponse(repo.vfs.tryread('clonebundles.manifest'))
 
 wireprotocaps = ['lookup', 'changegroupsubset', 'branchmap', 'pushkey',
                  'known', 'getbundle', 'unbundlehash', 'batch']
@@ -789,7 +800,7 @@
 # `_capabilities` instead.
 @wireprotocommand('capabilities')
 def capabilities(repo, proto):
-    return ' '.join(_capabilities(repo, proto))
+    return bytesresponse(' '.join(_capabilities(repo, proto)))
 
 @wireprotocommand('changegroup', 'roots')
 def changegroup(repo, proto, roots):
@@ -814,7 +825,8 @@
 def debugwireargs(repo, proto, one, two, others):
     # only accept optional args from the known set
     opts = options('debugwireargs', ['three', 'four'], others)
-    return repo.debugwireargs(one, two, **pycompat.strkwargs(opts))
+    return bytesresponse(repo.debugwireargs(one, two,
+                                            **pycompat.strkwargs(opts)))
 
 @wireprotocommand('getbundle', '*')
 def getbundle(repo, proto, others):
@@ -885,7 +897,7 @@
 @wireprotocommand('heads')
 def heads(repo, proto):
     h = repo.heads()
-    return encodelist(h) + "\n"
+    return bytesresponse(encodelist(h) + '\n')
 
 @wireprotocommand('hello')
 def hello(repo, proto):
@@ -896,12 +908,13 @@
 
     capabilities: space separated list of tokens
     '''
-    return "capabilities: %s\n" % (capabilities(repo, proto))
+    caps = capabilities(repo, proto).data
+    return bytesresponse('capabilities: %s\n' % caps)
 
 @wireprotocommand('listkeys', 'namespace')
 def listkeys(repo, proto, namespace):
     d = repo.listkeys(encoding.tolocal(namespace)).items()
-    return pushkeymod.encodekeys(d)
+    return bytesresponse(pushkeymod.encodekeys(d))
 
 @wireprotocommand('lookup', 'key')
 def lookup(repo, proto, key):
@@ -913,11 +926,12 @@
     except Exception as inst:
         r = str(inst)
         success = 0
-    return "%d %s\n" % (success, r)
+    return bytesresponse('%d %s\n' % (success, r))
 
 @wireprotocommand('known', 'nodes *')
 def known(repo, proto, nodes, others):
-    return ''.join(b and "1" or "0" for b in repo.known(decodelist(nodes)))
+    v = ''.join(b and '1' or '0' for b in repo.known(decodelist(nodes)))
+    return bytesresponse(v)
 
 @wireprotocommand('pushkey', 'namespace key old new')
 def pushkey(repo, proto, namespace, key, old, new):
@@ -938,7 +952,7 @@
                          encoding.tolocal(old), new) or False
 
     output = output.getvalue() if output else ''
-    return '%s\n%s' % (int(r), output)
+    return bytesresponse('%s\n%s' % (int(r), output))
 
 @wireprotocommand('stream_out')
 def stream(repo, proto):