cmdserver: protect pipe server streams against corruption caused by direct io
authorYuya Nishihara <yuya@tcha.org>
Sat, 15 Nov 2014 13:50:43 +0900
changeset 23324 69f86b937035
parent 23323 bc374458688b
child 23325 4165cfd67519
cmdserver: protect pipe server streams against corruption caused by direct io Because pipe-mode server uses stdio as IPC channel, other modules should not touch stdio directly and use ui instead. However, this strategy is brittle because several Python functions read and write stdio implicitly. print 'hello' # should use ui.write() # => ch = 'h', size = 1701604463 'ello', data = '\n' This patch adds protection for such mistakes. Both stdio files and low-level file descriptors are redirected to /dev/null while command server uses them.
mercurial/commandserver.py
tests/test-commandserver.t
--- a/mercurial/commandserver.py	Sat Nov 15 13:04:41 2014 +0900
+++ b/mercurial/commandserver.py	Sat Nov 15 13:50:43 2014 +0900
@@ -7,7 +7,7 @@
 
 from i18n import _
 import struct
-import os, errno, traceback, SocketServer
+import sys, os, errno, traceback, SocketServer
 import dispatch, encoding, util
 
 logfile = None
@@ -248,6 +248,29 @@
 
         return 0
 
+def _protectio(ui):
+    """ duplicates streams and redirect original to null if ui uses stdio """
+    ui.flush()
+    newfiles = []
+    nullfd = os.open(os.devnull, os.O_RDWR)
+    for f, sysf, mode in [(ui.fin, sys.stdin, 'rb'),
+                          (ui.fout, sys.stdout, 'wb')]:
+        if f is sysf:
+            newfd = os.dup(f.fileno())
+            os.dup2(nullfd, f.fileno())
+            f = os.fdopen(newfd, mode)
+        newfiles.append(f)
+    os.close(nullfd)
+    return tuple(newfiles)
+
+def _restoreio(ui, fin, fout):
+    """ restores streams from duplicated ones """
+    ui.flush()
+    for f, uif in [(fin, ui.fin), (fout, ui.fout)]:
+        if f is not uif:
+            os.dup2(f.fileno(), uif.fileno())
+            f.close()
+
 class pipeservice(object):
     def __init__(self, ui, repo, opts):
         self.ui = ui
@@ -258,8 +281,14 @@
 
     def run(self):
         ui = self.ui
-        sv = server(ui, self.repo, ui.fin, ui.fout)
-        return sv.serve()
+        # redirect stdio to null device so that broken extensions or in-process
+        # hooks will never cause corruption of channel protocol.
+        fin, fout = _protectio(ui)
+        try:
+            sv = server(ui, self.repo, fin, fout)
+            return sv.serve()
+        finally:
+            _restoreio(ui, fin, fout)
 
 class _requesthandler(SocketServer.StreamRequestHandler):
     def handle(self):
--- a/tests/test-commandserver.t	Sat Nov 15 13:04:41 2014 +0900
+++ b/tests/test-commandserver.t	Sat Nov 15 13:50:43 2014 +0900
@@ -492,6 +492,7 @@
   foo
 
   $ cat <<EOF > dbgui.py
+  > import os, sys
   > from mercurial import cmdutil, commands
   > cmdtable = {}
   > command = cmdutil.command(cmdtable)
@@ -501,6 +502,14 @@
   > @command("debugprompt", norepo=True)
   > def debugprompt(ui):
   >     ui.write("%s\\n" % ui.prompt("prompt:"))
+  > @command("debugreadstdin", norepo=True)
+  > def debugreadstdin(ui):
+  >     ui.write("read: %r\n" % sys.stdin.read(1))
+  > @command("debugwritestdout", norepo=True)
+  > def debugwritestdout(ui):
+  >     os.write(1, "low-level stdout fd and\n")
+  >     sys.stdout.write("stdout should be redirected to /dev/null\n")
+  >     sys.stdout.flush()
   > EOF
   $ cat <<EOF >> .hg/hgrc
   > [extensions]
@@ -518,10 +527,15 @@
   ...     runcommand(server, ['debugprompt', '--config',
   ...                         'ui.interactive=True'],
   ...                input=cStringIO.StringIO('5678\n'))
+  ...     runcommand(server, ['debugreadstdin'])
+  ...     runcommand(server, ['debugwritestdout'])
   *** runcommand debuggetpass --config ui.interactive=True
   password: 1234
   *** runcommand debugprompt --config ui.interactive=True
   prompt: 5678
+  *** runcommand debugreadstdin
+  read: ''
+  *** runcommand debugwritestdout
 
 
 run commandserver in commandserver, which is silly but should work: