wireprotoserver: extract SSH response handling functions
authorGregory Szorc <gregory.szorc@gmail.com>
Wed, 07 Feb 2018 21:04:54 -0800
changeset 36064 5767664d39a5
parent 36063 5a53af7d09aa
child 36065 bf676267f64f
wireprotoserver: extract SSH response handling functions The lookup/dispatch table was cute. But it isn't needed. Future refactors will benefit from the handlers for individual response types living outside the class. As part of this, I snuck in a change that changes a type compare from str to bytes. This has no effect on Python 2. But it might make Python 3 a bit happier. Differential Revision: https://phab.mercurial-scm.org/D2091
mercurial/wireprotoserver.py
tests/sshprotoext.py
--- a/mercurial/wireprotoserver.py	Sat Dec 23 15:13:37 2017 +0530
+++ b/mercurial/wireprotoserver.py	Wed Feb 07 21:04:54 2018 -0800
@@ -336,6 +336,24 @@
 
     return ''
 
+def _sshv1respondbytes(fout, value):
+    """Send a bytes response for protocol version 1."""
+    fout.write('%d\n' % len(value))
+    fout.write(value)
+    fout.flush()
+
+def _sshv1respondstream(fout, source):
+    write = fout.write
+    for chunk in source.gen:
+        write(chunk)
+    fout.flush()
+
+def _sshv1respondooberror(fout, ferr, rsp):
+    ferr.write(b'%s\n-\n' % rsp)
+    ferr.flush()
+    fout.write(b'\n')
+    fout.flush()
+
 class sshserver(baseprotocolhandler):
     def __init__(self, ui, repo):
         self._ui = ui
@@ -376,7 +394,7 @@
         return [data[k] for k in keys]
 
     def getfile(self, fpout):
-        self._sendresponse('')
+        _sshv1respondbytes(self._fout, b'')
         count = int(self._fin.readline())
         while count:
             fpout.write(self._fin.read(count))
@@ -385,51 +403,34 @@
     def redirect(self):
         pass
 
-    def _sendresponse(self, v):
-        self._fout.write("%d\n" % len(v))
-        self._fout.write(v)
-        self._fout.flush()
-
-    def _sendstream(self, source):
-        write = self._fout.write
-        for chunk in source.gen:
-            write(chunk)
-        self._fout.flush()
-
-    def _sendpushresponse(self, rsp):
-        self._sendresponse('')
-        self._sendresponse(str(rsp.res))
-
-    def _sendpusherror(self, rsp):
-        self._sendresponse(rsp.res)
-
-    def _sendooberror(self, rsp):
-        self._ui.ferr.write('%s\n-\n' % rsp.message)
-        self._ui.ferr.flush()
-        self._fout.write('\n')
-        self._fout.flush()
-
     def serve_forever(self):
         while self.serve_one():
             pass
         sys.exit(0)
 
-    _handlers = {
-        str: _sendresponse,
-        wireproto.streamres: _sendstream,
-        wireproto.streamres_legacy: _sendstream,
-        wireproto.pushres: _sendpushresponse,
-        wireproto.pusherr: _sendpusherror,
-        wireproto.ooberror: _sendooberror,
-    }
-
     def serve_one(self):
         cmd = self._fin.readline()[:-1]
         if cmd and wireproto.commands.commandavailable(cmd, self):
             rsp = wireproto.dispatch(self._repo, self, cmd)
-            self._handlers[rsp.__class__](self, rsp)
+
+            if isinstance(rsp, bytes):
+                _sshv1respondbytes(self._fout, rsp)
+            elif isinstance(rsp, wireproto.streamres):
+                _sshv1respondstream(self._fout, rsp)
+            elif isinstance(rsp, wireproto.streamres_legacy):
+                _sshv1respondstream(self._fout, rsp)
+            elif isinstance(rsp, wireproto.pushres):
+                _sshv1respondbytes(self._fout, b'')
+                _sshv1respondbytes(self._fout, bytes(rsp.res))
+            elif isinstance(rsp, wireproto.pusherr):
+                _sshv1respondbytes(self._fout, rsp.res)
+            elif isinstance(rsp, wireproto.ooberror):
+                _sshv1respondooberror(self._fout, self._ui.ferr, rsp.message)
+            else:
+                raise error.ProgrammingError('unhandled response type from '
+                                             'wire protocol command: %s' % rsp)
         elif cmd:
-            self._sendresponse("")
+            _sshv1respondbytes(self._fout, b'')
         return cmd != ''
 
     def _client(self):
--- a/tests/sshprotoext.py	Sat Dec 23 15:13:37 2017 +0530
+++ b/tests/sshprotoext.py	Wed Feb 07 21:04:54 2018 -0800
@@ -45,11 +45,11 @@
         l = self._fin.readline()
         assert l == b'hello\n'
         # Respond to unknown commands with an empty reply.
-        self._sendresponse(b'')
+        wireprotoserver._sshv1respondbytes(self._fout, b'')
         l = self._fin.readline()
         assert l == b'between\n'
         rsp = wireproto.dispatch(self._repo, self, b'between')
-        self._handlers[rsp.__class__](self, rsp)
+        wireprotoserver._sshv1respondbytes(self._fout, rsp)
 
         super(prehelloserver, self).serve_forever()