Patchwork [4,of,4] atomictempfile: add context manager support

login
register
mail settings
Submitter Martijn Pieters
Date June 23, 2016, 4:44 p.m.
Message ID <529c6d580929ac1a639e.1466700241@mjpieters-mbp.dhcp.thefacebook.com>
Download mbox | patch
Permalink /patch/15587/
State Superseded
Headers show

Comments

Martijn Pieters - June 23, 2016, 4:44 p.m.
# HG changeset patch
# User Martijn Pieters <mjpieters@fb.com>
# Date 1466700112 -3600
#      Thu Jun 23 17:41:52 2016 +0100
# Node ID 529c6d580929ac1a639e03cdb39dba80662c79ff
# Parent  ce8e19a9c8df0f3992183f4eaf180f993d915149
atomictempfile: add context manager support

Close the file (moving it in place) on clean context exit, discard when there
has been an exception.

Patch

diff --git a/mercurial/util.py b/mercurial/util.py
--- a/mercurial/util.py
+++ b/mercurial/util.py
@@ -1516,6 +1516,15 @@ 
         if safehasattr(self, '_fp'): # constructor actually did something
             self.discard()
 
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exctype, excvalue, traceback):
+        if exctype is not None:
+            self.discard()
+        else:
+            self.close()
+
 def makedirs(name, mode=None, notindexed=False):
     """recursive directory creation with parent mode inheritance
 
diff --git a/tests/test-atomictempfile.py b/tests/test-atomictempfile.py
--- a/tests/test-atomictempfile.py
+++ b/tests/test-atomictempfile.py
@@ -96,6 +96,24 @@ 
         self.assertTrue(file.read(), b'foobar\n')
         file.discard()
 
+    def test_contextmanager_success(self):
+        """When the context closes, the file is closed"""
+        with atomictempfile('foo') as f:
+            self.assertFalse(os.path.isfile('foo'))
+            f.write(b'argh\n')
+        self.assertTrue(os.path.isfile('foo'))
+
+    def test_contextmanager_failure(self):
+        """On exception, the file is discarded"""
+        try:
+            with atomictempfile('foo') as f:
+                self.assertFalse(os.path.isfile('foo'))
+                f.write(b'argh\n')
+                raise ValueError
+        except ValueError:
+            pass
+        self.assertFalse(os.path.isfile('foo'))
+
 if __name__ == '__main__':
     import silenttestrunner
     silenttestrunner.main(__name__)