mercurial/patch.py
changeset 37621 5537d8f5e989
parent 37574 a1bcc7ff0eac
child 37730 8d730f96e792
--- a/mercurial/patch.py	Thu Apr 12 23:06:27 2018 -0700
+++ b/mercurial/patch.py	Thu Apr 12 23:14:38 2018 -0700
@@ -9,6 +9,7 @@
 from __future__ import absolute_import, print_function
 
 import collections
+import contextlib
 import copy
 import difflib
 import email
@@ -192,6 +193,7 @@
                   ('Node ID', 'nodeid'),
                  ]
 
+@contextlib.contextmanager
 def extract(ui, fileobj):
     '''extract patch from data read from fileobj.
 
@@ -209,6 +211,16 @@
     Any item can be missing from the dictionary. If filename is missing,
     fileobj did not contain a patch. Caller must unlink filename when done.'''
 
+    fd, tmpname = tempfile.mkstemp(prefix='hg-patch-')
+    tmpfp = os.fdopen(fd, r'wb')
+    try:
+        yield _extract(ui, fileobj, tmpname, tmpfp)
+    finally:
+        tmpfp.close()
+        os.unlink(tmpname)
+
+def _extract(ui, fileobj, tmpname, tmpfp):
+
     # attempt to detect the start of a patch
     # (this heuristic is borrowed from quilt)
     diffre = re.compile(br'^(?:Index:[ \t]|diff[ \t]-|RCS file: |'
@@ -218,86 +230,80 @@
                         re.MULTILINE | re.DOTALL)
 
     data = {}
-    fd, tmpname = tempfile.mkstemp(prefix='hg-patch-')
-    tmpfp = os.fdopen(fd, r'wb')
-    try:
-        msg = pycompat.emailparser().parse(fileobj)
+
+    msg = pycompat.emailparser().parse(fileobj)
 
-        subject = msg[r'Subject'] and mail.headdecode(msg[r'Subject'])
-        data['user'] = msg[r'From'] and mail.headdecode(msg[r'From'])
-        if not subject and not data['user']:
-            # Not an email, restore parsed headers if any
-            subject = '\n'.join(': '.join(map(encoding.strtolocal, h))
-                                for h in msg.items()) + '\n'
+    subject = msg[r'Subject'] and mail.headdecode(msg[r'Subject'])
+    data['user'] = msg[r'From'] and mail.headdecode(msg[r'From'])
+    if not subject and not data['user']:
+        # Not an email, restore parsed headers if any
+        subject = '\n'.join(': '.join(map(encoding.strtolocal, h))
+                            for h in msg.items()) + '\n'
 
-        # should try to parse msg['Date']
-        parents = []
+    # should try to parse msg['Date']
+    parents = []
 
-        if subject:
-            if subject.startswith('[PATCH'):
-                pend = subject.find(']')
-                if pend >= 0:
-                    subject = subject[pend + 1:].lstrip()
-            subject = re.sub(br'\n[ \t]+', ' ', subject)
-            ui.debug('Subject: %s\n' % subject)
-        if data['user']:
-            ui.debug('From: %s\n' % data['user'])
-        diffs_seen = 0
-        ok_types = ('text/plain', 'text/x-diff', 'text/x-patch')
-        message = ''
-        for part in msg.walk():
-            content_type = pycompat.bytestr(part.get_content_type())
-            ui.debug('Content-Type: %s\n' % content_type)
-            if content_type not in ok_types:
-                continue
-            payload = part.get_payload(decode=True)
-            m = diffre.search(payload)
-            if m:
-                hgpatch = False
-                hgpatchheader = False
-                ignoretext = False
+    if subject:
+        if subject.startswith('[PATCH'):
+            pend = subject.find(']')
+            if pend >= 0:
+                subject = subject[pend + 1:].lstrip()
+        subject = re.sub(br'\n[ \t]+', ' ', subject)
+        ui.debug('Subject: %s\n' % subject)
+    if data['user']:
+        ui.debug('From: %s\n' % data['user'])
+    diffs_seen = 0
+    ok_types = ('text/plain', 'text/x-diff', 'text/x-patch')
+    message = ''
+    for part in msg.walk():
+        content_type = pycompat.bytestr(part.get_content_type())
+        ui.debug('Content-Type: %s\n' % content_type)
+        if content_type not in ok_types:
+            continue
+        payload = part.get_payload(decode=True)
+        m = diffre.search(payload)
+        if m:
+            hgpatch = False
+            hgpatchheader = False
+            ignoretext = False
 
-                ui.debug('found patch at byte %d\n' % m.start(0))
-                diffs_seen += 1
-                cfp = stringio()
-                for line in payload[:m.start(0)].splitlines():
-                    if line.startswith('# HG changeset patch') and not hgpatch:
-                        ui.debug('patch generated by hg export\n')
-                        hgpatch = True
-                        hgpatchheader = True
-                        # drop earlier commit message content
-                        cfp.seek(0)
-                        cfp.truncate()
-                        subject = None
-                    elif hgpatchheader:
-                        if line.startswith('# User '):
-                            data['user'] = line[7:]
-                            ui.debug('From: %s\n' % data['user'])
-                        elif line.startswith("# Parent "):
-                            parents.append(line[9:].lstrip())
-                        elif line.startswith("# "):
-                            for header, key in patchheadermap:
-                                prefix = '# %s ' % header
-                                if line.startswith(prefix):
-                                    data[key] = line[len(prefix):]
-                        else:
-                            hgpatchheader = False
-                    elif line == '---':
-                        ignoretext = True
-                    if not hgpatchheader and not ignoretext:
-                        cfp.write(line)
-                        cfp.write('\n')
-                message = cfp.getvalue()
-                if tmpfp:
-                    tmpfp.write(payload)
-                    if not payload.endswith('\n'):
-                        tmpfp.write('\n')
-            elif not diffs_seen and message and content_type == 'text/plain':
-                message += '\n' + payload
-    except: # re-raises
-        tmpfp.close()
-        os.unlink(tmpname)
-        raise
+            ui.debug('found patch at byte %d\n' % m.start(0))
+            diffs_seen += 1
+            cfp = stringio()
+            for line in payload[:m.start(0)].splitlines():
+                if line.startswith('# HG changeset patch') and not hgpatch:
+                    ui.debug('patch generated by hg export\n')
+                    hgpatch = True
+                    hgpatchheader = True
+                    # drop earlier commit message content
+                    cfp.seek(0)
+                    cfp.truncate()
+                    subject = None
+                elif hgpatchheader:
+                    if line.startswith('# User '):
+                        data['user'] = line[7:]
+                        ui.debug('From: %s\n' % data['user'])
+                    elif line.startswith("# Parent "):
+                        parents.append(line[9:].lstrip())
+                    elif line.startswith("# "):
+                        for header, key in patchheadermap:
+                            prefix = '# %s ' % header
+                            if line.startswith(prefix):
+                                data[key] = line[len(prefix):]
+                    else:
+                        hgpatchheader = False
+                elif line == '---':
+                    ignoretext = True
+                if not hgpatchheader and not ignoretext:
+                    cfp.write(line)
+                    cfp.write('\n')
+            message = cfp.getvalue()
+            if tmpfp:
+                tmpfp.write(payload)
+                if not payload.endswith('\n'):
+                    tmpfp.write('\n')
+        elif not diffs_seen and message and content_type == 'text/plain':
+            message += '\n' + payload
 
     if subject and not message.startswith(subject):
         message = '%s\n%s' % (subject, message)
@@ -310,8 +316,7 @@
 
     if diffs_seen:
         data['filename'] = tmpname
-    else:
-        os.unlink(tmpname)
+
     return data
 
 class patchmeta(object):