Patchwork D12583: iblt: prototype for setdiscovery

login
register
mail settings
Submitter phabricator
Date April 21, 2022, 11:25 p.m.
Message ID <differential-rev-PHID-DREV-2bjuz7q5w6ruwmdviibn-req@mercurial-scm.org>
Download mbox | patch
Permalink /patch/50931/
State New
Headers show

Comments

phabricator - April 21, 2022, 11:25 p.m.
joerg.sonnenberger created this revision.
Herald added a reviewer: hg-reviewers.
Herald added a subscriber: mercurial-patches.

REPOSITORY
  rHG Mercurial

BRANCH
  default

REVISION DETAIL
  https://phab.mercurial-scm.org/D12583

AFFECTED FILES
  mercurial/bundlerepo.py
  mercurial/debugcommands.py
  mercurial/exchange.py
  mercurial/iblt.py
  mercurial/interfaces/repository.py
  mercurial/localrepo.py
  mercurial/setdiscovery.py
  mercurial/wireprotov1peer.py
  mercurial/wireprotov1server.py

CHANGE DETAILS




To: joerg.sonnenberger, #hg-reviewers
Cc: mercurial-patches, mercurial-devel

Patch

diff --git a/mercurial/wireprotov1server.py b/mercurial/wireprotov1server.py
--- a/mercurial/wireprotov1server.py
+++ b/mercurial/wireprotov1server.py
@@ -275,6 +275,7 @@ 
     b'known',
     b'getbundle',
     b'unbundlehash',
+    b'iblt-changelog',
 ]
 
 
@@ -426,6 +427,15 @@ 
             continue
     return None
 
+@wireprotocommand(b'getestimator', b'name', permission=b'pull')
+def getestimator(repo, proto, name):
+    estimator = repo.peer().getestimator(name)
+    return wireprototypes.bytesresponse(estimator.dump())
+
+@wireprotocommand(b'getiblt', b'name size seed', permission=b'pull')
+def getiblt(repo, proto, name, size, seed):
+    inst = repo.peer().getiblt(name, int(size), int(seed))
+    return wireprototypes.bytesresponse(inst.dump())
 
 @wireprotocommand(b'getbundle', b'*', permission=b'pull')
 def getbundle(repo, proto, others):
diff --git a/mercurial/wireprotov1peer.py b/mercurial/wireprotov1peer.py
--- a/mercurial/wireprotov1peer.py
+++ b/mercurial/wireprotov1peer.py
@@ -21,6 +21,7 @@ 
     changegroup as changegroupmod,
     encoding,
     error,
+    iblt,
     pushkey as pushkeymod,
     pycompat,
     util,
@@ -503,6 +504,14 @@ 
             ret = bundle2.getunbundler(self.ui, stream)
         return ret
 
+    def getestimator(self, name):
+        d = self._call(b"getestimator", name=name)
+        return iblt.estimator.load(d)
+
+    def getiblt(self, name, size, seed):
+        d = self._call(b"getiblt", name=name, size=b'%d' % size, seed=b'%d' % seed)
+        return iblt.iblt.load(d)[0]
+
     # End of ipeercommands interface.
 
     # Begin of ipeerlegacycommands interface.
diff --git a/mercurial/setdiscovery.py b/mercurial/setdiscovery.py
--- a/mercurial/setdiscovery.py
+++ b/mercurial/setdiscovery.py
@@ -279,6 +279,70 @@ 
     'discovery', member='PartialDiscovery', default=partialdiscovery
 )
 
+import math
+iblt_sizes = [(1 << i) for i in range(5, 31)]
+iblt_sizes += [math.trunc(math.sqrt(2) * s) for s in iblt_sizes]
+iblt_sizes.sort()
+
+def round_iblt_size(size):
+    size = size + size // 4
+    for s in iblt_sizes:
+        if s >= size:
+            return s
+
+def findsetdifferences(ui, local, remote):
+    if not remote.capable(b'iblt-changelog'):
+        ui.status(b'no iblt support: %s\n' % b' '.join(list(remote.capabilities())))
+        return False, [], [], [], []
+    myestimator = local.peer().getestimator(b'changelog')
+    theirestimator = remote.getestimator(b'changelog')
+    estimated_diff = myestimator.compare(theirestimator)
+    # bail out if estimated_diff = O(len(repo)) and fallback to the classic mechanism?
+    iblt_size = round_iblt_size(estimated_diff)
+    ui.debug(b"expected difference is: %d, using IBLT size of %d\n" % (estimated_diff, iblt_size))
+
+    attempt = 0
+    while True:
+        myiblt = local.peer().getiblt(b'changelog', iblt_size, 0)
+        theiriblt = remote.getiblt(b'changelog', iblt_size, 0)
+        theiriblt.subtract(myiblt)
+        success, them_only, my_only = theiriblt.list()
+        if not success:
+            attempt += 1
+            if attempt == 3:
+                ui.debug(b'iblt extraction failed\n')
+                return False, [], [], [], []
+            iblt_size = round_iblt_size(iblt_size + 1)
+            ui.debug(b'iblt extraction failed, retrying with size %d' % iblt_size)
+            continue
+
+        ui.status(b'iblt extraction worked, %d local changes and %d remote changes found\n' % (len(my_only), len(them_only)))
+        break
+
+    has_node = local.changelog.index.has_node
+    nodelen = len(local.nullid)
+    my_only = [node[:nodelen] for node in my_only]
+
+    # first: find all parents and nodes
+    parents = set()
+    nodes = set()
+    for row in them_only:
+        node = row[:nodelen]
+        if has_node(node):
+            raise error.Abort(_(b"found already known remote change: %s") % node)
+        nodes.add(node)
+        parents.add(row[nodelen:2*nodelen])
+        parents.add(row[2*nodelen:])
+    # second: remote heads are all nodes that are not also parents
+    remoteheads = nodes - parents
+    # third: parent nodes that are not nodes themselve are the boundary
+    # of the common set. Double check that they are known locally.
+    commonheadscandidates = parents - nodes
+    commonheads = [node for node in commonheadscandidates if has_node(node)]
+    if len(commonheads) != len(commonheadscandidates):
+        raise error.Abort(_(b"found remote changes with unknown parents"))
+
+    return True, my_only, them_only, commonheads, remoteheads
 
 def findcommonheads(
     ui,
@@ -295,7 +359,6 @@ 
     will be updated with extra data about the discovery, this is useful for
     debug.
     """
-
     samplegrowth = float(ui.config(b'devel', b'discovery.grow-sample.rate'))
 
     if audit is not None:
diff --git a/mercurial/localrepo.py b/mercurial/localrepo.py
--- a/mercurial/localrepo.py
+++ b/mercurial/localrepo.py
@@ -46,6 +46,7 @@ 
     extensions,
     filelog,
     hook,
+    iblt,
     lock as lockmod,
     match as matchmod,
     mergestate as mergestatemod,
@@ -246,6 +247,7 @@ 
     b'known',
     b'getbundle',
     b'unbundle',
+    b'iblt-changelog',
 }
 legacycaps = moderncaps.union({b'changegroupsubset'})
 
@@ -344,6 +346,58 @@ 
     def clonebundles(self):
         return self._repo.tryread(bundlecaches.CB_MANIFEST_FILE)
 
+    def getestimator(self, name):
+        if name == b'changelog':
+            repo = self.local()
+            cachename = b'estimator-changelog.%d' % repo.changelog.tiprev()
+            try:
+                data = repo.cachevfs.read(cachename)
+            except (IOError, OSError):
+                data = None
+            if data:
+                return iblt.estimator.load(data)
+
+            estimator = iblt.estimator(32)
+            tonode = repo.unfiltered().changelog.node
+            for rev in repo.revs('all()'):
+                estimator.insert(tonode(rev))
+            try:
+                with repo.cachevfs.open(cachename, b'wb') as f:
+                    f.write(estimator.dump())
+            except (IOError, OSError) as inst:
+                pass
+        else:
+            raise KeyError(b'unknown getestimator key %s' % name)
+        return estimator
+
+    def getiblt(self, name, size, seed):
+        if seed != 0:
+            raise KeyError(b'unsupport getiblt seed: %s' % seed)
+        if name == b'changelog':
+            repo = self.local()
+            cachename = b'iblt-changelog.%d-%d' % (repo.changelog.tiprev(), size)
+            try:
+                data = repo.cachevfs.read(cachename)
+            except (IOError, OSError):
+                data = None
+            if data:
+                return iblt.iblt.load(data)[0]
+
+            tonode = repo.unfiltered().changelog.node
+            parents = repo.unfiltered().changelog.parents
+            inst = iblt.iblt(size, 3, 3 * len(repo.nullid))
+            for rev in repo.revs('all()'):
+                node = tonode(rev)
+                inst.insert(node + b''.join(parents(node)))
+            try:
+                with repo.cachevfs.open(cachename, b'wb') as f:
+                    f.write(inst.dump())
+            except (IOError, OSError) as inst:
+                pass
+        else:
+            raise KeyError(b'unknown getiblt key %s' % name)
+        return inst
+
     def debugwireargs(self, one, two, three=None, four=None, five=None):
         """Used to test argument passing over the wire"""
         return b"%s %s %s %s %s" % (
diff --git a/mercurial/interfaces/repository.py b/mercurial/interfaces/repository.py
--- a/mercurial/interfaces/repository.py
+++ b/mercurial/interfaces/repository.py
@@ -199,6 +199,12 @@ 
         Returns a generator of bundle data.
         """
 
+    def getestimator(name):
+        pass
+
+    def getiblt(name, size, seed):
+        pass
+
     def heads():
         """Determine all known head revisions in the peer.
 
diff --git a/mercurial/iblt.py b/mercurial/iblt.py
new file mode 100644
--- /dev/null
+++ b/mercurial/iblt.py
@@ -0,0 +1,205 @@ 
+import copy
+import hashlib
+import struct
+
+def range_size(upper):
+    if upper < 256:
+        return 1
+    if upper < 65536:
+        return 2
+    if upper < 16777216:
+        return 3
+    return 4
+
+class iblt:
+    def __init__(self, m, k, key_size):
+        self.m = m
+        self.k = k
+        self.key_size = key_size
+        self.key_xors = [0] * m
+        self.key_hash_xors = [0] * m
+        self.key_hash_size = 4
+        self.counts = [0] * m
+
+    def insert(self, key):
+        self._change(key, 1)
+
+    def remove(self, key):
+        self._change(key, -1)
+
+    def __hash(self, key):
+        hashes = hashlib.blake2b(key, digest_size=4*self.k + self.key_hash_size).digest()
+        values = [int.from_bytes(hashes[4*i:4*i+4], 'big') % self.m for i in range(self.k)]
+        # Fudge indices if they are not unique. This avoids the most common
+        # reason for the (implicit) peeling process to fail.
+        if values[0] == values[1]:
+            values[1] ^= 1
+        if values[0] == values[2] or values[1] == values[2]:
+            values[2] ^= 1
+        if values[0] == values[2] or values[1] == values[2]:
+            values[2] ^= 2
+        return values, int.from_bytes(hashes[4 * self.k:], 'big')
+
+    def _change(self, key, count):
+        indices, keyhash = self.__hash(key)
+        numkey = int.from_bytes(key, 'big')
+        for i in indices:
+            self.key_xors[i] ^= numkey
+            self.key_hash_xors[i] ^= keyhash
+            self.counts[i] += count
+
+    def list(self):
+        left = []
+        right = []
+        queue = []
+        for i in range(self.m):
+            if self.counts[i] in (1, -1):
+                queue.append(i)
+        while queue:
+            i = queue.pop()
+            c = self.counts[i]
+            if c not in (1, -1):
+                continue
+            intkey = self.key_xors[i]
+            key = intkey.to_bytes(length = self.key_size, byteorder='big')
+            indices, keyhash = self.__hash(key)
+            if self.key_hash_xors[i] != keyhash:
+                continue
+
+            for k in indices:
+                self.key_xors[k] ^= intkey
+                self.key_hash_xors[k] ^= keyhash
+                self.counts[k] -= c
+                if self.counts[k] in (1, -1):
+                    queue.append(k)
+
+            if c == 1:
+                left.append(key)
+            else:
+                right.append(key)
+        for i in range(self.m):
+            if self.key_xors[i] or self.key_hash_xors[i] or self.counts[i]:
+                return False, left, right
+        return True, left, right
+
+    def subtract(self, other):
+        assert self.m == other.m
+        assert self.k == other.k
+        assert self.key_size == other.key_size
+
+        for i in range(self.m):
+            self.key_xors[i] ^= other.key_xors[i]
+            self.key_hash_xors[i] ^= other.key_hash_xors[i]
+            self.counts[i] -= other.counts[i]
+
+    def dump(self):
+        min_count = min(self.counts)
+        max_count = max(self.counts)
+        count_size = range_size(max_count - min_count)
+        data = []
+        data.extend(self.m.to_bytes(4, 'big'))
+        data.extend(self.k.to_bytes(1, 'big'))
+        data.extend(self.key_size.to_bytes(1, 'big'))
+        data.extend(self.key_hash_size.to_bytes(1, 'big'))
+        data.extend(count_size.to_bytes(1, 'big'))
+        data.extend(min_count.to_bytes(4, 'big', signed = True))
+        for i in range(self.m):
+            data.extend((self.counts[i] - min_count).to_bytes(count_size, 'big'))
+            data.extend(self.key_hash_xors[i].to_bytes(self.key_hash_size, 'big'))
+            data.extend(self.key_xors[i].to_bytes(self.key_size, 'big'))
+        return bytes(data)
+
+    @classmethod
+    def load(cls, data):
+        self = cls.__new__(cls)
+        self.m = int.from_bytes(data[:4], 'big')
+        self.k = int.from_bytes(data[4:5], 'big')
+        self.key_size = int.from_bytes(data[5:6], 'big')
+        self.key_hash_size = int.from_bytes(data[6:7], 'big')
+        count_size = int.from_bytes(data[7:8], 'big')
+        min_count = int.from_bytes(data[8:12], 'big', signed = True)
+        pos = 12
+        self.counts = []
+        self.key_hash_xors = []
+        self.key_xors = []
+        for i in range(self.m):
+            self.counts.append(min_count + int.from_bytes(data[pos:pos+count_size], 'big'))
+            pos += count_size
+            self.key_hash_xors.append(int.from_bytes(data[pos:pos+self.key_hash_size], 'big'))
+            pos += self.key_hash_size
+            self.key_xors.append(int.from_bytes(data[pos:pos+self.key_size], 'big'))
+            pos += self.key_size
+        return self, pos
+
+    def compatible(self, other):
+        return self.m == other.m and self.k == other.k and self.key_size == other.key_size and self.key_hash_size == other.key_hash_size
+
+    def __eq__(self, other):
+        if not self.compatible(other):
+            return False
+        return self.counts == other.counts and self.key_xors == other.key_xors and self.key_hash_xors == other.key_hash_xors
+
+def ffs(x):
+    return (x&-x).bit_length() - 1
+
+class estimator:
+    def __init__(self, stratas = 32):
+        self.stratas = stratas
+        self.key_size = self.stratas // 8
+        assert self.stratas <= 256
+        self.strata_size = 120
+        self.k = 3
+        self.iblts = [iblt(self.strata_size, self.k, self.key_size) for n in range(self.stratas)]
+
+    def insert(self, key):
+        self._change(key, 1)
+
+    def remove(self, key):
+        self._change(key, -1)
+
+    def _change(self, key, count):
+        h = self.__hash(key)
+        lowest = ffs(h) if h else self.stratas - 1
+        self.iblts[lowest]._change(bytes(h.to_bytes(self.key_size, 'big')), count)
+
+    def __hash(self, key):
+        return int.from_bytes(hashlib.blake2b(key, digest_size=self.key_size).digest(), 'big')
+
+    def dump(self):
+        data = []
+        data.extend(self.stratas.to_bytes(2, 'big'))
+        data.extend(self.strata_size.to_bytes(2, 'big'))
+        for i in range(self.stratas):
+            data.extend(self.iblts[i].dump())
+        return bytes(data)
+
+    @classmethod
+    def load(cls, data):
+        self = cls.__new__(cls)
+        self.stratas = int.from_bytes(data[:2], 'big')
+        self.strata_size = int.from_bytes(data[2:4], 'big')
+        self.key_size = self.stratas // 8
+        assert self.stratas <= 256
+        self.k = 3
+        data = data[4:]
+        self.iblts = []
+        for i in range(self.stratas):
+            inst, pos = iblt.load(data)
+            self.iblts.append(inst)
+            data = data[pos:]
+        return self
+
+    def compare(self, other):
+        assert self.stratas == other.stratas and self.strata_size == other.strata_size and self.k == other.k
+        estimate = 0
+        for i in reversed(range(self.stratas)):
+            iblt = copy.deepcopy(self.iblts[i])
+            #iblt = self.iblts[i]
+            iblt.subtract(other.iblts[i])
+            success, ours, theirs = iblt.list()
+            if success:
+                estimate += len(ours) + len(theirs)
+            else:
+                estimate <<= i + 1
+                break
+        return estimate
diff --git a/mercurial/exchange.py b/mercurial/exchange.py
--- a/mercurial/exchange.py
+++ b/mercurial/exchange.py
@@ -1748,6 +1748,17 @@ 
 
     Current handle changeset discovery only, will change handle all discovery
     at some point."""
+
+    if not pullop.heads:
+        from . import setdiscovery
+        success, mychanges, theirchanges, common, rheads = setdiscovery.findsetdifferences(pullop.repo.ui, pullop.repo, pullop.remote)
+        if success:
+            tonode = pullop.repo.unfiltered().changelog.node
+            pullop.common = common
+            pullop.rheads = rheads
+            pullop.fetch = bool(theirchanges)
+            return
+
     tmp = discovery.findcommonincoming(
         pullop.repo, pullop.remote, heads=pullop.heads, force=pullop.force
     )
diff --git a/mercurial/debugcommands.py b/mercurial/debugcommands.py
--- a/mercurial/debugcommands.py
+++ b/mercurial/debugcommands.py
@@ -1001,7 +1001,7 @@ 
 @command(
     b'debugdiscovery',
     [
-        (b'', b'old', None, _(b'use old-style discovery')),
+        (b'', b'protocol', b'set', _(b'use given discovery protocol')),
         (
             b'',
             b'nonheads',
@@ -1116,7 +1116,7 @@ 
         repo = repo.filtered(b'debug-discovery-local-filter')
 
     data = {}
-    if opts.get(b'old'):
+    if opts[b'protocol'] == b'tree':
 
         def doit(pushedrevs, remoteheads, remote=remote):
             if not util.safehasattr(remote, b'branches'):
@@ -1137,7 +1137,7 @@ 
                 common = {clnode(r) for r in common}
             return common, hds
 
-    else:
+    elif opts[b'protocol'] == b'set':
 
         def doit(pushedrevs, remoteheads, remote=remote):
             nodes = None
@@ -1149,6 +1149,19 @@ 
             )
             return common, hds
 
+    elif opts[b'protocol'] == b'iblt':
+
+        def doit(pushedrevs, remoteheads, remote=remote):
+            success, _, _, common, rheads = setdiscovery.findsetdifferences(ui, repo, remote)
+            assert success
+            return common, rheads
+
+    else:
+        raise error.InputError(
+            _(b"Unknown or unsupported discovery protocol")
+        )
+
+
     remoterevs, _checkout = hg.addbranchrevs(repo, remote, branches, revs=None)
     localrevs = opts[b'rev']
 
@@ -1220,6 +1233,8 @@ 
     fm.data(**pycompat.strkwargs(data))
     # display discovery summary
     fm.plain(b"elapsed time:  %(elapsed)f seconds\n" % data)
+    fm.end()
+    return
     fm.plain(b"round-trips:           %(total-roundtrips)9d\n" % data)
     fm.plain(b"queries:               %(total-queries)9d\n" % data)
     fm.plain(b"heads summary:\n")
diff --git a/mercurial/bundlerepo.py b/mercurial/bundlerepo.py
--- a/mercurial/bundlerepo.py
+++ b/mercurial/bundlerepo.py
@@ -583,18 +583,31 @@ 
       the changes; it closes both the original "peer" and the one returned
       here.
     """
-    tmp = discovery.findcommonincoming(repo, peer, heads=onlyheads, force=force)
-    common, incoming, rheads = tmp
-    if not incoming:
-        try:
-            if bundlename:
-                os.unlink(bundlename)
-        except OSError:
-            pass
-        return repo, [], peer.close
+    success = False
+    if not onlyheads:
+        from . import setdiscovery
+        success, mychanges, theirchanges, common, rheads = setdiscovery.findsetdifferences(ui, repo, peer)
+        if success and not theirchanges:
+            try:
+                if bundlename:
+                    os.unlink(bundlename)
+            except OSError:
+                pass
+            return repo, [], peer.close
 
-    commonset = set(common)
-    rheads = [x for x in rheads if x not in commonset]
+    if not success:
+        tmp = discovery.findcommonincoming(repo, peer, heads=onlyheads, force=force)
+        common, incoming, rheads = tmp
+        if not incoming:
+            try:
+                if bundlename:
+                    os.unlink(bundlename)
+            except OSError:
+                pass
+            return repo, [], peer.close
+
+        commonset = set(common)
+        rheads = [x for x in rheads if x not in commonset]
 
     bundle = None
     bundlerepo = None