Patchwork D2034: sshpeer: move handshake outside of sshpeer

login
register
mail settings
Submitter phabricator
Date Feb. 5, 2018, 5:15 p.m.
Message ID <3af64386ec4ebac76f9f8e412697f97f@localhost.localdomain>
Download mbox | patch
Permalink /patch/27335/
State Not Applicable
Headers show

Comments

phabricator - Feb. 5, 2018, 5:15 p.m.
indygreg updated this revision to Diff 5229.
indygreg marked an inline comment as done.

REPOSITORY
  rHG Mercurial

CHANGES SINCE LAST UPDATE
  https://phab.mercurial-scm.org/D2034?vs=5192&id=5229

REVISION DETAIL
  https://phab.mercurial-scm.org/D2034

AFFECTED FILES
  mercurial/sshpeer.py
  tests/sshprotoext.py
  tests/test-check-interfaces.py
  tests/test-ssh-proto.t

CHANGE DETAILS




To: indygreg, #hg-reviewers
Cc: lothiraldan, mercurial-devel

Patch

diff --git a/tests/test-ssh-proto.t b/tests/test-ssh-proto.t
--- a/tests/test-ssh-proto.t
+++ b/tests/test-ssh-proto.t
@@ -146,7 +146,6 @@ 
 
   $ hg --config sshpeer.mode=extra-handshake-commands --config sshpeer.handshake-mode=pre-no-args --debug debugpeer ssh://user@dummy/server
   running * "*/tests/dummyssh" 'user@dummy' 'hg -R server serve --stdio' (glob)
-  devel-peer-request: no-args
   sending no-args command
   devel-peer-request: hello
   sending hello command
@@ -182,11 +181,8 @@ 
 
   $ hg --config sshpeer.mode=extra-handshake-commands --config sshpeer.handshake-mode=pre-multiple-no-args --debug debugpeer ssh://user@dummy/server
   running * "*/tests/dummyssh" 'user@dummy' 'hg -R server serve --stdio' (glob)
-  devel-peer-request: unknown1
   sending unknown1 command
-  devel-peer-request: unknown2
   sending unknown2 command
-  devel-peer-request: unknown3
   sending unknown3 command
   devel-peer-request: hello
   sending hello command
diff --git a/tests/test-check-interfaces.py b/tests/test-check-interfaces.py
--- a/tests/test-check-interfaces.py
+++ b/tests/test-check-interfaces.py
@@ -51,10 +51,6 @@ 
         pass
 
 # Facilitates testing sshpeer without requiring an SSH server.
-class testingsshpeer(sshpeer.sshpeer):
-    def _validaterepo(self, *args, **kwargs):
-        pass
-
 class badpeer(httppeer.httppeer):
     def __init__(self):
         super(badpeer, self).__init__(uimod.ui(), 'http://localhost')
@@ -69,8 +65,8 @@ 
     checkobject(badpeer())
     checkobject(httppeer.httppeer(ui, 'http://localhost'))
     checkobject(localrepo.localpeer(dummyrepo()))
-    checkobject(testingsshpeer(ui, 'ssh://localhost/foo', None, None, None,
-                               None))
+    checkobject(sshpeer.sshpeer(ui, 'ssh://localhost/foo', None, None, None,
+                               None, None))
     checkobject(bundlerepo.bundlepeer(dummyrepo()))
     checkobject(statichttprepo.statichttppeer(dummyrepo()))
     checkobject(unionrepo.unionpeer(dummyrepo()))
diff --git a/tests/sshprotoext.py b/tests/sshprotoext.py
--- a/tests/sshprotoext.py
+++ b/tests/sshprotoext.py
@@ -12,6 +12,7 @@ 
 
 from mercurial import (
     error,
+    extensions,
     registrar,
     sshpeer,
     wireproto,
@@ -52,30 +53,26 @@ 
 
         super(prehelloserver, self).serve_forever()
 
-class extrahandshakecommandspeer(sshpeer.sshpeer):
-    """An ssh peer that sends extra commands as part of initial handshake."""
-    def _validaterepo(self):
-        mode = self._ui.config(b'sshpeer', b'handshake-mode')
-        if mode == b'pre-no-args':
-            self._callstream(b'no-args')
-            return super(extrahandshakecommandspeer, self)._validaterepo()
-        elif mode == b'pre-multiple-no-args':
-            self._callstream(b'unknown1')
-            self._callstream(b'unknown2')
-            self._callstream(b'unknown3')
-            return super(extrahandshakecommandspeer, self)._validaterepo()
-        else:
-            raise error.ProgrammingError(b'unknown HANDSHAKECOMMANDMODE: %s' %
-                                         mode)
-
-def registercommands():
-    def dummycommand(repo, proto):
-        raise error.ProgrammingError('this should never be called')
-
-    wireproto.wireprotocommand(b'no-args', b'')(dummycommand)
-    wireproto.wireprotocommand(b'unknown1', b'')(dummycommand)
-    wireproto.wireprotocommand(b'unknown2', b'')(dummycommand)
-    wireproto.wireprotocommand(b'unknown3', b'')(dummycommand)
+def performhandshake(orig, ui, pipei, pipeo, pipee):
+    """Wrapped version of sshpeer._performhandshake to send extra commands."""
+    mode = ui.config(b'sshpeer', b'handshake-mode')
+    if mode == b'pre-no-args':
+        ui.debug(b'sending no-args command\n')
+        pipeo.write(b'no-args\n')
+        pipeo.flush()
+        return orig(ui, pipei, pipeo, pipee)
+    elif mode == b'pre-multiple-no-args':
+        ui.debug(b'sending unknown1 command\n')
+        pipeo.write(b'unknown1\n')
+        ui.debug(b'sending unknown2 command\n')
+        pipeo.write(b'unknown2\n')
+        ui.debug(b'sending unknown3 command\n')
+        pipeo.write(b'unknown3\n')
+        pipeo.flush()
+        return orig(ui, pipei, pipeo, pipee)
+    else:
+        raise error.ProgrammingError(b'unknown HANDSHAKECOMMANDMODE: %s' %
+                                     mode)
 
 def extsetup(ui):
     # It's easier for tests to define the server behavior via environment
@@ -94,7 +91,6 @@ 
     peermode = ui.config(b'sshpeer', b'mode')
 
     if peermode == b'extra-handshake-commands':
-        sshpeer.sshpeer = extrahandshakecommandspeer
-        registercommands()
+        extensions.wrapfunction(sshpeer, '_performhandshake', performhandshake)
     elif peermode:
         raise error.ProgrammingError(b'unknown peer mode: %s' % peermode)
diff --git a/mercurial/sshpeer.py b/mercurial/sshpeer.py
--- a/mercurial/sshpeer.py
+++ b/mercurial/sshpeer.py
@@ -157,13 +157,69 @@ 
 
     return proc, pipei, pipeo, pipee
 
+def _performhandshake(ui, pipei, pipeo, pipee):
+    def badresponse():
+        msg = _('no suitable response from remote hg')
+        hint = ui.config('ui', 'ssherrorhint')
+        raise error.RepoError(msg, hint=hint)
+
+    requestlog = ui.configbool('devel', 'debug.peer-request')
+
+    try:
+        pairsarg = '%s-%s' % ('0' * 40, '0' * 40)
+        handshake = [
+            'hello\n',
+            'between\n',
+            'pairs %d\n' % len(pairsarg),
+            pairsarg,
+        ]
+
+        if requestlog:
+            ui.debug('devel-peer-request: hello\n')
+        ui.debug('sending hello command\n')
+        if requestlog:
+            ui.debug('devel-peer-request: between\n')
+            ui.debug('devel-peer-request:   pairs: %d bytes\n' %len(pairsarg))
+        ui.debug('sending between command\n')
+
+        pipeo.write(''.join(handshake))
+        pipeo.flush()
+    except IOError:
+        badresponse()
+
+    lines = ['', 'dummy']
+    max_noise = 500
+    while lines[-1] and max_noise:
+        try:
+            l = pipei.readline()
+            _forwardoutput(ui, pipee)
+            if lines[-1] == '1\n' and l == '\n':
+                break
+            if l:
+                ui.debug('remote: ', l)
+            lines.append(l)
+            max_noise -= 1
+        except IOError:
+            badresponse()
+    else:
+        badresponse()
+
+    caps = set()
+    for l in reversed(lines):
+        if l.startswith('capabilities:'):
+            caps.update(l[:-1].split(':')[1].split())
+            break
+
+    return caps
+
 class sshpeer(wireproto.wirepeer):
-    def __init__(self, ui, url, proc, pipei, pipeo, pipee):
+    def __init__(self, ui, url, proc, pipei, pipeo, pipee, caps):
         """Create a peer from an existing SSH connection.
 
         ``proc`` is a handle on the underlying SSH process.
         ``pipei``, ``pipeo``, and ``pipee`` are handles on the stdin,
         stdout, and stderr file descriptors for that process.
+        ``caps`` is a set of capabilities supported by the remote.
         """
         self._url = url
         self._ui = ui
@@ -173,8 +229,7 @@ 
         self._pipei = pipei
         self._pipeo = pipeo
         self._pipee = pipee
-
-        self._validaterepo()
+        self._caps = caps
 
     # Begin of _basepeer interface.
 
@@ -206,61 +261,6 @@ 
 
     # End of _basewirecommands interface.
 
-    def _validaterepo(self):
-        def badresponse():
-            msg = _("no suitable response from remote hg")
-            hint = self.ui.config("ui", "ssherrorhint")
-            self._abort(error.RepoError(msg, hint=hint))
-
-        try:
-            pairsarg = '%s-%s' % ('0' * 40, '0' * 40)
-
-            handshake = [
-                'hello\n',
-                'between\n',
-                'pairs %d\n' % len(pairsarg),
-                pairsarg,
-            ]
-
-            requestlog = self.ui.configbool('devel', 'debug.peer-request')
-
-            if requestlog:
-                self.ui.debug('devel-peer-request: hello\n')
-            self.ui.debug('sending hello command\n')
-            if requestlog:
-                self.ui.debug('devel-peer-request: between\n')
-                self.ui.debug('devel-peer-request:   pairs: %d bytes\n' %
-                              len(pairsarg))
-            self.ui.debug('sending between command\n')
-
-            self._pipeo.write(''.join(handshake))
-            self._pipeo.flush()
-        except IOError:
-            badresponse()
-
-        lines = ["", "dummy"]
-        max_noise = 500
-        while lines[-1] and max_noise:
-            try:
-                l = self._pipei.readline()
-                _forwardoutput(self.ui, self._pipee)
-                if lines[-1] == "1\n" and l == "\n":
-                    break
-                if l:
-                    self.ui.debug("remote: ", l)
-                lines.append(l)
-                max_noise -= 1
-            except IOError:
-                badresponse()
-        else:
-            badresponse()
-
-        self._caps = set()
-        for l in reversed(lines):
-            if l.startswith("capabilities:"):
-                self._caps.update(l[:-1].split(":")[1].split())
-                break
-
     def _readerr(self):
         _forwardoutput(self.ui, self._pipee)
 
@@ -415,4 +415,10 @@ 
     proc, pipei, pipeo, pipee = _makeconnection(ui, sshcmd, args, remotecmd,
                                                 remotepath, sshenv)
 
-    return sshpeer(ui, path, proc, pipei, pipeo, pipee)
+    try:
+        caps = _performhandshake(ui, pipei, pipeo, pipee)
+    except Exception:
+        _cleanuppipes(ui, pipei, pipeo, pipee)
+        raise
+
+    return sshpeer(ui, path, proc, pipei, pipeo, pipee, caps)