tests/sshprotoext.py
changeset 35938 80a2b8ae42a1
parent 35937 a9cffd14aa04
child 35976 48a3a9283f09
--- a/tests/sshprotoext.py	Sun Feb 04 14:10:56 2018 -0800
+++ b/tests/sshprotoext.py	Mon Feb 05 09:14:32 2018 -0800
@@ -12,6 +12,7 @@
 
 from mercurial import (
     error,
+    extensions,
     registrar,
     sshpeer,
     wireproto,
@@ -52,30 +53,26 @@
 
         super(prehelloserver, self).serve_forever()
 
-class extrahandshakecommandspeer(sshpeer.sshpeer):
-    """An ssh peer that sends extra commands as part of initial handshake."""
-    def _validaterepo(self):
-        mode = self._ui.config(b'sshpeer', b'handshake-mode')
-        if mode == b'pre-no-args':
-            self._callstream(b'no-args')
-            return super(extrahandshakecommandspeer, self)._validaterepo()
-        elif mode == b'pre-multiple-no-args':
-            self._callstream(b'unknown1')
-            self._callstream(b'unknown2')
-            self._callstream(b'unknown3')
-            return super(extrahandshakecommandspeer, self)._validaterepo()
-        else:
-            raise error.ProgrammingError(b'unknown HANDSHAKECOMMANDMODE: %s' %
-                                         mode)
-
-def registercommands():
-    def dummycommand(repo, proto):
-        raise error.ProgrammingError('this should never be called')
-
-    wireproto.wireprotocommand(b'no-args', b'')(dummycommand)
-    wireproto.wireprotocommand(b'unknown1', b'')(dummycommand)
-    wireproto.wireprotocommand(b'unknown2', b'')(dummycommand)
-    wireproto.wireprotocommand(b'unknown3', b'')(dummycommand)
+def performhandshake(orig, ui, stdin, stdout, stderr):
+    """Wrapped version of sshpeer._performhandshake to send extra commands."""
+    mode = ui.config(b'sshpeer', b'handshake-mode')
+    if mode == b'pre-no-args':
+        ui.debug(b'sending no-args command\n')
+        stdin.write(b'no-args\n')
+        stdin.flush()
+        return orig(ui, stdin, stdout, stderr)
+    elif mode == b'pre-multiple-no-args':
+        ui.debug(b'sending unknown1 command\n')
+        stdin.write(b'unknown1\n')
+        ui.debug(b'sending unknown2 command\n')
+        stdin.write(b'unknown2\n')
+        ui.debug(b'sending unknown3 command\n')
+        stdin.write(b'unknown3\n')
+        stdin.flush()
+        return orig(ui, stdin, stdout, stderr)
+    else:
+        raise error.ProgrammingError(b'unknown HANDSHAKECOMMANDMODE: %s' %
+                                     mode)
 
 def extsetup(ui):
     # It's easier for tests to define the server behavior via environment
@@ -94,7 +91,6 @@
     peermode = ui.config(b'sshpeer', b'mode')
 
     if peermode == b'extra-handshake-commands':
-        sshpeer.sshpeer = extrahandshakecommandspeer
-        registercommands()
+        extensions.wrapfunction(sshpeer, '_performhandshake', performhandshake)
     elif peermode:
         raise error.ProgrammingError(b'unknown peer mode: %s' % peermode)