tests/badserverext.py
changeset 41466 4d5aae86c9bd
parent 41464 d343d9ac173e
child 43076 2372284d9457
--- a/tests/badserverext.py	Tue Jan 29 14:06:46 2019 -0800
+++ b/tests/badserverext.py	Wed Jan 30 12:12:25 2019 -0800
@@ -75,7 +75,7 @@
         object.__setattr__(self, '_closeaftersendbytes', closeaftersendbytes)
 
     def __getattribute__(self, name):
-        if name in ('makefile',):
+        if name in ('makefile', 'sendall', '_writelog'):
             return object.__getattribute__(self, name)
 
         return getattr(object.__getattribute__(self, '_orig'), name)
@@ -86,6 +86,13 @@
     def __setattr__(self, name, value):
         setattr(object.__getattribute__(self, '_orig'), name, value)
 
+    def _writelog(self, msg):
+        msg = msg.replace(b'\r', b'\\r').replace(b'\n', b'\\n')
+
+        object.__getattribute__(self, '_logfp').write(msg)
+        object.__getattribute__(self, '_logfp').write(b'\n')
+        object.__getattribute__(self, '_logfp').flush()
+
     def makefile(self, mode, bufsize):
         f = object.__getattribute__(self, '_orig').makefile(mode, bufsize)
 
@@ -99,6 +106,38 @@
                                closeafterrecvbytes=closeafterrecvbytes,
                                closeaftersendbytes=closeaftersendbytes)
 
+    def sendall(self, data, flags=0):
+        remaining = object.__getattribute__(self, '_closeaftersendbytes')
+
+        # No read limit. Call original function.
+        if not remaining:
+            result = object.__getattribute__(self, '_orig').sendall(data, flags)
+            self._writelog(b'sendall(%d) -> %s' % (len(data), data))
+            return result
+
+        if len(data) > remaining:
+            newdata = data[0:remaining]
+        else:
+            newdata = data
+
+        remaining -= len(newdata)
+
+        result = object.__getattribute__(self, '_orig').sendall(newdata, flags)
+
+        self._writelog(b'sendall(%d from %d) -> (%d) %s' % (
+            len(newdata), len(data), remaining, newdata))
+
+        object.__setattr__(self, '_closeaftersendbytes', remaining)
+
+        if remaining <= 0:
+            self._writelog(b'write limit reached; closing socket')
+            object.__getattribute__(self, '_orig').shutdown(socket.SHUT_RDWR)
+
+            raise Exception('connection closed after sending N bytes')
+
+        return result
+
+
 # We can't adjust __class__ on socket._fileobject, so define a proxy.
 class fileobjectproxy(object):
     __slots__ = (