cborutil: implement support for streaming encoding, bytestring decoding
authorGregory Szorc <gregory.szorc@gmail.com>
Sat, 14 Apr 2018 16:36:15 -0700
changeset 37711 65a23cc8e75b
parent 37710 0a5fe2a08e82
child 37712 a728e3695325
cborutil: implement support for streaming encoding, bytestring decoding The vendored cbor2 package is... a bit disappointing. On the encoding side, it insists that you pass it something with a write() to send data to. That means if you want to emit data to a generator, you have to construct an e.g. io.BytesIO(), write() to it, then get the data back out. There can be non-trivial overhead involved. The encoder also doesn't support indefinite types - bytestrings, arrays, and maps that don't have a known length. Again, this is really unfortunate because it requires you to buffer the entire source and destination in memory to encode large things. On the decoding side, it supports reading indefinite length types. But it buffers them completely before returning. More sadness. This commit implements "streaming" encoders for various CBOR types. Encoding emits a generator of hunks. So you can efficiently stream encoded data elsewhere. It also implements support for emitting indefinite length bytestrings, arrays, and maps. On the decoding side, we only implement support for decoding an indefinite length bytestring from a file object. It will emit a generator of raw chunks from the source. I didn't want to reinvent so many wheels. But profiling the wire protocol revealed that the overhead of constructing io.BytesIO() instances to temporarily hold results has a non-trivial overhead. We're talking >15% of execution time for operations like "transfer the fulltexts of all files in a revision." So I can justify this effort. Fortunately, CBOR is a relatively straightforward format. And we have a reference implementation in the repo we can test against. Differential Revision: https://phab.mercurial-scm.org/D3303
contrib/import-checker.py
mercurial/utils/cborutil.py
tests/test-cbor.py
--- a/contrib/import-checker.py	Sun Apr 15 22:28:03 2018 -0400
+++ b/contrib/import-checker.py	Sat Apr 14 16:36:15 2018 -0700
@@ -36,6 +36,8 @@
     'mercurial.pure.parsers',
     # third-party imports should be directly imported
     'mercurial.thirdparty',
+    'mercurial.thirdparty.cbor',
+    'mercurial.thirdparty.cbor.cbor2',
     'mercurial.thirdparty.zope',
     'mercurial.thirdparty.zope.interface',
 )
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/mercurial/utils/cborutil.py	Sat Apr 14 16:36:15 2018 -0700
@@ -0,0 +1,258 @@
+# cborutil.py - CBOR extensions
+#
+# Copyright 2018 Gregory Szorc <gregory.szorc@gmail.com>
+#
+# This software may be used and distributed according to the terms of the
+# GNU General Public License version 2 or any later version.
+
+from __future__ import absolute_import
+
+import struct
+
+from ..thirdparty.cbor.cbor2 import (
+    decoder as decodermod,
+)
+
+# Very short very of RFC 7049...
+#
+# Each item begins with a byte. The 3 high bits of that byte denote the
+# "major type." The lower 5 bits denote the "subtype." Each major type
+# has its own encoding mechanism.
+#
+# Most types have lengths. However, bytestring, string, array, and map
+# can be indefinite length. These are denotes by a subtype with value 31.
+# Sub-components of those types then come afterwards and are terminated
+# by a "break" byte.
+
+MAJOR_TYPE_UINT = 0
+MAJOR_TYPE_NEGINT = 1
+MAJOR_TYPE_BYTESTRING = 2
+MAJOR_TYPE_STRING = 3
+MAJOR_TYPE_ARRAY = 4
+MAJOR_TYPE_MAP = 5
+MAJOR_TYPE_SEMANTIC = 6
+MAJOR_TYPE_SPECIAL = 7
+
+SUBTYPE_MASK = 0b00011111
+
+SUBTYPE_HALF_FLOAT = 25
+SUBTYPE_SINGLE_FLOAT = 26
+SUBTYPE_DOUBLE_FLOAT = 27
+SUBTYPE_INDEFINITE = 31
+
+# Indefinite types begin with their major type ORd with information value 31.
+BEGIN_INDEFINITE_BYTESTRING = struct.pack(
+    r'>B', MAJOR_TYPE_BYTESTRING << 5 | SUBTYPE_INDEFINITE)
+BEGIN_INDEFINITE_ARRAY = struct.pack(
+    r'>B', MAJOR_TYPE_ARRAY << 5 | SUBTYPE_INDEFINITE)
+BEGIN_INDEFINITE_MAP = struct.pack(
+    r'>B', MAJOR_TYPE_MAP << 5 | SUBTYPE_INDEFINITE)
+
+ENCODED_LENGTH_1 = struct.Struct(r'>B')
+ENCODED_LENGTH_2 = struct.Struct(r'>BB')
+ENCODED_LENGTH_3 = struct.Struct(r'>BH')
+ENCODED_LENGTH_4 = struct.Struct(r'>BL')
+ENCODED_LENGTH_5 = struct.Struct(r'>BQ')
+
+# The break ends an indefinite length item.
+BREAK = b'\xff'
+BREAK_INT = 255
+
+def encodelength(majortype, length):
+    """Obtain a value encoding the major type and its length."""
+    if length < 24:
+        return ENCODED_LENGTH_1.pack(majortype << 5 | length)
+    elif length < 256:
+        return ENCODED_LENGTH_2.pack(majortype << 5 | 24, length)
+    elif length < 65536:
+        return ENCODED_LENGTH_3.pack(majortype << 5 | 25, length)
+    elif length < 4294967296:
+        return ENCODED_LENGTH_4.pack(majortype << 5 | 26, length)
+    else:
+        return ENCODED_LENGTH_5.pack(majortype << 5 | 27, length)
+
+def streamencodebytestring(v):
+    yield encodelength(MAJOR_TYPE_BYTESTRING, len(v))
+    yield v
+
+def streamencodebytestringfromiter(it):
+    """Convert an iterator of chunks to an indefinite bytestring.
+
+    Given an input that is iterable and each element in the iterator is
+    representable as bytes, emit an indefinite length bytestring.
+    """
+    yield BEGIN_INDEFINITE_BYTESTRING
+
+    for chunk in it:
+        yield encodelength(MAJOR_TYPE_BYTESTRING, len(chunk))
+        yield chunk
+
+    yield BREAK
+
+def streamencodeindefinitebytestring(source, chunksize=65536):
+    """Given a large source buffer, emit as an indefinite length bytestring.
+
+    This is a generator of chunks constituting the encoded CBOR data.
+    """
+    yield BEGIN_INDEFINITE_BYTESTRING
+
+    i = 0
+    l = len(source)
+
+    while True:
+        chunk = source[i:i + chunksize]
+        i += len(chunk)
+
+        yield encodelength(MAJOR_TYPE_BYTESTRING, len(chunk))
+        yield chunk
+
+        if i >= l:
+            break
+
+    yield BREAK
+
+def streamencodeint(v):
+    if v >= 18446744073709551616 or v < -18446744073709551616:
+        raise ValueError('big integers not supported')
+
+    if v >= 0:
+        yield encodelength(MAJOR_TYPE_UINT, v)
+    else:
+        yield encodelength(MAJOR_TYPE_NEGINT, abs(v) - 1)
+
+def streamencodearray(l):
+    """Encode a known size iterable to an array."""
+
+    yield encodelength(MAJOR_TYPE_ARRAY, len(l))
+
+    for i in l:
+        for chunk in streamencode(i):
+            yield chunk
+
+def streamencodearrayfromiter(it):
+    """Encode an iterator of items to an indefinite length array."""
+
+    yield BEGIN_INDEFINITE_ARRAY
+
+    for i in it:
+        for chunk in streamencode(i):
+            yield chunk
+
+    yield BREAK
+
+def streamencodeset(s):
+    # https://www.iana.org/assignments/cbor-tags/cbor-tags.xhtml defines
+    # semantic tag 258 for finite sets.
+    yield encodelength(MAJOR_TYPE_SEMANTIC, 258)
+
+    for chunk in streamencodearray(sorted(s)):
+        yield chunk
+
+def streamencodemap(d):
+    """Encode dictionary to a generator.
+
+    Does not supporting indefinite length dictionaries.
+    """
+    yield encodelength(MAJOR_TYPE_MAP, len(d))
+
+    for key, value in sorted(d.iteritems()):
+        for chunk in streamencode(key):
+            yield chunk
+        for chunk in streamencode(value):
+            yield chunk
+
+def streamencodemapfromiter(it):
+    """Given an iterable of (key, value), encode to an indefinite length map."""
+    yield BEGIN_INDEFINITE_MAP
+
+    for key, value in it:
+        for chunk in streamencode(key):
+            yield chunk
+        for chunk in streamencode(value):
+            yield chunk
+
+    yield BREAK
+
+def streamencodebool(b):
+    # major type 7, simple value 20 and 21.
+    yield b'\xf5' if b else b'\xf4'
+
+def streamencodenone(v):
+    # major type 7, simple value 22.
+    yield b'\xf6'
+
+STREAM_ENCODERS = {
+    bytes: streamencodebytestring,
+    int: streamencodeint,
+    list: streamencodearray,
+    tuple: streamencodearray,
+    dict: streamencodemap,
+    set: streamencodeset,
+    bool: streamencodebool,
+    type(None): streamencodenone,
+}
+
+def streamencode(v):
+    """Encode a value in a streaming manner.
+
+    Given an input object, encode it to CBOR recursively.
+
+    Returns a generator of CBOR encoded bytes. There is no guarantee
+    that each emitted chunk fully decodes to a value or sub-value.
+
+    Encoding is deterministic - unordered collections are sorted.
+    """
+    fn = STREAM_ENCODERS.get(v.__class__)
+
+    if not fn:
+        raise ValueError('do not know how to encode %s' % type(v))
+
+    return fn(v)
+
+def readindefinitebytestringtoiter(fh, expectheader=True):
+    """Read an indefinite bytestring to a generator.
+
+    Receives an object with a ``read(X)`` method to read N bytes.
+
+    If ``expectheader`` is True, it is expected that the first byte read
+    will represent an indefinite length bytestring. Otherwise, we
+    expect the first byte to be part of the first bytestring chunk.
+    """
+    read = fh.read
+    decodeuint = decodermod.decode_uint
+    byteasinteger = decodermod.byte_as_integer
+
+    if expectheader:
+        initial = decodermod.byte_as_integer(read(1))
+
+        majortype = initial >> 5
+        subtype = initial & SUBTYPE_MASK
+
+        if majortype != MAJOR_TYPE_BYTESTRING:
+            raise decodermod.CBORDecodeError(
+                'expected major type %d; got %d' % (MAJOR_TYPE_BYTESTRING,
+                                                    majortype))
+
+        if subtype != SUBTYPE_INDEFINITE:
+            raise decodermod.CBORDecodeError(
+                'expected indefinite subtype; got %d' % subtype)
+
+    # The indefinite bytestring is composed of chunks of normal bytestrings.
+    # Read chunks until we hit a BREAK byte.
+
+    while True:
+        # We need to sniff for the BREAK byte.
+        initial = byteasinteger(read(1))
+
+        if initial == BREAK_INT:
+            break
+
+        length = decodeuint(fh, initial & SUBTYPE_MASK)
+        chunk = read(length)
+
+        if len(chunk) != length:
+            raise decodermod.CBORDecodeError(
+                'failed to read bytestring chunk: got %d bytes; expected %d' % (
+                    len(chunk), length))
+
+        yield chunk
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/tests/test-cbor.py	Sat Apr 14 16:36:15 2018 -0700
@@ -0,0 +1,210 @@
+from __future__ import absolute_import
+
+import io
+import unittest
+
+from mercurial.thirdparty import (
+    cbor,
+)
+from mercurial.utils import (
+    cborutil,
+)
+
+def loadit(it):
+    return cbor.loads(b''.join(it))
+
+class BytestringTests(unittest.TestCase):
+    def testsimple(self):
+        self.assertEqual(
+            list(cborutil.streamencode(b'foobar')),
+            [b'\x46', b'foobar'])
+
+        self.assertEqual(
+            loadit(cborutil.streamencode(b'foobar')),
+            b'foobar')
+
+    def testlong(self):
+        source = b'x' * 1048576
+
+        self.assertEqual(loadit(cborutil.streamencode(source)), source)
+
+    def testfromiter(self):
+        # This is the example from RFC 7049 Section 2.2.2.
+        source = [b'\xaa\xbb\xcc\xdd', b'\xee\xff\x99']
+
+        self.assertEqual(
+            list(cborutil.streamencodebytestringfromiter(source)),
+            [
+                b'\x5f',
+                b'\x44',
+                b'\xaa\xbb\xcc\xdd',
+                b'\x43',
+                b'\xee\xff\x99',
+                b'\xff',
+            ])
+
+        self.assertEqual(
+            loadit(cborutil.streamencodebytestringfromiter(source)),
+            b''.join(source))
+
+    def testfromiterlarge(self):
+        source = [b'a' * 16, b'b' * 128, b'c' * 1024, b'd' * 1048576]
+
+        self.assertEqual(
+            loadit(cborutil.streamencodebytestringfromiter(source)),
+            b''.join(source))
+
+    def testindefinite(self):
+        source = b'\x00\x01\x02\x03' + b'\xff' * 16384
+
+        it = cborutil.streamencodeindefinitebytestring(source, chunksize=2)
+
+        self.assertEqual(next(it), b'\x5f')
+        self.assertEqual(next(it), b'\x42')
+        self.assertEqual(next(it), b'\x00\x01')
+        self.assertEqual(next(it), b'\x42')
+        self.assertEqual(next(it), b'\x02\x03')
+        self.assertEqual(next(it), b'\x42')
+        self.assertEqual(next(it), b'\xff\xff')
+
+        dest = b''.join(cborutil.streamencodeindefinitebytestring(
+            source, chunksize=42))
+        self.assertEqual(cbor.loads(dest), b''.join(source))
+
+    def testreadtoiter(self):
+        source = io.BytesIO(b'\x5f\x44\xaa\xbb\xcc\xdd\x43\xee\xff\x99\xff')
+
+        it = cborutil.readindefinitebytestringtoiter(source)
+        self.assertEqual(next(it), b'\xaa\xbb\xcc\xdd')
+        self.assertEqual(next(it), b'\xee\xff\x99')
+
+        with self.assertRaises(StopIteration):
+            next(it)
+
+class IntTests(unittest.TestCase):
+    def testsmall(self):
+        self.assertEqual(list(cborutil.streamencode(0)), [b'\x00'])
+        self.assertEqual(list(cborutil.streamencode(1)), [b'\x01'])
+        self.assertEqual(list(cborutil.streamencode(2)), [b'\x02'])
+        self.assertEqual(list(cborutil.streamencode(3)), [b'\x03'])
+        self.assertEqual(list(cborutil.streamencode(4)), [b'\x04'])
+
+    def testnegativesmall(self):
+        self.assertEqual(list(cborutil.streamencode(-1)), [b'\x20'])
+        self.assertEqual(list(cborutil.streamencode(-2)), [b'\x21'])
+        self.assertEqual(list(cborutil.streamencode(-3)), [b'\x22'])
+        self.assertEqual(list(cborutil.streamencode(-4)), [b'\x23'])
+        self.assertEqual(list(cborutil.streamencode(-5)), [b'\x24'])
+
+    def testrange(self):
+        for i in range(-70000, 70000, 10):
+            self.assertEqual(
+                b''.join(cborutil.streamencode(i)),
+                cbor.dumps(i))
+
+class ArrayTests(unittest.TestCase):
+    def testempty(self):
+        self.assertEqual(list(cborutil.streamencode([])), [b'\x80'])
+        self.assertEqual(loadit(cborutil.streamencode([])), [])
+
+    def testbasic(self):
+        source = [b'foo', b'bar', 1, -10]
+
+        self.assertEqual(list(cborutil.streamencode(source)), [
+            b'\x84', b'\x43', b'foo', b'\x43', b'bar', b'\x01', b'\x29'])
+
+    def testemptyfromiter(self):
+        self.assertEqual(b''.join(cborutil.streamencodearrayfromiter([])),
+                         b'\x9f\xff')
+
+    def testfromiter1(self):
+        source = [b'foo']
+
+        self.assertEqual(list(cborutil.streamencodearrayfromiter(source)), [
+            b'\x9f',
+            b'\x43', b'foo',
+            b'\xff',
+        ])
+
+        dest = b''.join(cborutil.streamencodearrayfromiter(source))
+        self.assertEqual(cbor.loads(dest), source)
+
+    def testtuple(self):
+        source = (b'foo', None, 42)
+
+        self.assertEqual(cbor.loads(b''.join(cborutil.streamencode(source))),
+                         list(source))
+
+class SetTests(unittest.TestCase):
+    def testempty(self):
+        self.assertEqual(list(cborutil.streamencode(set())), [
+            b'\xd9\x01\x02',
+            b'\x80',
+        ])
+
+    def testset(self):
+        source = {b'foo', None, 42}
+
+        self.assertEqual(cbor.loads(b''.join(cborutil.streamencode(source))),
+                         source)
+
+class BoolTests(unittest.TestCase):
+    def testbasic(self):
+        self.assertEqual(list(cborutil.streamencode(True)),  [b'\xf5'])
+        self.assertEqual(list(cborutil.streamencode(False)), [b'\xf4'])
+
+        self.assertIs(loadit(cborutil.streamencode(True)), True)
+        self.assertIs(loadit(cborutil.streamencode(False)), False)
+
+class NoneTests(unittest.TestCase):
+    def testbasic(self):
+        self.assertEqual(list(cborutil.streamencode(None)), [b'\xf6'])
+
+        self.assertIs(loadit(cborutil.streamencode(None)), None)
+
+class MapTests(unittest.TestCase):
+    def testempty(self):
+        self.assertEqual(list(cborutil.streamencode({})), [b'\xa0'])
+        self.assertEqual(loadit(cborutil.streamencode({})), {})
+
+    def testemptyindefinite(self):
+        self.assertEqual(list(cborutil.streamencodemapfromiter([])), [
+            b'\xbf', b'\xff'])
+
+        self.assertEqual(loadit(cborutil.streamencodemapfromiter([])), {})
+
+    def testone(self):
+        source = {b'foo': b'bar'}
+        self.assertEqual(list(cborutil.streamencode(source)), [
+            b'\xa1', b'\x43', b'foo', b'\x43', b'bar'])
+
+        self.assertEqual(loadit(cborutil.streamencode(source)), source)
+
+    def testmultiple(self):
+        source = {
+            b'foo': b'bar',
+            b'baz': b'value1',
+        }
+
+        self.assertEqual(loadit(cborutil.streamencode(source)), source)
+
+        self.assertEqual(
+            loadit(cborutil.streamencodemapfromiter(source.items())),
+            source)
+
+    def testcomplex(self):
+        source = {
+            b'key': 1,
+            2: -10,
+        }
+
+        self.assertEqual(loadit(cborutil.streamencode(source)),
+                         source)
+
+        self.assertEqual(
+            loadit(cborutil.streamencodemapfromiter(source.items())),
+            source)
+
+if __name__ == '__main__':
+    import silenttestrunner
+    silenttestrunner.main(__name__)