lock: move acquirefn call to inside the lock
authorSiddharth Agarwal <sid0@fb.com>
Tue, 22 Sep 2015 14:09:42 -0700
changeset 26321 db4c192cb9b3
parent 26320 3ac7acb99b04
child 26322 2cd19782d2d4
lock: move acquirefn call to inside the lock We're going to need to call it again as part of reinitialization after a subprocess inherits the lock.
mercurial/localrepo.py
mercurial/lock.py
tests/test-lock.py
--- a/mercurial/localrepo.py	Tue Sep 22 13:25:41 2015 -0700
+++ b/mercurial/localrepo.py	Tue Sep 22 14:09:42 2015 -0700
@@ -1210,7 +1210,8 @@
 
     def _lock(self, vfs, lockname, wait, releasefn, acquirefn, desc):
         try:
-            l = lockmod.lock(vfs, lockname, 0, releasefn=releasefn, desc=desc)
+            l = lockmod.lock(vfs, lockname, 0, releasefn=releasefn,
+                             acquirefn=acquirefn, desc=desc)
         except error.LockHeld as inst:
             if not wait:
                 raise
@@ -1219,10 +1220,9 @@
             # default to 600 seconds timeout
             l = lockmod.lock(vfs, lockname,
                              int(self.ui.config("ui", "timeout", "600")),
-                             releasefn=releasefn, desc=desc)
+                             releasefn=releasefn, acquirefn=acquirefn,
+                             desc=desc)
             self.ui.warn(_("got lock after %s seconds\n") % l.delay)
-        if acquirefn:
-            acquirefn()
         return l
 
     def _afterlock(self, callback):
--- a/mercurial/lock.py	Tue Sep 22 13:25:41 2015 -0700
+++ b/mercurial/lock.py	Tue Sep 22 14:09:42 2015 -0700
@@ -38,16 +38,20 @@
 
     _host = None
 
-    def __init__(self, vfs, file, timeout=-1, releasefn=None, desc=None):
+    def __init__(self, vfs, file, timeout=-1, releasefn=None, acquirefn=None,
+                 desc=None):
         self.vfs = vfs
         self.f = file
         self.held = 0
         self.timeout = timeout
         self.releasefn = releasefn
+        self.acquirefn = acquirefn
         self.desc = desc
         self.postrelease  = []
         self.pid = os.getpid()
         self.delay = self.lock()
+        if self.acquirefn:
+            self.acquirefn()
 
     def __del__(self):
         if self.held:
--- a/tests/test-lock.py	Tue Sep 22 13:25:41 2015 -0700
+++ b/tests/test-lock.py	Tue Sep 22 14:09:42 2015 -0700
@@ -15,23 +15,38 @@
 class teststate(object):
     def __init__(self, testcase):
         self._testcase = testcase
+        self._acquirecalled = False
         self._releasecalled = False
         self._postreleasecalled = False
         d = tempfile.mkdtemp(dir=os.getcwd())
         self.vfs = scmutil.vfs(d, audit=False)
 
     def makelock(self, *args, **kwargs):
-        l = lock.lock(self.vfs, testlockname, releasefn=self.releasefn, *args,
-                      **kwargs)
+        l = lock.lock(self.vfs, testlockname, releasefn=self.releasefn,
+                      acquirefn=self.acquirefn, *args, **kwargs)
         l.postrelease.append(self.postreleasefn)
         return l
 
+    def acquirefn(self):
+        self._acquirecalled = True
+
     def releasefn(self):
         self._releasecalled = True
 
     def postreleasefn(self):
         self._postreleasecalled = True
 
+    def assertacquirecalled(self, called):
+        self._testcase.assertEqual(
+            self._acquirecalled, called,
+            'expected acquire to be %s but was actually %s' % (
+                self._tocalled(called),
+                self._tocalled(self._acquirecalled),
+            ))
+
+    def resetacquirefn(self):
+        self._acquirecalled = False
+
     def assertreleasecalled(self, called):
         self._testcase.assertEqual(
             self._releasecalled, called,
@@ -73,6 +88,7 @@
     def testlock(self):
         state = teststate(self)
         lock = state.makelock()
+        state.assertacquirecalled(True)
         lock.release()
         state.assertreleasecalled(True)
         state.assertpostreleasecalled(True)
@@ -81,7 +97,13 @@
     def testrecursivelock(self):
         state = teststate(self)
         lock = state.makelock()
+        state.assertacquirecalled(True)
+
+        state.resetacquirefn()
         lock.lock()
+        # recursive lock should not call acquirefn again
+        state.assertacquirecalled(False)
+
         lock.release() # brings lock refcount down from 2 to 1
         state.assertreleasecalled(False)
         state.assertpostreleasecalled(False)
@@ -95,6 +117,7 @@
     def testlockfork(self):
         state = teststate(self)
         lock = state.makelock()
+        state.assertacquirecalled(True)
         lock.lock()
         # fake a fork
         lock.pid += 1