@@ -24,7 +24,7 @@
def assertparse(self, cmd, input, expected):
server = mockserver(input)
_func, spec = wireproto.commands[cmd]
- self.assertEqual(server.getargs(spec), expected)
+ self.assertEqual(server._proto.getargs(spec), expected)
def mockserver(inbytes):
ui = mockui(inbytes)
@@ -48,7 +48,7 @@
wireprotoserver._sshv1respondbytes(self._fout, b'')
l = self._fin.readline()
assert l == b'between\n'
- rsp = wireproto.dispatch(self._repo, self, b'between')
+ rsp = wireproto.dispatch(self._repo, self._proto, b'between')
wireprotoserver._sshv1respondbytes(self._fout, rsp)
super(prehelloserver, self).serve_forever()
@@ -73,7 +73,7 @@
# Send the upgrade response.
self._fout.write(b'upgraded %s %s\n' % (token, name))
- servercaps = wireproto.capabilities(self._repo, self)
+ servercaps = wireproto.capabilities(self._repo, self._proto)
rsp = b'capabilities: %s' % servercaps
self._fout.write(b'%d\n' % len(rsp))
self._fout.write(rsp)
@@ -354,19 +354,12 @@
fout.write(b'\n')
fout.flush()
-class sshserver(baseprotocolhandler):
- def __init__(self, ui, repo):
+class sshv1protocolhandler(baseprotocolhandler):
+ """Handler for requests services via version 1 of SSH protocol."""
+ def __init__(self, ui, fin, fout):
self._ui = ui
- self._repo = repo
- self._fin = ui.fin
- self._fout = ui.fout
-
- hook.redirect(True)
- ui.fout = repo.ui.fout = ui.ferr
-
- # Prevent insertion/deletion of CRs
- util.setbinary(self._fin)
- util.setbinary(self._fout)
+ self._fin = fin
+ self._fout = fout
@property
def name(self):
@@ -403,15 +396,35 @@
def redirect(self):
pass
+ def _client(self):
+ client = encoding.environ.get('SSH_CLIENT', '').split(' ', 1)[0]
+ return 'remote:ssh:' + client
+
+class sshserver(object):
+ def __init__(self, ui, repo):
+ self._ui = ui
+ self._repo = repo
+ self._fin = ui.fin
+ self._fout = ui.fout
+
+ hook.redirect(True)
+ ui.fout = repo.ui.fout = ui.ferr
+
+ # Prevent insertion/deletion of CRs
+ util.setbinary(self._fin)
+ util.setbinary(self._fout)
+
+ self._proto = sshv1protocolhandler(self._ui, self._fin, self._fout)
+
def serve_forever(self):
while self.serve_one():
pass
sys.exit(0)
def serve_one(self):
cmd = self._fin.readline()[:-1]
- if cmd and wireproto.commands.commandavailable(cmd, self):
- rsp = wireproto.dispatch(self._repo, self, cmd)
+ if cmd and wireproto.commands.commandavailable(cmd, self._proto):
+ rsp = wireproto.dispatch(self._repo, self._proto, cmd)
if isinstance(rsp, bytes):
_sshv1respondbytes(self._fout, rsp)
@@ -432,7 +445,3 @@
elif cmd:
_sshv1respondbytes(self._fout, b'')
return cmd != ''
-
- def _client(self):
- client = encoding.environ.get('SSH_CLIENT', '').split(' ', 1)[0]
- return 'remote:ssh:' + client