wireproto: support for receiving multiple requests
authorGregory Szorc <gregory.szorc@gmail.com>
Wed, 14 Mar 2018 16:53:30 -0700
changeset 37058 c5e9c3b47366
parent 37057 2ec1fb9de638
child 37059 bbea991635d0
wireproto: support for receiving multiple requests Now that we have request IDs on each frame and a specification that allows multiple requests to be issued simultaneously, possibly interleaved, let's teach the server to deal with that. Instead of tracking the state for *the* active command request, we instead track the state of each receiving command by its request ID. The multiple states in our state machine for processing each command's state has been collapsed into a single state for "receiving commands." Tests have been added so our branch coverage covers all meaningful branches. However, we did lose some logical coverage. The implementation of this new feature opens up the door to a server having partial command requests when end of input is reached. We will probably want a mechanism to deal with partial requests. For now, I've tracked that as a known issue in the class docstring. I've also noted an abuse vector that becomes a little bit easier to exploit with this feature. Differential Revision: https://phab.mercurial-scm.org/D2870
mercurial/wireprotoframing.py
tests/test-http-api-httpv2.t
tests/test-wireproto-serverreactor.py
--- a/mercurial/wireprotoframing.py	Wed Mar 14 16:51:34 2018 -0700
+++ b/mercurial/wireprotoframing.py	Wed Mar 14 16:53:30 2018 -0700
@@ -327,6 +327,23 @@
 
     noop
        Indicates no additional action is required.
+
+    Known Issues
+    ------------
+
+    There are no limits to the number of partially received commands or their
+    size. A malicious client could stream command request data and exhaust the
+    server's memory.
+
+    Partially received commands are not acted upon when end of input is
+    reached. Should the server error if it receives a partial request?
+    Should the client send a message to abort a partially transmitted request
+    to facilitate graceful shutdown?
+
+    Active requests that haven't been responded to aren't tracked. This means
+    that if we receive a command and instruct its dispatch, another command
+    with its request ID can come in over the wire and there will be a race
+    between who responds to what.
     """
 
     def __init__(self, deferoutput=False):
@@ -342,14 +359,8 @@
         self._deferoutput = deferoutput
         self._state = 'idle'
         self._bufferedframegens = []
-        self._activerequestid = None
-        self._activecommand = None
-        self._activeargs = None
-        self._activedata = None
-        self._expectingargs = None
-        self._expectingdata = None
-        self._activeargname = None
-        self._activeargchunks = None
+        # request id -> dict of commands that are actively being received.
+        self._receivingcommands = {}
 
     def onframerecv(self, requestid, frametype, frameflags, payload):
         """Process a frame that has been received off the wire.
@@ -359,8 +370,7 @@
         """
         handlers = {
             'idle': self._onframeidle,
-            'command-receiving-args': self._onframereceivingargs,
-            'command-receiving-data': self._onframereceivingdata,
+            'command-receiving': self._onframecommandreceiving,
             'errored': self._onframeerrored,
         }
 
@@ -391,6 +401,8 @@
         No more frames will be received. All pending activity should be
         completed.
         """
+        # TODO should we do anything about in-flight commands?
+
         if not self._deferoutput or not self._bufferedframegens:
             return 'noop', {}
 
@@ -414,12 +426,20 @@
             'message': msg,
         }
 
-    def _makeruncommandresult(self):
+    def _makeruncommandresult(self, requestid):
+        entry = self._receivingcommands[requestid]
+        del self._receivingcommands[requestid]
+
+        if self._receivingcommands:
+            self._state = 'command-receiving'
+        else:
+            self._state = 'idle'
+
         return 'runcommand', {
-            'requestid': self._activerequestid,
-            'command': self._activecommand,
-            'args': self._activeargs,
-            'data': self._activedata.getvalue() if self._activedata else None,
+            'requestid': requestid,
+            'command': entry['command'],
+            'args': entry['args'],
+            'data': entry['data'].getvalue() if entry['data'] else None,
         }
 
     def _makewantframeresult(self):
@@ -435,34 +455,76 @@
             return self._makeerrorresult(
                 _('expected command frame; got %d') % frametype)
 
-        self._activerequestid = requestid
-        self._activecommand = payload
-        self._activeargs = {}
-        self._activedata = None
+        if requestid in self._receivingcommands:
+            self._state = 'errored'
+            return self._makeerrorresult(
+                _('request with ID %d already received') % requestid)
+
+        expectingargs = bool(frameflags & FLAG_COMMAND_NAME_HAVE_ARGS)
+        expectingdata = bool(frameflags & FLAG_COMMAND_NAME_HAVE_DATA)
+
+        self._receivingcommands[requestid] = {
+            'command': payload,
+            'args': {},
+            'data': None,
+            'expectingargs': expectingargs,
+            'expectingdata': expectingdata,
+        }
 
         if frameflags & FLAG_COMMAND_NAME_EOS:
-            return self._makeruncommandresult()
-
-        self._expectingargs = bool(frameflags & FLAG_COMMAND_NAME_HAVE_ARGS)
-        self._expectingdata = bool(frameflags & FLAG_COMMAND_NAME_HAVE_DATA)
+            return self._makeruncommandresult(requestid)
 
-        if self._expectingargs:
-            self._state = 'command-receiving-args'
-            return self._makewantframeresult()
-        elif self._expectingdata:
-            self._activedata = util.bytesio()
-            self._state = 'command-receiving-data'
+        if expectingargs or expectingdata:
+            self._state = 'command-receiving'
             return self._makewantframeresult()
         else:
             self._state = 'errored'
             return self._makeerrorresult(_('missing frame flags on '
                                            'command frame'))
 
-    def _onframereceivingargs(self, requestid, frametype, frameflags, payload):
-        if frametype != FRAME_TYPE_COMMAND_ARGUMENT:
+    def _onframecommandreceiving(self, requestid, frametype, frameflags,
+                                 payload):
+        # It could be a new command request. Process it as such.
+        if frametype == FRAME_TYPE_COMMAND_NAME:
+            return self._onframeidle(requestid, frametype, frameflags, payload)
+
+        # All other frames should be related to a command that is currently
+        # receiving.
+        if requestid not in self._receivingcommands:
             self._state = 'errored'
-            return self._makeerrorresult(_('expected command argument '
-                                           'frame; got %d') % frametype)
+            return self._makeerrorresult(
+                _('received frame for request that is not receiving: %d') %
+                  requestid)
+
+        entry = self._receivingcommands[requestid]
+
+        if frametype == FRAME_TYPE_COMMAND_ARGUMENT:
+            if not entry['expectingargs']:
+                self._state = 'errored'
+                return self._makeerrorresult(_(
+                    'received command argument frame for request that is not '
+                    'expecting arguments: %d') % requestid)
+
+            return self._handlecommandargsframe(requestid, entry, frametype,
+                                                frameflags, payload)
+
+        elif frametype == FRAME_TYPE_COMMAND_DATA:
+            if not entry['expectingdata']:
+                self._state = 'errored'
+                return self._makeerrorresult(_(
+                    'received command data frame for request that is not '
+                    'expecting data: %d') % requestid)
+
+            if entry['data'] is None:
+                entry['data'] = util.bytesio()
+
+            return self._handlecommanddataframe(requestid, entry, frametype,
+                                                frameflags, payload)
+
+    def _handlecommandargsframe(self, requestid, entry, frametype, frameflags,
+                                payload):
+        # The frame and state of command should have already been validated.
+        assert frametype == FRAME_TYPE_COMMAND_ARGUMENT
 
         offset = 0
         namesize, valuesize = ARGUMENT_FRAME_HEADER.unpack_from(payload)
@@ -483,10 +545,6 @@
         # and wait for the next frame.
         if frameflags & FLAG_COMMAND_ARGUMENT_CONTINUATION:
             raise error.ProgrammingError('not yet implemented')
-            self._activeargname = argname
-            self._activeargchunks = [argvalue]
-            self._state = 'command-arg-continuation'
-            return self._makewantframeresult()
 
         # Common case: the argument value is completely contained in this
         # frame.
@@ -496,36 +554,30 @@
             return self._makeerrorresult(_('malformed argument frame: '
                                            'partial argument value'))
 
-        self._activeargs[argname] = argvalue
+        entry['args'][argname] = argvalue
 
         if frameflags & FLAG_COMMAND_ARGUMENT_EOA:
-            if self._expectingdata:
-                self._state = 'command-receiving-data'
-                self._activedata = util.bytesio()
+            if entry['expectingdata']:
                 # TODO signal request to run a command once we don't
                 # buffer data frames.
                 return self._makewantframeresult()
             else:
-                self._state = 'waiting'
-                return self._makeruncommandresult()
+                return self._makeruncommandresult(requestid)
         else:
             return self._makewantframeresult()
 
-    def _onframereceivingdata(self, requestid, frametype, frameflags, payload):
-        if frametype != FRAME_TYPE_COMMAND_DATA:
-            self._state = 'errored'
-            return self._makeerrorresult(_('expected command data frame; '
-                                           'got %d') % frametype)
+    def _handlecommanddataframe(self, requestid, entry, frametype, frameflags,
+                                payload):
+        assert frametype == FRAME_TYPE_COMMAND_DATA
 
         # TODO support streaming data instead of buffering it.
-        self._activedata.write(payload)
+        entry['data'].write(payload)
 
         if frameflags & FLAG_COMMAND_DATA_CONTINUATION:
             return self._makewantframeresult()
         elif frameflags & FLAG_COMMAND_DATA_EOS:
-            self._activedata.seek(0)
-            self._state = 'idle'
-            return self._makeruncommandresult()
+            entry['data'].seek(0)
+            return self._makeruncommandresult(requestid)
         else:
             self._state = 'errored'
             return self._makeerrorresult(_('command data frame without '
--- a/tests/test-http-api-httpv2.t	Wed Mar 14 16:51:34 2018 -0700
+++ b/tests/test-http-api-httpv2.t	Wed Mar 14 16:53:30 2018 -0700
@@ -401,12 +401,12 @@
   s>     Server: testing stub value\r\n
   s>     Date: $HTTP_DATE$\r\n
   s>     Content-Type: text/plain\r\n
-  s>     Content-Length: 332\r\n
+  s>     Content-Length: 322\r\n
   s>     \r\n
   s>     received: 1 2 1 command1\n
-  s>     ["wantframe", {"state": "command-receiving-args"}]\n
+  s>     ["wantframe", {"state": "command-receiving"}]\n
   s>     received: 2 0 1 \x03\x00\x04\x00fooval1\n
-  s>     ["wantframe", {"state": "command-receiving-args"}]\n
+  s>     ["wantframe", {"state": "command-receiving"}]\n
   s>     received: 2 2 1 \x04\x00\x03\x00bar1val\n
   s>     ["runcommand", {"args": {"bar1": "val", "foo": "val1"}, "command": "command1", "data": null, "requestid": 1}]\n
   s>     received: <no frame>\n
--- a/tests/test-wireproto-serverreactor.py	Wed Mar 14 16:51:34 2018 -0700
+++ b/tests/test-wireproto-serverreactor.py	Wed Mar 14 16:53:30 2018 -0700
@@ -196,6 +196,19 @@
             'message': b'expected command frame; got 2',
         })
 
+    def testunexpectedcommandargumentreceiving(self):
+        """Same as above but the command is receiving."""
+        results = list(sendframes(makereactor(), [
+            ffs(b'1 command-name have-data command'),
+            ffs(b'1 command-argument eoa ignored'),
+        ]))
+
+        self.assertaction(results[1], 'error')
+        self.assertEqual(results[1][1], {
+            'message': b'received command argument frame for request that is '
+                       b'not expecting arguments: 1',
+        })
+
     def testunexpectedcommanddata(self):
         """Command argument frame when not running a command is an error."""
         result = self._sendsingleframe(makereactor(),
@@ -205,6 +218,19 @@
             'message': b'expected command frame; got 3',
         })
 
+    def testunexpectedcommanddatareceiving(self):
+        """Same as above except the command is receiving."""
+        results = list(sendframes(makereactor(), [
+            ffs(b'1 command-name have-args command'),
+            ffs(b'1 command-data eos ignored'),
+        ]))
+
+        self.assertaction(results[1], 'error')
+        self.assertEqual(results[1][1], {
+            'message': b'received command data frame for request that is not '
+                       b'expecting data: 1',
+        })
+
     def testmissingcommandframeflags(self):
         """Command name frame must have flags set."""
         result = self._sendsingleframe(makereactor(),
@@ -214,19 +240,77 @@
             'message': b'missing frame flags on command frame',
         })
 
+    def testconflictingrequestid(self):
+        """Multiple fully serviced commands with same request ID is allowed."""
+        results = list(sendframes(makereactor(), [
+            ffs(b'1 command-name eos command'),
+            ffs(b'1 command-name eos command'),
+            ffs(b'1 command-name eos command'),
+        ]))
+        for i in range(3):
+            self.assertaction(results[i], 'runcommand')
+            self.assertEqual(results[i][1], {
+                'requestid': 1,
+                'command': b'command',
+                'args': {},
+                'data': None,
+            })
+
+    def testconflictingrequestid(self):
+        """Request ID for new command matching in-flight command is illegal."""
+        results = list(sendframes(makereactor(), [
+            ffs(b'1 command-name have-args command'),
+            ffs(b'1 command-name eos command'),
+        ]))
+
+        self.assertaction(results[0], 'wantframe')
+        self.assertaction(results[1], 'error')
+        self.assertEqual(results[1][1], {
+            'message': b'request with ID 1 already received',
+        })
+
+    def testinterleavedcommands(self):
+        results = list(sendframes(makereactor(), [
+            ffs(b'1 command-name have-args command1'),
+            ffs(b'3 command-name have-args command3'),
+            ffs(br'1 command-argument 0 \x03\x00\x03\x00foobar'),
+            ffs(br'3 command-argument 0 \x03\x00\x03\x00bizbaz'),
+            ffs(br'3 command-argument eoa \x03\x00\x03\x00keyval'),
+            ffs(br'1 command-argument eoa \x04\x00\x03\x00key1val'),
+        ]))
+
+        self.assertEqual([t[0] for t in results], [
+            'wantframe',
+            'wantframe',
+            'wantframe',
+            'wantframe',
+            'runcommand',
+            'runcommand',
+        ])
+
+        self.assertEqual(results[4][1], {
+            'requestid': 3,
+            'command': 'command3',
+            'args': {b'biz': b'baz', b'key': b'val'},
+            'data': None,
+        })
+        self.assertEqual(results[5][1], {
+            'requestid': 1,
+            'command': 'command1',
+            'args': {b'foo': b'bar', b'key1': b'val'},
+            'data': None,
+        })
+
     def testmissingargumentframe(self):
+        # This test attempts to test behavior when reactor has an incomplete
+        # command request waiting on argument data. But it doesn't handle that
+        # scenario yet. So this test does nothing of value.
         frames = [
             ffs(b'1 command-name have-args command'),
-            ffs(b'1 command-name 0 ignored'),
         ]
 
         results = list(sendframes(makereactor(), frames))
-        self.assertEqual(len(results), 2)
         self.assertaction(results[0], 'wantframe')
-        self.assertaction(results[1], 'error')
-        self.assertEqual(results[1][1], {
-            'message': b'expected command argument frame; got 1',
-        })
 
     def testincompleteargumentname(self):
         """Argument frame with incomplete name."""
@@ -259,17 +343,16 @@
         })
 
     def testmissingcommanddataframe(self):
+        # The reactor doesn't currently handle partially received commands.
+        # So this test is failing to do anything with request 1.
         frames = [
             ffs(b'1 command-name have-data command1'),
-            ffs(b'1 command-name eos command2'),
+            ffs(b'3 command-name eos command2'),
         ]
         results = list(sendframes(makereactor(), frames))
         self.assertEqual(len(results), 2)
         self.assertaction(results[0], 'wantframe')
-        self.assertaction(results[1], 'error')
-        self.assertEqual(results[1][1], {
-            'message': b'expected command data frame; got 1',
-        })
+        self.assertaction(results[1], 'runcommand')
 
     def testmissingcommanddataframeflags(self):
         frames = [
@@ -284,6 +367,18 @@
             'message': b'command data frame without flags',
         })
 
+    def testframefornonreceivingrequest(self):
+        """Receiving a frame for a command that is not receiving is illegal."""
+        results = list(sendframes(makereactor(), [
+            ffs(b'1 command-name eos command1'),
+            ffs(b'3 command-name have-data command3'),
+            ffs(b'1 command-argument eoa ignored'),
+        ]))
+        self.assertaction(results[2], 'error')
+        self.assertEqual(results[2][1], {
+            'message': b'received frame for request that is not receiving: 1',
+        })
+
     def testsimpleresponse(self):
         """Bytes response to command sends result frames."""
         reactor = makereactor()