ui: proxy protect/restorestdio() calls to update internal flag
authorYuya Nishihara <yuya@tcha.org>
Wed, 26 Sep 2018 21:41:52 +0900
changeset 41285 cf8677cd7286
parent 41284 b0e3f2d7c143
child 41286 00b314c42094
ui: proxy protect/restorestdio() calls to update internal flag It should be better to manage the redirection flag solely by the ui class.
mercurial/chgserver.py
mercurial/ui.py
mercurial/wireprotoserver.py
tests/test-sshserver.py
--- a/mercurial/chgserver.py	Wed Sep 26 21:29:13 2018 +0900
+++ b/mercurial/chgserver.py	Wed Sep 26 21:41:52 2018 +0900
@@ -200,7 +200,7 @@
         def _runsystem(self, cmd, environ, cwd, out):
             # fallback to the original system method if
             #  a. the output stream is not stdout (e.g. stderr, cStringIO),
-            #  b. or stdout is redirected by protectstdio(),
+            #  b. or stdout is redirected by protectfinout(),
             # because the chg client is not aware of these situations and
             # will behave differently (i.e. write to stdout).
             if (out is not self.fout
--- a/mercurial/ui.py	Wed Sep 26 21:29:13 2018 +0900
+++ b/mercurial/ui.py	Wed Sep 26 21:41:52 2018 +0900
@@ -1080,14 +1080,38 @@
             return False
         return procutil.isatty(fh)
 
+    def protectfinout(self):
+        """Duplicate ui streams and redirect original if they are stdio
+
+        Returns (fin, fout) which point to the original ui fds, but may be
+        copy of them. The returned streams can be considered "owned" in that
+        print(), exec(), etc. never reach to them.
+        """
+        if self._finoutredirected:
+            # if already redirected, protectstdio() would just create another
+            # nullfd pair, which is equivalent to returning self._fin/_fout.
+            return self._fin, self._fout
+        fin, fout = procutil.protectstdio(self._fin, self._fout)
+        self._finoutredirected = (fin, fout) != (self._fin, self._fout)
+        return fin, fout
+
+    def restorefinout(self, fin, fout):
+        """Restore ui streams from possibly duplicated (fin, fout)"""
+        if (fin, fout) == (self._fin, self._fout):
+            return
+        procutil.restorestdio(self._fin, self._fout, fin, fout)
+        # protectfinout() won't create more than one duplicated streams,
+        # so we can just turn the redirection flag off.
+        self._finoutredirected = False
+
     @contextlib.contextmanager
     def protectedfinout(self):
         """Run code block with protected standard streams"""
-        fin, fout = procutil.protectstdio(self._fin, self._fout)
+        fin, fout = self.protectfinout()
         try:
             yield fin, fout
         finally:
-            procutil.restorestdio(self._fin, self._fout, fin, fout)
+            self.restorefinout(fin, fout)
 
     def disablepager(self):
         self._disablepager = True
--- a/mercurial/wireprotoserver.py	Wed Sep 26 21:29:13 2018 +0900
+++ b/mercurial/wireprotoserver.py	Wed Sep 26 21:41:52 2018 +0900
@@ -24,7 +24,6 @@
 from .utils import (
     cborutil,
     interfaceutil,
-    procutil,
 )
 
 stringio = util.stringio
@@ -782,9 +781,7 @@
     def __init__(self, ui, repo, logfh=None):
         self._ui = ui
         self._repo = repo
-        self._fin, self._fout = procutil.protectstdio(ui.fin, ui.fout)
-        # TODO: manage the redirection flag internally by ui
-        ui._finoutredirected = (self._fin, self._fout) != (ui.fin, ui.fout)
+        self._fin, self._fout = ui.protectfinout()
 
         # Log write I/O to stdout and stderr if configured.
         if logfh:
@@ -795,8 +792,7 @@
 
     def serve_forever(self):
         self.serveuntil(threading.Event())
-        procutil.restorestdio(self._ui.fin, self._ui.fout,
-                              self._fin, self._fout)
+        self._ui.restorefinout(self._fin, self._fout)
         sys.exit(0)
 
     def serveuntil(self, ev):
--- a/tests/test-sshserver.py	Wed Sep 26 21:29:13 2018 +0900
+++ b/tests/test-sshserver.py	Wed Sep 26 21:41:52 2018 +0900
@@ -47,6 +47,12 @@
         self.fout = io.BytesIO()
         self.ferr = io.BytesIO()
 
+    def protectfinout(self):
+        return self.fin, self.fout
+
+    def restorefinout(self, fin, fout):
+        pass
+
 if __name__ == '__main__':
     # Don't call into msvcrt to set BytesIO to binary mode
     procutil.setbinary = lambda fp: True