Patchwork D3306: patch: make extract() a context manager (API)

login
register
mail settings
Submitter phabricator
Date April 13, 2018, 6:27 a.m.
Message ID <differential-rev-PHID-DREV-tooek3g3gjjmbyhmxamc-req@phab.mercurial-scm.org>
Download mbox | patch
Permalink /patch/30849/
State Superseded
Headers show

Comments

phabricator - April 13, 2018, 6:27 a.m.
indygreg created this revision.
Herald added a subscriber: mercurial-devel.
Herald added a reviewer: hg-reviewers.

REVISION SUMMARY
  Previously, this function was creating a temporary file and relying
  on callers to unlink it. Yuck.
  
  We convert the function to a context manager and tie the lifetime of
  the temporary file to that of the context manager. This changed
  indentation not only from the context manager, but also from the
  elination of try blocks. It was just easier to split the heart of
  extract() into its own function.
  
  The single consumer of this function has been refactored to use it as
  a context manager. Code for cleaning up the file in tryimportone()
  has also been removed.
  
  .. api::
  
    ``patch.extract()`` is now a context manager. Callers no longer have
    to worry about deleting the temporary file it creates, as the file is
    tied to the lifetime of the context manager.

REPOSITORY
  rHG Mercurial

REVISION DETAIL
  https://phab.mercurial-scm.org/D3306

AFFECTED FILES
  mercurial/cmdutil.py
  mercurial/commands.py
  mercurial/patch.py

CHANGE DETAILS




To: indygreg, #hg-reviewers
Cc: mercurial-devel

Patch

diff --git a/mercurial/patch.py b/mercurial/patch.py
--- a/mercurial/patch.py
+++ b/mercurial/patch.py
@@ -9,6 +9,7 @@ 
 from __future__ import absolute_import, print_function
 
 import collections
+import contextlib
 import copy
 import difflib
 import email
@@ -192,6 +193,7 @@ 
                   ('Node ID', 'nodeid'),
                  ]
 
+@contextlib.contextmanager
 def extract(ui, fileobj):
     '''extract patch from data read from fileobj.
 
@@ -209,6 +211,16 @@ 
     Any item can be missing from the dictionary. If filename is missing,
     fileobj did not contain a patch. Caller must unlink filename when done.'''
 
+    fd, tmpname = tempfile.mkstemp(prefix='hg-patch-')
+    tmpfp = os.fdopen(fd, r'wb')
+    try:
+        yield _extract(ui, fileobj, tmpname, tmpfp)
+    finally:
+        tmpfp.close()
+        os.unlink(tmpname)
+
+def _extract(ui, fileobj, tmpname, tmpfp):
+
     # attempt to detect the start of a patch
     # (this heuristic is borrowed from quilt)
     diffre = re.compile(br'^(?:Index:[ \t]|diff[ \t]-|RCS file: |'
@@ -218,86 +230,80 @@ 
                         re.MULTILINE | re.DOTALL)
 
     data = {}
-    fd, tmpname = tempfile.mkstemp(prefix='hg-patch-')
-    tmpfp = os.fdopen(fd, r'wb')
-    try:
-        msg = pycompat.emailparser().parse(fileobj)
+
+    msg = pycompat.emailparser().parse(fileobj)
 
-        subject = msg[r'Subject'] and mail.headdecode(msg[r'Subject'])
-        data['user'] = msg[r'From'] and mail.headdecode(msg[r'From'])
-        if not subject and not data['user']:
-            # Not an email, restore parsed headers if any
-            subject = '\n'.join(': '.join(map(encoding.strtolocal, h))
-                                for h in msg.items()) + '\n'
+    subject = msg[r'Subject'] and mail.headdecode(msg[r'Subject'])
+    data['user'] = msg[r'From'] and mail.headdecode(msg[r'From'])
+    if not subject and not data['user']:
+        # Not an email, restore parsed headers if any
+        subject = '\n'.join(': '.join(map(encoding.strtolocal, h))
+                            for h in msg.items()) + '\n'
 
-        # should try to parse msg['Date']
-        parents = []
+    # should try to parse msg['Date']
+    parents = []
 
-        if subject:
-            if subject.startswith('[PATCH'):
-                pend = subject.find(']')
-                if pend >= 0:
-                    subject = subject[pend + 1:].lstrip()
-            subject = re.sub(br'\n[ \t]+', ' ', subject)
-            ui.debug('Subject: %s\n' % subject)
-        if data['user']:
-            ui.debug('From: %s\n' % data['user'])
-        diffs_seen = 0
-        ok_types = ('text/plain', 'text/x-diff', 'text/x-patch')
-        message = ''
-        for part in msg.walk():
-            content_type = pycompat.bytestr(part.get_content_type())
-            ui.debug('Content-Type: %s\n' % content_type)
-            if content_type not in ok_types:
-                continue
-            payload = part.get_payload(decode=True)
-            m = diffre.search(payload)
-            if m:
-                hgpatch = False
-                hgpatchheader = False
-                ignoretext = False
+    if subject:
+        if subject.startswith('[PATCH'):
+            pend = subject.find(']')
+            if pend >= 0:
+                subject = subject[pend + 1:].lstrip()
+        subject = re.sub(br'\n[ \t]+', ' ', subject)
+        ui.debug('Subject: %s\n' % subject)
+    if data['user']:
+        ui.debug('From: %s\n' % data['user'])
+    diffs_seen = 0
+    ok_types = ('text/plain', 'text/x-diff', 'text/x-patch')
+    message = ''
+    for part in msg.walk():
+        content_type = pycompat.bytestr(part.get_content_type())
+        ui.debug('Content-Type: %s\n' % content_type)
+        if content_type not in ok_types:
+            continue
+        payload = part.get_payload(decode=True)
+        m = diffre.search(payload)
+        if m:
+            hgpatch = False
+            hgpatchheader = False
+            ignoretext = False
 
-                ui.debug('found patch at byte %d\n' % m.start(0))
-                diffs_seen += 1
-                cfp = stringio()
-                for line in payload[:m.start(0)].splitlines():
-                    if line.startswith('# HG changeset patch') and not hgpatch:
-                        ui.debug('patch generated by hg export\n')
-                        hgpatch = True
-                        hgpatchheader = True
-                        # drop earlier commit message content
-                        cfp.seek(0)
-                        cfp.truncate()
-                        subject = None
-                    elif hgpatchheader:
-                        if line.startswith('# User '):
-                            data['user'] = line[7:]
-                            ui.debug('From: %s\n' % data['user'])
-                        elif line.startswith("# Parent "):
-                            parents.append(line[9:].lstrip())
-                        elif line.startswith("# "):
-                            for header, key in patchheadermap:
-                                prefix = '# %s ' % header
-                                if line.startswith(prefix):
-                                    data[key] = line[len(prefix):]
-                        else:
-                            hgpatchheader = False
-                    elif line == '---':
-                        ignoretext = True
-                    if not hgpatchheader and not ignoretext:
-                        cfp.write(line)
-                        cfp.write('\n')
-                message = cfp.getvalue()
-                if tmpfp:
-                    tmpfp.write(payload)
-                    if not payload.endswith('\n'):
-                        tmpfp.write('\n')
-            elif not diffs_seen and message and content_type == 'text/plain':
-                message += '\n' + payload
-    except: # re-raises
-        tmpfp.close()
-        os.unlink(tmpname)
-        raise
+            ui.debug('found patch at byte %d\n' % m.start(0))
+            diffs_seen += 1
+            cfp = stringio()
+            for line in payload[:m.start(0)].splitlines():
+                if line.startswith('# HG changeset patch') and not hgpatch:
+                    ui.debug('patch generated by hg export\n')
+                    hgpatch = True
+                    hgpatchheader = True
+                    # drop earlier commit message content
+                    cfp.seek(0)
+                    cfp.truncate()
+                    subject = None
+                elif hgpatchheader:
+                    if line.startswith('# User '):
+                        data['user'] = line[7:]
+                        ui.debug('From: %s\n' % data['user'])
+                    elif line.startswith("# Parent "):
+                        parents.append(line[9:].lstrip())
+                    elif line.startswith("# "):
+                        for header, key in patchheadermap:
+                            prefix = '# %s ' % header
+                            if line.startswith(prefix):
+                                data[key] = line[len(prefix):]
+                    else:
+                        hgpatchheader = False
+                elif line == '---':
+                    ignoretext = True
+                if not hgpatchheader and not ignoretext:
+                    cfp.write(line)
+                    cfp.write('\n')
+            message = cfp.getvalue()
+            if tmpfp:
+                tmpfp.write(payload)
+                if not payload.endswith('\n'):
+                    tmpfp.write('\n')
+        elif not diffs_seen and message and content_type == 'text/plain':
+            message += '\n' + payload
 
     if subject and not message.startswith(subject):
         message = '%s\n%s' % (subject, message)
@@ -310,8 +316,7 @@ 
 
     if diffs_seen:
         data['filename'] = tmpname
-    else:
-        os.unlink(tmpname)
+
     return data
 
 class patchmeta(object):
diff --git a/mercurial/commands.py b/mercurial/commands.py
--- a/mercurial/commands.py
+++ b/mercurial/commands.py
@@ -3089,11 +3089,10 @@ 
 
             haspatch = False
             for hunk in patch.split(patchfile):
-                patchdata = patch.extract(ui, hunk)
-
-                msg, node, rej = cmdutil.tryimportone(ui, repo, patchdata,
-                                                      parents, opts,
-                                                      msgs, hg.clean)
+                with patch.extract(ui, hunk) as patchdata:
+                    msg, node, rej = cmdutil.tryimportone(ui, repo, patchdata,
+                                                          parents, opts,
+                                                          msgs, hg.clean)
                 if msg:
                     haspatch = True
                     ui.note(msg + '\n')
diff --git a/mercurial/cmdutil.py b/mercurial/cmdutil.py
--- a/mercurial/cmdutil.py
+++ b/mercurial/cmdutil.py
@@ -1379,141 +1379,140 @@ 
     strip = opts["strip"]
     prefix = opts["prefix"]
     sim = float(opts.get('similarity') or 0)
+
     if not tmpname:
-        return (None, None, False)
+        return None, None, False
 
     rejects = False
 
-    try:
-        cmdline_message = logmessage(ui, opts)
-        if cmdline_message:
-            # pickup the cmdline msg
-            message = cmdline_message
-        elif message:
-            # pickup the patch msg
-            message = message.strip()
-        else:
-            # launch the editor
-            message = None
-        ui.debug('message:\n%s\n' % (message or ''))
-
-        if len(parents) == 1:
-            parents.append(repo[nullid])
-        if opts.get('exact'):
-            if not nodeid or not p1:
-                raise error.Abort(_('not a Mercurial patch'))
+    cmdline_message = logmessage(ui, opts)
+    if cmdline_message:
+        # pickup the cmdline msg
+        message = cmdline_message
+    elif message:
+        # pickup the patch msg
+        message = message.strip()
+    else:
+        # launch the editor
+        message = None
+    ui.debug('message:\n%s\n' % (message or ''))
+
+    if len(parents) == 1:
+        parents.append(repo[nullid])
+    if opts.get('exact'):
+        if not nodeid or not p1:
+            raise error.Abort(_('not a Mercurial patch'))
+        p1 = repo[p1]
+        p2 = repo[p2 or nullid]
+    elif p2:
+        try:
             p1 = repo[p1]
-            p2 = repo[p2 or nullid]
-        elif p2:
-            try:
-                p1 = repo[p1]
-                p2 = repo[p2]
-                # Without any options, consider p2 only if the
-                # patch is being applied on top of the recorded
-                # first parent.
-                if p1 != parents[0]:
-                    p1 = parents[0]
-                    p2 = repo[nullid]
-            except error.RepoError:
-                p1, p2 = parents
-            if p2.node() == nullid:
-                ui.warn(_("warning: import the patch as a normal revision\n"
-                          "(use --exact to import the patch as a merge)\n"))
+            p2 = repo[p2]
+            # Without any options, consider p2 only if the
+            # patch is being applied on top of the recorded
+            # first parent.
+            if p1 != parents[0]:
+                p1 = parents[0]
+                p2 = repo[nullid]
+        except error.RepoError:
+            p1, p2 = parents
+        if p2.node() == nullid:
+            ui.warn(_("warning: import the patch as a normal revision\n"
+                      "(use --exact to import the patch as a merge)\n"))
+    else:
+        p1, p2 = parents
+
+    n = None
+    if update:
+        if p1 != parents[0]:
+            updatefunc(repo, p1.node())
+        if p2 != parents[1]:
+            repo.setparents(p1.node(), p2.node())
+
+        if opts.get('exact') or importbranch:
+            repo.dirstate.setbranch(branch or 'default')
+
+        partial = opts.get('partial', False)
+        files = set()
+        try:
+            patch.patch(ui, repo, tmpname, strip=strip, prefix=prefix,
+                        files=files, eolmode=None, similarity=sim / 100.0)
+        except error.PatchError as e:
+            if not partial:
+                raise error.Abort(pycompat.bytestr(e))
+            if partial:
+                rejects = True
+
+        files = list(files)
+        if nocommit:
+            if message:
+                msgs.append(message)
         else:
-            p1, p2 = parents
-
-        n = None
-        if update:
-            if p1 != parents[0]:
-                updatefunc(repo, p1.node())
-            if p2 != parents[1]:
-                repo.setparents(p1.node(), p2.node())
-
-            if opts.get('exact') or importbranch:
-                repo.dirstate.setbranch(branch or 'default')
-
-            partial = opts.get('partial', False)
+            if opts.get('exact') or p2:
+                # If you got here, you either use --force and know what
+                # you are doing or used --exact or a merge patch while
+                # being updated to its first parent.
+                m = None
+            else:
+                m = scmutil.matchfiles(repo, files or [])
+            editform = mergeeditform(repo[None], 'import.normal')
+            if opts.get('exact'):
+                editor = None
+            else:
+                editor = getcommiteditor(editform=editform,
+                                         **pycompat.strkwargs(opts))
+            extra = {}
+            for idfunc in extrapreimport:
+                extrapreimportmap[idfunc](repo, patchdata, extra, opts)
+            overrides = {}
+            if partial:
+                overrides[('ui', 'allowemptycommit')] = True
+            with repo.ui.configoverride(overrides, 'import'):
+                n = repo.commit(message, user,
+                                date, match=m,
+                                editor=editor, extra=extra)
+                for idfunc in extrapostimport:
+                    extrapostimportmap[idfunc](repo[n])
+    else:
+        if opts.get('exact') or importbranch:
+            branch = branch or 'default'
+        else:
+            branch = p1.branch()
+        store = patch.filestore()
+        try:
             files = set()
             try:
-                patch.patch(ui, repo, tmpname, strip=strip, prefix=prefix,
-                            files=files, eolmode=None, similarity=sim / 100.0)
+                patch.patchrepo(ui, repo, p1, store, tmpname, strip, prefix,
+                                files, eolmode=None)
             except error.PatchError as e:
-                if not partial:
-                    raise error.Abort(pycompat.bytestr(e))
-                if partial:
-                    rejects = True
-
-            files = list(files)
-            if nocommit:
-                if message:
-                    msgs.append(message)
+                raise error.Abort(stringutil.forcebytestr(e))
+            if opts.get('exact'):
+                editor = None
             else:
-                if opts.get('exact') or p2:
-                    # If you got here, you either use --force and know what
-                    # you are doing or used --exact or a merge patch while
-                    # being updated to its first parent.
-                    m = None
-                else:
-                    m = scmutil.matchfiles(repo, files or [])
-                editform = mergeeditform(repo[None], 'import.normal')
-                if opts.get('exact'):
-                    editor = None
-                else:
-                    editor = getcommiteditor(editform=editform,
-                                             **pycompat.strkwargs(opts))
-                extra = {}
-                for idfunc in extrapreimport:
-                    extrapreimportmap[idfunc](repo, patchdata, extra, opts)
-                overrides = {}
-                if partial:
-                    overrides[('ui', 'allowemptycommit')] = True
-                with repo.ui.configoverride(overrides, 'import'):
-                    n = repo.commit(message, user,
-                                    date, match=m,
-                                    editor=editor, extra=extra)
-                    for idfunc in extrapostimport:
-                        extrapostimportmap[idfunc](repo[n])
-        else:
-            if opts.get('exact') or importbranch:
-                branch = branch or 'default'
-            else:
-                branch = p1.branch()
-            store = patch.filestore()
-            try:
-                files = set()
-                try:
-                    patch.patchrepo(ui, repo, p1, store, tmpname, strip, prefix,
-                                    files, eolmode=None)
-                except error.PatchError as e:
-                    raise error.Abort(stringutil.forcebytestr(e))
-                if opts.get('exact'):
-                    editor = None
-                else:
-                    editor = getcommiteditor(editform='import.bypass')
-                memctx = context.memctx(repo, (p1.node(), p2.node()),
-                                            message,
-                                            files=files,
-                                            filectxfn=store,
-                                            user=user,
-                                            date=date,
-                                            branch=branch,
-                                            editor=editor)
-                n = memctx.commit()
-            finally:
-                store.close()
-        if opts.get('exact') and nocommit:
-            # --exact with --no-commit is still useful in that it does merge
-            # and branch bits
-            ui.warn(_("warning: can't check exact import with --no-commit\n"))
-        elif opts.get('exact') and hex(n) != nodeid:
-            raise error.Abort(_('patch is damaged or loses information'))
-        msg = _('applied to working directory')
-        if n:
-            # i18n: refers to a short changeset id
-            msg = _('created %s') % short(n)
-        return (msg, n, rejects)
-    finally:
-        os.unlink(tmpname)
+                editor = getcommiteditor(editform='import.bypass')
+            memctx = context.memctx(repo, (p1.node(), p2.node()),
+                                        message,
+                                        files=files,
+                                        filectxfn=store,
+                                        user=user,
+                                        date=date,
+                                        branch=branch,
+                                        editor=editor)
+            n = memctx.commit()
+        finally:
+            store.close()
+    if opts.get('exact') and nocommit:
+        # --exact with --no-commit is still useful in that it does merge
+        # and branch bits
+        ui.warn(_("warning: can't check exact import with --no-commit\n"))
+    elif opts.get('exact') and hex(n) != nodeid:
+        raise error.Abort(_('patch is damaged or loses information'))
+    msg = _('applied to working directory')
+    if n:
+        # i18n: refers to a short changeset id
+        msg = _('created %s') % short(n)
+    return msg, n, rejects
+
 
 # facility to let extensions include additional data in an exported patch
 # list of identifiers to be executed in order