Patchwork D2089: wireproto: introduce type for raw byte responses (API)

login
register
mail settings
Submitter phabricator
Date Feb. 8, 2018, 12:27 a.m.
Message ID <differential-rev-PHID-DREV-l2knxlsx2fqnahnh3rno-req@phab.mercurial-scm.org>
Download mbox | patch
Permalink /patch/27457/
State Superseded
Headers show

Comments

phabricator - Feb. 8, 2018, 12:27 a.m.
indygreg created this revision.
Herald added a subscriber: mercurial-devel.
Herald added a reviewer: hg-reviewers.

REVISION SUMMARY
  Right now we simply return a str/bytes instance for simple
  responses. I want all wire protocol response types to be strongly
  typed. So let's invent and use a type for raw bytes responses.
  
  While I was here, I also switched a `str` to `bytes` in the
  ssh protocol handler. That should make Python 3 a bit happier.
  
  .. api::
  
    Wire protocol command handlers now return a
    wireprototypes.bytesresponse instead of a raw bytes instance.
    Protocol handlers will continue handling bytes instances. However,
    any extensions wrapping wire protocol commands will need to handle
    the new type.

REPOSITORY
  rHG Mercurial

REVISION DETAIL
  https://phab.mercurial-scm.org/D2089

AFFECTED FILES
  hgext/largefiles/proto.py
  mercurial/wireproto.py
  mercurial/wireprotoserver.py
  mercurial/wireprototypes.py
  tests/sshprotoext.py
  tests/test-wireproto.py

CHANGE DETAILS




To: indygreg, #hg-reviewers
Cc: mercurial-devel

Patch

diff --git a/tests/test-wireproto.py b/tests/test-wireproto.py
--- a/tests/test-wireproto.py
+++ b/tests/test-wireproto.py
@@ -1,8 +1,10 @@ 
 from __future__ import absolute_import, print_function
 
 from mercurial import (
+    error,
     util,
     wireproto,
+    wireprototypes,
 )
 stringio = util.stringio
 
@@ -42,7 +44,13 @@ 
         return ['batch']
 
     def _call(self, cmd, **args):
-        return wireproto.dispatch(self.serverrepo, proto(args), cmd)
+        res = wireproto.dispatch(self.serverrepo, proto(args), cmd)
+        if isinstance(res, wireprototypes.bytesresponse):
+            return res.data
+        elif isinstance(res, bytes):
+            return res
+        else:
+            raise error.Abort('dummy client does not support response type')
 
     def _callstream(self, cmd, **args):
         return stringio(self._call(cmd, **args))
diff --git a/tests/sshprotoext.py b/tests/sshprotoext.py
--- a/tests/sshprotoext.py
+++ b/tests/sshprotoext.py
@@ -74,7 +74,7 @@ 
         # Send the upgrade response.
         self._fout.write(b'upgraded %s %s\n' % (token, name))
         servercaps = wireproto.capabilities(self._repo, self._proto)
-        rsp = b'capabilities: %s' % servercaps
+        rsp = b'capabilities: %s' % servercaps.data
         self._fout.write(b'%d\n' % len(rsp))
         self._fout.write(rsp)
         self._fout.write(b'\n')
diff --git a/mercurial/wireprototypes.py b/mercurial/wireprototypes.py
--- a/mercurial/wireprototypes.py
+++ b/mercurial/wireprototypes.py
@@ -5,6 +5,11 @@ 
 
 from __future__ import absolute_import
 
+class bytesresponse(object):
+    """A wire protocol response consisting of raw bytes."""
+    def __init__(self, data):
+        self.data = data
+
 class ooberror(object):
     """wireproto reply: failure of a batch of operation
 
diff --git a/mercurial/wireprotoserver.py b/mercurial/wireprotoserver.py
--- a/mercurial/wireprotoserver.py
+++ b/mercurial/wireprotoserver.py
@@ -274,6 +274,9 @@ 
     if isinstance(rsp, bytes):
         req.respond(HTTP_OK, HGTYPE, body=rsp)
         return []
+    elif isinstance(rsp, wireprototypes.bytesresponse):
+        req.respond(HTTP_OK, HGTYPE, body=rsp.data)
+        return []
     elif isinstance(rsp, wireprototypes.streamres_legacy):
         gen = rsp.gen
         req.respond(HTTP_OK, HGTYPE)
@@ -389,6 +392,9 @@ 
         self._fout.write(v)
         self._fout.flush()
 
+    def _sendbytes(self, v):
+        self._sendresponse(v.data)
+
     def _sendstream(self, source):
         write = self._fout.write
         for chunk in source.gen:
@@ -409,7 +415,8 @@ 
         self._fout.flush()
 
     _handlers = {
-        str: _sendresponse,
+        bytes: _sendresponse,
+        wireprototypes.bytesresponse: _sendbytes,
         wireprototypes.streamres: _sendstream,
         wireprototypes.streamres_legacy: _sendstream,
         wireprototypes.pushres: _sendpushresponse,
diff --git a/mercurial/wireproto.py b/mercurial/wireproto.py
--- a/mercurial/wireproto.py
+++ b/mercurial/wireproto.py
@@ -37,6 +37,7 @@ 
 urlerr = util.urlerr
 urlreq = util.urlreq
 
+bytesresponse = wireprototypes.bytesresponse
 ooberror = wireprototypes.ooberror
 pushres = wireprototypes.pushres
 pusherr = wireprototypes.pusherr
@@ -696,16 +697,24 @@ 
             result = func(repo, proto)
         if isinstance(result, ooberror):
             return result
+
+        # For now, all batchable commands must return bytesresponse or
+        # raw bytes (for backwards compatibility).
+        assert isinstance(result, (bytesresponse, bytes))
+        if isinstance(result, bytesresponse):
+            result = result.data
         res.append(escapearg(result))
-    return ';'.join(res)
+
+    return bytesresponse(';'.join(res))
 
 @wireprotocommand('between', 'pairs')
 def between(repo, proto, pairs):
     pairs = [decodelist(p, '-') for p in pairs.split(" ")]
     r = []
     for b in repo.between(pairs):
         r.append(encodelist(b) + "\n")
-    return "".join(r)
+
+    return bytesresponse(''.join(r))
 
 @wireprotocommand('branchmap')
 def branchmap(repo, proto):
@@ -715,15 +724,17 @@ 
         branchname = urlreq.quote(encoding.fromlocal(branch))
         branchnodes = encodelist(nodes)
         heads.append('%s %s' % (branchname, branchnodes))
-    return '\n'.join(heads)
+
+    return bytesresponse('\n'.join(heads))
 
 @wireprotocommand('branches', 'nodes')
 def branches(repo, proto, nodes):
     nodes = decodelist(nodes)
     r = []
     for b in repo.branches(nodes):
         r.append(encodelist(b) + "\n")
-    return "".join(r)
+
+    return bytesresponse(''.join(r))
 
 @wireprotocommand('clonebundles', '')
 def clonebundles(repo, proto):
@@ -735,7 +746,7 @@ 
     depending on the request. e.g. you could advertise URLs for the closest
     data center given the client's IP address.
     """
-    return repo.vfs.tryread('clonebundles.manifest')
+    return bytesresponse(repo.vfs.tryread('clonebundles.manifest'))
 
 wireprotocaps = ['lookup', 'changegroupsubset', 'branchmap', 'pushkey',
                  'known', 'getbundle', 'unbundlehash', 'batch']
@@ -789,7 +800,7 @@ 
 # `_capabilities` instead.
 @wireprotocommand('capabilities')
 def capabilities(repo, proto):
-    return ' '.join(_capabilities(repo, proto))
+    return bytesresponse(' '.join(_capabilities(repo, proto)))
 
 @wireprotocommand('changegroup', 'roots')
 def changegroup(repo, proto, roots):
@@ -814,7 +825,8 @@ 
 def debugwireargs(repo, proto, one, two, others):
     # only accept optional args from the known set
     opts = options('debugwireargs', ['three', 'four'], others)
-    return repo.debugwireargs(one, two, **pycompat.strkwargs(opts))
+    return bytesresponse(repo.debugwireargs(one, two,
+                                            **pycompat.strkwargs(opts)))
 
 @wireprotocommand('getbundle', '*')
 def getbundle(repo, proto, others):
@@ -885,7 +897,7 @@ 
 @wireprotocommand('heads')
 def heads(repo, proto):
     h = repo.heads()
-    return encodelist(h) + "\n"
+    return bytesresponse(encodelist(h) + '\n')
 
 @wireprotocommand('hello')
 def hello(repo, proto):
@@ -896,12 +908,13 @@ 
 
     capabilities: space separated list of tokens
     '''
-    return "capabilities: %s\n" % (capabilities(repo, proto))
+    caps = capabilities(repo, proto).data
+    return bytesresponse('capabilities: %s\n' % caps)
 
 @wireprotocommand('listkeys', 'namespace')
 def listkeys(repo, proto, namespace):
     d = repo.listkeys(encoding.tolocal(namespace)).items()
-    return pushkeymod.encodekeys(d)
+    return bytesresponse(pushkeymod.encodekeys(d))
 
 @wireprotocommand('lookup', 'key')
 def lookup(repo, proto, key):
@@ -913,11 +926,12 @@ 
     except Exception as inst:
         r = str(inst)
         success = 0
-    return "%d %s\n" % (success, r)
+    return bytesresponse('%d %s\n' % (success, r))
 
 @wireprotocommand('known', 'nodes *')
 def known(repo, proto, nodes, others):
-    return ''.join(b and "1" or "0" for b in repo.known(decodelist(nodes)))
+    v = ''.join(b and '1' or '0' for b in repo.known(decodelist(nodes)))
+    return bytesresponse(v)
 
 @wireprotocommand('pushkey', 'namespace key old new')
 def pushkey(repo, proto, namespace, key, old, new):
@@ -938,7 +952,7 @@ 
                          encoding.tolocal(old), new) or False
 
     output = output.getvalue() if output else ''
-    return '%s\n%s' % (int(r), output)
+    return bytesresponse('%s\n%s' % (int(r), output))
 
 @wireprotocommand('stream_out')
 def stream(repo, proto):
diff --git a/hgext/largefiles/proto.py b/hgext/largefiles/proto.py
--- a/hgext/largefiles/proto.py
+++ b/hgext/largefiles/proto.py
@@ -14,6 +14,7 @@ 
     httppeer,
     util,
     wireproto,
+    wireprototypes,
 )
 
 from . import (
@@ -85,8 +86,8 @@ 
     server side.'''
     filename = lfutil.findfile(repo, sha)
     if not filename:
-        return '2\n'
-    return '0\n'
+        return wireprototypes.bytesresponse('2\n')
+    return wireprototypes.bytesresponse('0\n')
 
 def wirereposetup(ui, repo):
     class lfileswirerepository(repo.__class__):