mercurial/wireprotoframing.py
changeset 40138 b5bf3dd6ec5b
parent 40136 3a6d6c54bd81
child 40328 2c55716f8a1c
--- a/mercurial/wireprotoframing.py	Mon Oct 08 15:19:32 2018 -0700
+++ b/mercurial/wireprotoframing.py	Mon Oct 08 17:24:28 2018 -0700
@@ -368,17 +368,46 @@
 def createcommandresponseokframe(stream, requestid):
     overall = b''.join(cborutil.streamencode({b'status': b'ok'}))
 
+    if stream.streamsettingssent:
+        overall = stream.encode(overall)
+        encoded = True
+
+        if not overall:
+            return None
+    else:
+        encoded = False
+
     return stream.makeframe(requestid=requestid,
                             typeid=FRAME_TYPE_COMMAND_RESPONSE,
                             flags=FLAG_COMMAND_RESPONSE_CONTINUATION,
-                            payload=overall)
+                            payload=overall,
+                            encoded=encoded)
 
-def createcommandresponseeosframe(stream, requestid):
+def createcommandresponseeosframes(stream, requestid,
+                                   maxframesize=DEFAULT_MAX_FRAME_SIZE):
     """Create an empty payload frame representing command end-of-stream."""
-    return stream.makeframe(requestid=requestid,
-                            typeid=FRAME_TYPE_COMMAND_RESPONSE,
-                            flags=FLAG_COMMAND_RESPONSE_EOS,
-                            payload=b'')
+    payload = stream.flush()
+
+    offset = 0
+    while True:
+        chunk = payload[offset:offset + maxframesize]
+        offset += len(chunk)
+
+        done = offset == len(payload)
+
+        if done:
+            flags = FLAG_COMMAND_RESPONSE_EOS
+        else:
+            flags = FLAG_COMMAND_RESPONSE_CONTINUATION
+
+        yield stream.makeframe(requestid=requestid,
+                               typeid=FRAME_TYPE_COMMAND_RESPONSE,
+                               flags=flags,
+                               payload=chunk,
+                               encoded=payload != b'')
+
+        if done:
+            break
 
 def createalternatelocationresponseframe(stream, requestid, location):
     data = {
@@ -395,10 +424,19 @@
         if value is not None:
             data[b'location'][pycompat.bytestr(a)] = value
 
+    payload = b''.join(cborutil.streamencode(data))
+
+    if stream.streamsettingssent:
+        payload = stream.encode(payload)
+        encoded = True
+    else:
+        encoded = False
+
     return stream.makeframe(requestid=requestid,
                             typeid=FRAME_TYPE_COMMAND_RESPONSE,
                             flags=FLAG_COMMAND_RESPONSE_CONTINUATION,
-                            payload=b''.join(cborutil.streamencode(data)))
+                            payload=payload,
+                            encoded=encoded)
 
 def createcommanderrorresponse(stream, requestid, message, args=None):
     # TODO should this be using a list of {'msg': ..., 'args': {}} so atom
@@ -519,6 +557,8 @@
                 yield frame
             return
 
+        data = self._stream.encode(data)
+
         # There is a ton of potential to do more complicated things here.
         # Our immediate goal is to coalesce small chunks into big frames,
         # not achieve the fewest number of frames possible. So we go with
@@ -548,7 +588,8 @@
                     self._requestid,
                     typeid=FRAME_TYPE_COMMAND_RESPONSE,
                     flags=FLAG_COMMAND_RESPONSE_CONTINUATION,
-                    payload=chunk)
+                    payload=chunk,
+                    encoded=True)
 
                 if offset == len(data):
                     return
@@ -583,7 +624,8 @@
             self._requestid,
             typeid=FRAME_TYPE_COMMAND_RESPONSE,
             flags=FLAG_COMMAND_RESPONSE_CONTINUATION,
-            payload=payload)
+            payload=payload,
+            encoded=True)
 
 # TODO consider defining encoders/decoders using the util.compressionengine
 # mechanism.
@@ -776,7 +818,9 @@
 
     def __init__(self, streamid, active=False):
         super(outputstream, self).__init__(streamid, active=active)
+        self.streamsettingssent = False
         self._encoder = None
+        self._encodername = None
 
     def setencoder(self, ui, name):
         """Set the encoder for this stream.
@@ -787,6 +831,7 @@
             raise error.Abort(_('unknown stream encoder: %s') % name)
 
         self._encoder = STREAM_ENCODERS[name][0](ui)
+        self._encodername = name
 
     def encode(self, data):
         if not self._encoder:
@@ -806,6 +851,45 @@
 
         self._encoder.finish()
 
+    def makeframe(self, requestid, typeid, flags, payload,
+                  encoded=False):
+        """Create a frame to be sent out over this stream.
+
+        Only returns the frame instance. Does not actually send it.
+        """
+        streamflags = 0
+        if not self._active:
+            streamflags |= STREAM_FLAG_BEGIN_STREAM
+            self._active = True
+
+        if encoded:
+            if not self.streamsettingssent:
+                raise error.ProgrammingError(
+                    b'attempting to send encoded frame without sending stream '
+                    b'settings')
+
+            streamflags |= STREAM_FLAG_ENCODING_APPLIED
+
+        if (typeid == FRAME_TYPE_STREAM_SETTINGS
+            and flags & FLAG_STREAM_ENCODING_SETTINGS_EOS):
+            self.streamsettingssent = True
+
+        return makeframe(requestid, self.streamid, streamflags, typeid, flags,
+                         payload)
+
+    def makestreamsettingsframe(self, requestid):
+        """Create a stream settings frame for this stream.
+
+        Returns frame data or None if no stream settings frame is needed or has
+        already been sent.
+        """
+        if not self._encoder or self.streamsettingssent:
+            return None
+
+        payload = b''.join(cborutil.streamencode(self._encodername))
+        return self.makeframe(requestid, FRAME_TYPE_STREAM_SETTINGS,
+                              FLAG_STREAM_ENCODING_SETTINGS_EOS, payload)
+
 def ensureserverstream(stream):
     if stream.streamid % 2:
         raise error.ProgrammingError('server should only write to even '
@@ -995,7 +1079,9 @@
                         yield frame
 
                     if emitted:
-                        yield createcommandresponseeosframe(stream, requestid)
+                        for frame in createcommandresponseeosframes(
+                            stream, requestid):
+                            yield frame
                     break
 
                 except error.WireprotoCommandError as e:
@@ -1022,6 +1108,10 @@
                                 'alternatelocationresponse seen after initial '
                                 'output object')
 
+                        frame = stream.makestreamsettingsframe(requestid)
+                        if frame:
+                            yield frame
+
                         yield createalternatelocationresponseframe(
                             stream, requestid, o)
 
@@ -1034,7 +1124,16 @@
                             'object follows alternatelocationresponse')
 
                     if not emitted:
-                        yield createcommandresponseokframe(stream, requestid)
+                        # Frame is optional.
+                        frame = stream.makestreamsettingsframe(requestid)
+                        if frame:
+                            yield frame
+
+                        # May be None if empty frame (due to encoding).
+                        frame = createcommandresponseokframe(stream, requestid)
+                        if frame:
+                            yield frame
+
                         emitted = True
 
                     # Objects emitted by command functions can be serializable
@@ -1121,13 +1220,25 @@
         return self._handlesendframes(sendframes())
 
     def makeoutputstream(self):
-        """Create a stream to be used for sending data to the client."""
+        """Create a stream to be used for sending data to the client.
+
+        If this is called before protocol settings frames are received, we
+        don't know what stream encodings are supported by the client and
+        we will default to identity.
+        """
         streamid = self._nextoutgoingstreamid
         self._nextoutgoingstreamid += 2
 
         s = outputstream(streamid)
         self._outgoingstreams[streamid] = s
 
+        # Always use the *server's* preferred encoder over the client's,
+        # as servers have more to lose from sub-optimal encoders being used.
+        for name in STREAM_ENCODERS_ORDER:
+            if name in self._sendersettings['contentencodings']:
+                s.setencoder(self._ui, name)
+                break
+
         return s
 
     def _makeerrorresult(self, msg):