tests/test-stdio.py
changeset 45097 dff208398ede
parent 45096 e9e452eafbfb
child 45104 eb26a9cf7821
--- a/tests/test-stdio.py	Thu Jul 09 12:52:42 2020 +0200
+++ b/tests/test-stdio.py	Thu Jul 09 12:52:04 2020 +0200
@@ -10,6 +10,7 @@
 import signal
 import subprocess
 import sys
+import tempfile
 import unittest
 
 from mercurial import pycompat
@@ -41,7 +42,9 @@
 
 signal.signal(signal.SIGINT, lambda *x: None)
 dispatch.initstdio()
-procutil.{stream}.write(b'x' * 1048576)
+write_result = procutil.{stream}.write(b'x' * 1048576)
+with open({write_result_fn}, 'w') as write_result_f:
+    write_result_f.write(str(write_result))
 '''
 
 
@@ -109,6 +112,7 @@
         rwpair_generator,
         check_output,
         python_args=[],
+        post_child_check=None,
     ):
         assert stream in ('stdout', 'stderr')
         with rwpair_generator() as (stream_receiver, child_stream), open(
@@ -130,6 +134,8 @@
             finally:
                 retcode = proc.wait()
             self.assertEqual(retcode, 0)
+            if post_child_check is not None:
+                post_child_check()
 
     def _test_buffering(
         self, stream, rwpair_generator, expected_output, python_args=[]
@@ -194,13 +200,39 @@
                 _readall(stream_receiver, 131072, buf), b'x' * 1048576
             )
 
-        self._test(
-            TEST_LARGE_WRITE_CHILD_SCRIPT.format(stream=stream),
-            stream,
-            rwpair_generator,
-            check_output,
-            python_args,
-        )
+        def post_child_check():
+            with open(write_result_fn, 'r') as write_result_f:
+                write_result_str = write_result_f.read()
+            if pycompat.ispy3:
+                # On Python 3, we test that the correct number of bytes is
+                # claimed to have been written.
+                expected_write_result_str = '1048576'
+            else:
+                # On Python 2, we only check that the large write does not
+                # crash.
+                expected_write_result_str = 'None'
+            self.assertEqual(write_result_str, expected_write_result_str)
+
+        try:
+            # tempfile.mktemp() is unsafe in general, as a malicious process
+            # could create the file before we do. But in tests, we're running
+            # in a controlled environment.
+            write_result_fn = tempfile.mktemp()
+            self._test(
+                TEST_LARGE_WRITE_CHILD_SCRIPT.format(
+                    stream=stream, write_result_fn=repr(write_result_fn)
+                ),
+                stream,
+                rwpair_generator,
+                check_output,
+                python_args,
+                post_child_check=post_child_check,
+            )
+        finally:
+            try:
+                os.unlink(write_result_fn)
+            except OSError:
+                pass
 
     def test_large_write_stdout_devnull(self):
         self._test_large_write('stdout', _devnull)