Patchwork D2203: wireprotoserver: move SSH server operation to a standalone function

login
register
mail settings
Submitter phabricator
Date Feb. 13, 2018, 12:46 a.m.
Message ID <differential-rev-PHID-DREV-d6apdewbcng7wrj42szb-req@phab.mercurial-scm.org>
Download mbox | patch
Permalink /patch/27752/
State Superseded
Headers show

Comments

phabricator - Feb. 13, 2018, 12:46 a.m.
indygreg created this revision.
Herald added a subscriber: mercurial-devel.
Herald added a reviewer: hg-reviewers.

REVISION SUMMARY
  The server-side processing logic will soon get a bit more complicated
  in order to handle protocol switches. We will use a state machine
  to help make the transitions clearer.
  
  To prepare for this, we move SSH server operation into a standalone
  function. We structure it as a very simple state machine. It only
  has two states for now, with one state containing the bulk of the
  logic. But things will evolve shortly.

REPOSITORY
  rHG Mercurial

REVISION DETAIL
  https://phab.mercurial-scm.org/D2203

AFFECTED FILES
  mercurial/wireprotoserver.py
  tests/sshprotoext.py
  tests/test-sshserver.py

CHANGE DETAILS




To: indygreg, #hg-reviewers
Cc: mercurial-devel
phabricator - Feb. 15, 2018, 1:49 a.m.
durin42 accepted this revision.
durin42 added a comment.
This revision is now accepted and ready to land.


  LGTM, but needs rebased

REPOSITORY
  rHG Mercurial

REVISION DETAIL
  https://phab.mercurial-scm.org/D2203

To: indygreg, #hg-reviewers, durin42
Cc: durin42, mercurial-devel

Patch

diff --git a/tests/test-sshserver.py b/tests/test-sshserver.py
--- a/tests/test-sshserver.py
+++ b/tests/test-sshserver.py
@@ -23,8 +23,11 @@ 
 
     def assertparse(self, cmd, input, expected):
         server = mockserver(input)
+        proto = wireprotoserver.sshv1protocolhandler(server._ui,
+                                                     server._fin,
+                                                     server._fout)
         _func, spec = wireproto.commands[cmd]
-        self.assertEqual(server._proto.getargs(spec), expected)
+        self.assertEqual(proto.getargs(spec), expected)
 
 def mockserver(inbytes):
     ui = mockui(inbytes)
diff --git a/tests/sshprotoext.py b/tests/sshprotoext.py
--- a/tests/sshprotoext.py
+++ b/tests/sshprotoext.py
@@ -48,7 +48,9 @@ 
         wireprotoserver._sshv1respondbytes(self._fout, b'')
         l = self._fin.readline()
         assert l == b'between\n'
-        rsp = wireproto.dispatch(self._repo, self._proto, b'between')
+        proto = wireprotoserver.sshv1protocolhandler(self._ui, self._fin,
+                                                     self._fout)
+        rsp = wireproto.dispatch(self._repo, proto, b'between')
         wireprotoserver._sshv1respondbytes(self._fout, rsp.data)
 
         super(prehelloserver, self).serve_forever()
@@ -72,8 +74,10 @@ 
         self._fin.read(81)
 
         # Send the upgrade response.
+        proto = wireprotoserver.sshv1protocolhandler(self._ui, self._fin,
+                                                     self._fout)
         self._fout.write(b'upgraded %s %s\n' % (token, name))
-        servercaps = wireproto.capabilities(self._repo, self._proto)
+        servercaps = wireproto.capabilities(self._repo, proto)
         rsp = b'capabilities: %s' % servercaps.data
         self._fout.write(b'%d\n' % len(rsp))
         self._fout.write(rsp)
diff --git a/mercurial/wireprotoserver.py b/mercurial/wireprotoserver.py
--- a/mercurial/wireprotoserver.py
+++ b/mercurial/wireprotoserver.py
@@ -409,6 +409,56 @@ 
         client = encoding.environ.get('SSH_CLIENT', '').split(' ', 1)[0]
         return 'remote:ssh:' + client
 
+def _runsshserver(ui, repo, fin, fout):
+    state = 'protov1-serving'
+    proto = sshv1protocolhandler(ui, fin, fout)
+
+    while True:
+        if state == 'protov1-serving':
+            # Commands are issued on new lines.
+            request = fin.readline()[:-1]
+
+            # Empty lines signal to terminate the connection.
+            if not request:
+                state = 'shutdown'
+                continue
+
+            available = wireproto.commands.commandavailable(request, proto)
+
+            # This command isn't available. Send an empty response and go
+            # back to waiting for a new command.
+            if not available:
+                _sshv1respondbytes(fout, b'')
+                continue
+
+            rsp = wireproto.dispatch(repo, proto, request)
+
+            if isinstance(rsp, bytes):
+                _sshv1respondbytes(fout, rsp)
+            elif isinstance(rsp, wireprototypes.bytesresponse):
+                _sshv1respondbytes(fout, rsp.data)
+            elif isinstance(rsp, wireprototypes.streamres):
+                _sshv1respondstream(fout, rsp)
+            elif isinstance(rsp, wireprototypes.streamreslegacy):
+                _sshv1respondstream(fout, rsp)
+            elif isinstance(rsp, wireprototypes.pushres):
+                _sshv1respondbytes(fout, b'')
+                _sshv1respondbytes(fout, bytes(rsp.res))
+            elif isinstance(rsp, wireprototypes.pusherr):
+                _sshv1respondbytes(fout, rsp.res)
+            elif isinstance(rsp, wireprototypes.ooberror):
+                _sshv1respondooberror(fout, ui.ferr, rsp.message)
+            else:
+                raise error.ProgrammingError('unhandled response type from '
+                                             'wire protocol command: %s' % rsp)
+
+        elif state == 'shutdown':
+            break
+
+        else:
+            raise error.ProgrammingError('unhandled ssh server state: %s' %
+                                         state)
+
 class sshserver(object):
     def __init__(self, ui, repo):
         self._ui = ui
@@ -423,36 +473,8 @@ 
         util.setbinary(self._fin)
         util.setbinary(self._fout)
 
-        self._proto = sshv1protocolhandler(self._ui, self._fin, self._fout)
-
     def serve_forever(self):
-        while self.serve_one():
-            pass
+        _runsshserver(self._ui, self._repo, self._fin, self._fout)
         sys.exit(0)
 
-    def serve_one(self):
-        cmd = self._fin.readline()[:-1]
-        if cmd and wireproto.commands.commandavailable(cmd, self._proto):
-            rsp = wireproto.dispatch(self._repo, self._proto, cmd)
 
-            if isinstance(rsp, bytes):
-                _sshv1respondbytes(self._fout, rsp)
-            elif isinstance(rsp, wireprototypes.bytesresponse):
-                _sshv1respondbytes(self._fout, rsp.data)
-            elif isinstance(rsp, wireprototypes.streamres):
-                _sshv1respondstream(self._fout, rsp)
-            elif isinstance(rsp, wireprototypes.streamreslegacy):
-                _sshv1respondstream(self._fout, rsp)
-            elif isinstance(rsp, wireprototypes.pushres):
-                _sshv1respondbytes(self._fout, b'')
-                _sshv1respondbytes(self._fout, bytes(rsp.res))
-            elif isinstance(rsp, wireprototypes.pusherr):
-                _sshv1respondbytes(self._fout, rsp.res)
-            elif isinstance(rsp, wireprototypes.ooberror):
-                _sshv1respondooberror(self._fout, self._ui.ferr, rsp.message)
-            else:
-                raise error.ProgrammingError('unhandled response type from '
-                                             'wire protocol command: %s' % rsp)
-        elif cmd:
-            _sshv1respondbytes(self._fout, b'')
-        return cmd != ''