diff -r a9cffd14aa04 -r 80a2b8ae42a1 tests/sshprotoext.py --- a/tests/sshprotoext.py Sun Feb 04 14:10:56 2018 -0800 +++ b/tests/sshprotoext.py Mon Feb 05 09:14:32 2018 -0800 @@ -12,6 +12,7 @@ from mercurial import ( error, + extensions, registrar, sshpeer, wireproto, @@ -52,30 +53,26 @@ super(prehelloserver, self).serve_forever() -class extrahandshakecommandspeer(sshpeer.sshpeer): - """An ssh peer that sends extra commands as part of initial handshake.""" - def _validaterepo(self): - mode = self._ui.config(b'sshpeer', b'handshake-mode') - if mode == b'pre-no-args': - self._callstream(b'no-args') - return super(extrahandshakecommandspeer, self)._validaterepo() - elif mode == b'pre-multiple-no-args': - self._callstream(b'unknown1') - self._callstream(b'unknown2') - self._callstream(b'unknown3') - return super(extrahandshakecommandspeer, self)._validaterepo() - else: - raise error.ProgrammingError(b'unknown HANDSHAKECOMMANDMODE: %s' % - mode) - -def registercommands(): - def dummycommand(repo, proto): - raise error.ProgrammingError('this should never be called') - - wireproto.wireprotocommand(b'no-args', b'')(dummycommand) - wireproto.wireprotocommand(b'unknown1', b'')(dummycommand) - wireproto.wireprotocommand(b'unknown2', b'')(dummycommand) - wireproto.wireprotocommand(b'unknown3', b'')(dummycommand) +def performhandshake(orig, ui, stdin, stdout, stderr): + """Wrapped version of sshpeer._performhandshake to send extra commands.""" + mode = ui.config(b'sshpeer', b'handshake-mode') + if mode == b'pre-no-args': + ui.debug(b'sending no-args command\n') + stdin.write(b'no-args\n') + stdin.flush() + return orig(ui, stdin, stdout, stderr) + elif mode == b'pre-multiple-no-args': + ui.debug(b'sending unknown1 command\n') + stdin.write(b'unknown1\n') + ui.debug(b'sending unknown2 command\n') + stdin.write(b'unknown2\n') + ui.debug(b'sending unknown3 command\n') + stdin.write(b'unknown3\n') + stdin.flush() + return orig(ui, stdin, stdout, stderr) + else: + raise error.ProgrammingError(b'unknown HANDSHAKECOMMANDMODE: %s' % + mode) def extsetup(ui): # It's easier for tests to define the server behavior via environment @@ -94,7 +91,6 @@ peermode = ui.config(b'sshpeer', b'mode') if peermode == b'extra-handshake-commands': - sshpeer.sshpeer = extrahandshakecommandspeer - registercommands() + extensions.wrapfunction(sshpeer, '_performhandshake', performhandshake) elif peermode: raise error.ProgrammingError(b'unknown peer mode: %s' % peermode)