wireprotoserver: ensure that output stream gets flushed on exception stable
authorArseniy Alekseyev <aalekseyev@janestreet.com>
Thu, 04 Apr 2024 14:15:32 +0100
branchstable
changeset 51572 13c004b54cbe
parent 51571 74230abb2504
child 51600 ee1b648e4453
child 51608 3e0f86f09f26
wireprotoserver: ensure that output stream gets flushed on exception Previously flush was happening due to Python finalizer being run on `BufferedWriter`. With upgrade to Python 3.11 this started randomly failing. My guess is that the finalizer on the raw `FileIO` object may be running before the finalizer of `BufferedWriter` has a chance to run. At any rate, since we're not relying on finalizers in the happy case we should also not rely on them in case of exception.
mercurial/wireprotoserver.py
tests/sshprotoext.py
tests/test-sshserver.py
--- a/mercurial/wireprotoserver.py	Mon Apr 15 16:33:37 2024 +0100
+++ b/mercurial/wireprotoserver.py	Thu Apr 04 14:15:32 2024 +0100
@@ -527,24 +527,34 @@
     def __init__(self, ui, repo, logfh=None, accesshidden=False):
         self._ui = ui
         self._repo = repo
-        self._fin, self._fout = ui.protectfinout()
         self._accesshidden = accesshidden
-
-        # Log write I/O to stdout and stderr if configured.
-        if logfh:
-            self._fout = util.makeloggingfileobject(
-                logfh, self._fout, b'o', logdata=True
-            )
-            ui.ferr = util.makeloggingfileobject(
-                logfh, ui.ferr, b'e', logdata=True
-            )
+        self._logfh = logfh
 
     def serve_forever(self):
         self.serveuntil(threading.Event())
-        self._ui.restorefinout(self._fin, self._fout)
 
     def serveuntil(self, ev):
         """Serve until a threading.Event is set."""
-        _runsshserver(
-            self._ui, self._repo, self._fin, self._fout, ev, self._accesshidden
-        )
+        with self._ui.protectedfinout() as (fin, fout):
+            if self._logfh:
+                # Log write I/O to stdout and stderr if configured.
+                fout = util.makeloggingfileobject(
+                    self._logfh,
+                    fout,
+                    b'o',
+                    logdata=True,
+                )
+                self._ui.ferr = util.makeloggingfileobject(
+                    self._logfh,
+                    self._ui.ferr,
+                    b'e',
+                    logdata=True,
+                )
+            _runsshserver(
+                self._ui,
+                self._repo,
+                fin,
+                fout,
+                ev,
+                self._accesshidden,
+            )
--- a/tests/sshprotoext.py	Mon Apr 15 16:33:37 2024 +0100
+++ b/tests/sshprotoext.py	Thu Apr 04 14:15:32 2024 +0100
@@ -30,7 +30,7 @@
 
     def serve_forever(self):
         for i in range(10):
-            self._fout.write(b'banner: line %d\n' % i)
+            self._ui.fout.write(b'banner: line %d\n' % i)
 
         super(bannerserver, self).serve_forever()
 
@@ -45,17 +45,16 @@
     """
 
     def serve_forever(self):
-        l = self._fin.readline()
+        ui = self._ui
+        l = ui.fin.readline()
         assert l == b'hello\n'
         # Respond to unknown commands with an empty reply.
-        wireprotoserver._sshv1respondbytes(self._fout, b'')
-        l = self._fin.readline()
+        wireprotoserver._sshv1respondbytes(ui.fout, b'')
+        l = ui.fin.readline()
         assert l == b'between\n'
-        proto = wireprotoserver.sshv1protocolhandler(
-            self._ui, self._fin, self._fout
-        )
+        proto = wireprotoserver.sshv1protocolhandler(ui, ui.fin, ui.fout)
         rsp = wireprotov1server.dispatch(self._repo, proto, b'between')
-        wireprotoserver._sshv1respondbytes(self._fout, rsp.data)
+        wireprotoserver._sshv1respondbytes(ui.fout, rsp.data)
 
         super(prehelloserver, self).serve_forever()
 
--- a/tests/test-sshserver.py	Mon Apr 15 16:33:37 2024 +0100
+++ b/tests/test-sshserver.py	Thu Apr 04 14:15:32 2024 +0100
@@ -25,9 +25,8 @@
 
     def assertparse(self, cmd, input, expected):
         server = mockserver(input)
-        proto = wireprotoserver.sshv1protocolhandler(
-            server._ui, server._fin, server._fout
-        )
+        ui = server._ui
+        proto = wireprotoserver.sshv1protocolhandler(ui, ui.fin, ui.fout)
         _func, spec = wireprotov1server.commands[cmd]
         self.assertEqual(proto.getargs(spec), expected)
 
@@ -35,6 +34,9 @@
 def mockserver(inbytes):
     ui = mockui(inbytes)
     repo = mockrepo(ui)
+    # note: this test unfortunately doesn't really test anything about
+    # `sshserver` class anymore: the entirety of logic of that class lives
+    # in `serveuntil`, and that function is not even called by this test.
     return wireprotoserver.sshserver(ui, repo)