tests/test-cbor.py
changeset 43076 2372284d9457
parent 41549 1ea1bba1c5be
child 43996 6dbb18e1ac8d
equal deleted inserted replaced
43075:57875cf423c9 43076:2372284d9457
     3 import os
     3 import os
     4 import sys
     4 import sys
     5 import unittest
     5 import unittest
     6 
     6 
     7 # TODO migrate to canned cbor test strings and stop using thirdparty.cbor
     7 # TODO migrate to canned cbor test strings and stop using thirdparty.cbor
     8 tpp = os.path.normpath(os.path.join(os.path.dirname(__file__),
     8 tpp = os.path.normpath(
     9                                     '..', 'mercurial', 'thirdparty'))
     9     os.path.join(os.path.dirname(__file__), '..', 'mercurial', 'thirdparty')
       
    10 )
    10 if not os.path.exists(tpp):
    11 if not os.path.exists(tpp):
    11     # skip, not in a repo
    12     # skip, not in a repo
    12     sys.exit(80)
    13     sys.exit(80)
    13 sys.path[0:0] = [tpp]
    14 sys.path[0:0] = [tpp]
    14 import cbor
    15 import cbor
       
    16 
    15 del sys.path[0]
    17 del sys.path[0]
    16 
    18 
    17 from mercurial.utils import (
    19 from mercurial.utils import cborutil
    18     cborutil,
    20 
    19 )
       
    20 
    21 
    21 class TestCase(unittest.TestCase):
    22 class TestCase(unittest.TestCase):
    22     if not getattr(unittest.TestCase, 'assertRaisesRegex', False):
    23     if not getattr(unittest.TestCase, 'assertRaisesRegex', False):
    23         # Python 3.7 deprecates the regex*p* version, but 2.7 lacks
    24         # Python 3.7 deprecates the regex*p* version, but 2.7 lacks
    24         # the regex version.
    25         # the regex version.
    25         assertRaisesRegex = (# camelcase-required
    26         assertRaisesRegex = (  # camelcase-required
    26             unittest.TestCase.assertRaisesRegexp)
    27             unittest.TestCase.assertRaisesRegexp
       
    28         )
       
    29 
    27 
    30 
    28 def loadit(it):
    31 def loadit(it):
    29     return cbor.loads(b''.join(it))
    32     return cbor.loads(b''.join(it))
    30 
    33 
       
    34 
    31 class BytestringTests(TestCase):
    35 class BytestringTests(TestCase):
    32     def testsimple(self):
    36     def testsimple(self):
    33         self.assertEqual(
    37         self.assertEqual(
    34             list(cborutil.streamencode(b'foobar')),
    38             list(cborutil.streamencode(b'foobar')), [b'\x46', b'foobar']
    35             [b'\x46', b'foobar'])
    39         )
    36 
    40 
    37         self.assertEqual(
    41         self.assertEqual(loadit(cborutil.streamencode(b'foobar')), b'foobar')
    38             loadit(cborutil.streamencode(b'foobar')),
    42 
    39             b'foobar')
    43         self.assertEqual(cborutil.decodeall(b'\x46foobar'), [b'foobar'])
    40 
    44 
    41         self.assertEqual(cborutil.decodeall(b'\x46foobar'),
    45         self.assertEqual(
    42                          [b'foobar'])
    46             cborutil.decodeall(b'\x46foobar\x45fizbi'), [b'foobar', b'fizbi']
    43 
    47         )
    44         self.assertEqual(cborutil.decodeall(b'\x46foobar\x45fizbi'),
       
    45                          [b'foobar', b'fizbi'])
       
    46 
    48 
    47     def testlong(self):
    49     def testlong(self):
    48         source = b'x' * 1048576
    50         source = b'x' * 1048576
    49 
    51 
    50         self.assertEqual(loadit(cborutil.streamencode(source)), source)
    52         self.assertEqual(loadit(cborutil.streamencode(source)), source)
    63                 b'\x44',
    65                 b'\x44',
    64                 b'\xaa\xbb\xcc\xdd',
    66                 b'\xaa\xbb\xcc\xdd',
    65                 b'\x43',
    67                 b'\x43',
    66                 b'\xee\xff\x99',
    68                 b'\xee\xff\x99',
    67                 b'\xff',
    69                 b'\xff',
    68             ])
    70             ],
       
    71         )
    69 
    72 
    70         self.assertEqual(
    73         self.assertEqual(
    71             loadit(cborutil.streamencodebytestringfromiter(source)),
    74             loadit(cborutil.streamencodebytestringfromiter(source)),
    72             b''.join(source))
    75             b''.join(source),
    73 
    76         )
    74         self.assertEqual(cborutil.decodeall(b'\x5f\x44\xaa\xbb\xcc\xdd'
    77 
    75                                             b'\x43\xee\xff\x99\xff'),
    78         self.assertEqual(
    76                          [b'\xaa\xbb\xcc\xdd', b'\xee\xff\x99', b''])
    79             cborutil.decodeall(
       
    80                 b'\x5f\x44\xaa\xbb\xcc\xdd' b'\x43\xee\xff\x99\xff'
       
    81             ),
       
    82             [b'\xaa\xbb\xcc\xdd', b'\xee\xff\x99', b''],
       
    83         )
    77 
    84 
    78         for i, chunk in enumerate(
    85         for i, chunk in enumerate(
    79             cborutil.decodeall(b'\x5f\x44\xaa\xbb\xcc\xdd'
    86             cborutil.decodeall(
    80                                b'\x43\xee\xff\x99\xff')):
    87                 b'\x5f\x44\xaa\xbb\xcc\xdd' b'\x43\xee\xff\x99\xff'
       
    88             )
       
    89         ):
    81             self.assertIsInstance(chunk, cborutil.bytestringchunk)
    90             self.assertIsInstance(chunk, cborutil.bytestringchunk)
    82 
    91 
    83             if i == 0:
    92             if i == 0:
    84                 self.assertTrue(chunk.isfirst)
    93                 self.assertTrue(chunk.isfirst)
    85             else:
    94             else:
    93     def testfromiterlarge(self):
   102     def testfromiterlarge(self):
    94         source = [b'a' * 16, b'b' * 128, b'c' * 1024, b'd' * 1048576]
   103         source = [b'a' * 16, b'b' * 128, b'c' * 1024, b'd' * 1048576]
    95 
   104 
    96         self.assertEqual(
   105         self.assertEqual(
    97             loadit(cborutil.streamencodebytestringfromiter(source)),
   106             loadit(cborutil.streamencodebytestringfromiter(source)),
    98             b''.join(source))
   107             b''.join(source),
       
   108         )
    99 
   109 
   100     def testindefinite(self):
   110     def testindefinite(self):
   101         source = b'\x00\x01\x02\x03' + b'\xff' * 16384
   111         source = b'\x00\x01\x02\x03' + b'\xff' * 16384
   102 
   112 
   103         it = cborutil.streamencodeindefinitebytestring(source, chunksize=2)
   113         it = cborutil.streamencodeindefinitebytestring(source, chunksize=2)
   108         self.assertEqual(next(it), b'\x42')
   118         self.assertEqual(next(it), b'\x42')
   109         self.assertEqual(next(it), b'\x02\x03')
   119         self.assertEqual(next(it), b'\x02\x03')
   110         self.assertEqual(next(it), b'\x42')
   120         self.assertEqual(next(it), b'\x42')
   111         self.assertEqual(next(it), b'\xff\xff')
   121         self.assertEqual(next(it), b'\xff\xff')
   112 
   122 
   113         dest = b''.join(cborutil.streamencodeindefinitebytestring(
   123         dest = b''.join(
   114             source, chunksize=42))
   124             cborutil.streamencodeindefinitebytestring(source, chunksize=42)
       
   125         )
   115         self.assertEqual(cbor.loads(dest), source)
   126         self.assertEqual(cbor.loads(dest), source)
   116 
   127 
   117         self.assertEqual(b''.join(cborutil.decodeall(dest)), source)
   128         self.assertEqual(b''.join(cborutil.decodeall(dest)), source)
   118 
   129 
   119         for chunk in cborutil.decodeall(dest):
   130         for chunk in cborutil.decodeall(dest):
   138             elif len(source) < 65536:
   149             elif len(source) < 65536:
   139                 hlen = 3
   150                 hlen = 3
   140             elif len(source) < 1048576:
   151             elif len(source) < 1048576:
   141                 hlen = 5
   152                 hlen = 5
   142 
   153 
   143             self.assertEqual(cborutil.decodeitem(encoded),
   154             self.assertEqual(
   144                              (True, source, hlen + len(source),
   155                 cborutil.decodeitem(encoded),
   145                               cborutil.SPECIAL_NONE))
   156                 (True, source, hlen + len(source), cborutil.SPECIAL_NONE),
       
   157             )
   146 
   158 
   147     def testpartialdecode(self):
   159     def testpartialdecode(self):
   148         encoded = b''.join(cborutil.streamencode(b'foobar'))
   160         encoded = b''.join(cborutil.streamencode(b'foobar'))
   149 
   161 
   150         self.assertEqual(cborutil.decodeitem(encoded[0:1]),
   162         self.assertEqual(
   151                          (False, None, -6, cborutil.SPECIAL_NONE))
   163             cborutil.decodeitem(encoded[0:1]),
   152         self.assertEqual(cborutil.decodeitem(encoded[0:2]),
   164             (False, None, -6, cborutil.SPECIAL_NONE),
   153                          (False, None, -5, cborutil.SPECIAL_NONE))
   165         )
   154         self.assertEqual(cborutil.decodeitem(encoded[0:3]),
   166         self.assertEqual(
   155                          (False, None, -4, cborutil.SPECIAL_NONE))
   167             cborutil.decodeitem(encoded[0:2]),
   156         self.assertEqual(cborutil.decodeitem(encoded[0:4]),
   168             (False, None, -5, cborutil.SPECIAL_NONE),
   157                          (False, None, -3, cborutil.SPECIAL_NONE))
   169         )
   158         self.assertEqual(cborutil.decodeitem(encoded[0:5]),
   170         self.assertEqual(
   159                          (False, None, -2, cborutil.SPECIAL_NONE))
   171             cborutil.decodeitem(encoded[0:3]),
   160         self.assertEqual(cborutil.decodeitem(encoded[0:6]),
   172             (False, None, -4, cborutil.SPECIAL_NONE),
   161                          (False, None, -1, cborutil.SPECIAL_NONE))
   173         )
   162         self.assertEqual(cborutil.decodeitem(encoded[0:7]),
   174         self.assertEqual(
   163                          (True, b'foobar', 7, cborutil.SPECIAL_NONE))
   175             cborutil.decodeitem(encoded[0:4]),
       
   176             (False, None, -3, cborutil.SPECIAL_NONE),
       
   177         )
       
   178         self.assertEqual(
       
   179             cborutil.decodeitem(encoded[0:5]),
       
   180             (False, None, -2, cborutil.SPECIAL_NONE),
       
   181         )
       
   182         self.assertEqual(
       
   183             cborutil.decodeitem(encoded[0:6]),
       
   184             (False, None, -1, cborutil.SPECIAL_NONE),
       
   185         )
       
   186         self.assertEqual(
       
   187             cborutil.decodeitem(encoded[0:7]),
       
   188             (True, b'foobar', 7, cborutil.SPECIAL_NONE),
       
   189         )
   164 
   190 
   165     def testpartialdecodevariouslengths(self):
   191     def testpartialdecodevariouslengths(self):
   166         lens = [
   192         lens = [
   167             2,
   193             2,
   168             3,
   194             3,
   190         ]
   216         ]
   191 
   217 
   192         for size in lens:
   218         for size in lens:
   193             if size < 24:
   219             if size < 24:
   194                 hlen = 1
   220                 hlen = 1
   195             elif size < 2**8:
   221             elif size < 2 ** 8:
   196                 hlen = 2
   222                 hlen = 2
   197             elif size < 2**16:
   223             elif size < 2 ** 16:
   198                 hlen = 3
   224                 hlen = 3
   199             elif size < 2**32:
   225             elif size < 2 ** 32:
   200                 hlen = 5
   226                 hlen = 5
   201             else:
   227             else:
   202                 assert False
   228                 assert False
   203 
   229 
   204             source = b'x' * size
   230             source = b'x' * size
   205             encoded = b''.join(cborutil.streamencode(source))
   231             encoded = b''.join(cborutil.streamencode(source))
   206 
   232 
   207             res = cborutil.decodeitem(encoded[0:1])
   233             res = cborutil.decodeitem(encoded[0:1])
   208 
   234 
   209             if hlen > 1:
   235             if hlen > 1:
   210                 self.assertEqual(res, (False, None, -(hlen - 1),
   236                 self.assertEqual(
   211                                        cborutil.SPECIAL_NONE))
   237                     res, (False, None, -(hlen - 1), cborutil.SPECIAL_NONE)
       
   238                 )
   212             else:
   239             else:
   213                 self.assertEqual(res, (False, None, -(size + hlen - 1),
   240                 self.assertEqual(
   214                                        cborutil.SPECIAL_NONE))
   241                     res,
       
   242                     (False, None, -(size + hlen - 1), cborutil.SPECIAL_NONE),
       
   243                 )
   215 
   244 
   216             # Decoding partial header reports remaining header size.
   245             # Decoding partial header reports remaining header size.
   217             for i in range(hlen - 1):
   246             for i in range(hlen - 1):
   218                 self.assertEqual(cborutil.decodeitem(encoded[0:i + 1]),
   247                 self.assertEqual(
   219                                  (False, None, -(hlen - i - 1),
   248                     cborutil.decodeitem(encoded[0 : i + 1]),
   220                                   cborutil.SPECIAL_NONE))
   249                     (False, None, -(hlen - i - 1), cborutil.SPECIAL_NONE),
       
   250                 )
   221 
   251 
   222             # Decoding complete header reports item size.
   252             # Decoding complete header reports item size.
   223             self.assertEqual(cborutil.decodeitem(encoded[0:hlen]),
   253             self.assertEqual(
   224                              (False, None, -size, cborutil.SPECIAL_NONE))
   254                 cborutil.decodeitem(encoded[0:hlen]),
       
   255                 (False, None, -size, cborutil.SPECIAL_NONE),
       
   256             )
   225 
   257 
   226             # Decoding single byte after header reports item size - 1
   258             # Decoding single byte after header reports item size - 1
   227             self.assertEqual(cborutil.decodeitem(encoded[0:hlen + 1]),
   259             self.assertEqual(
   228                              (False, None, -(size - 1), cborutil.SPECIAL_NONE))
   260                 cborutil.decodeitem(encoded[0 : hlen + 1]),
       
   261                 (False, None, -(size - 1), cborutil.SPECIAL_NONE),
       
   262             )
   229 
   263 
   230             # Decoding all but the last byte reports -1 needed.
   264             # Decoding all but the last byte reports -1 needed.
   231             self.assertEqual(cborutil.decodeitem(encoded[0:hlen + size - 1]),
   265             self.assertEqual(
   232                              (False, None, -1, cborutil.SPECIAL_NONE))
   266                 cborutil.decodeitem(encoded[0 : hlen + size - 1]),
       
   267                 (False, None, -1, cborutil.SPECIAL_NONE),
       
   268             )
   233 
   269 
   234             # Decoding last byte retrieves value.
   270             # Decoding last byte retrieves value.
   235             self.assertEqual(cborutil.decodeitem(encoded[0:hlen + size]),
   271             self.assertEqual(
   236                              (True, source, hlen + size, cborutil.SPECIAL_NONE))
   272                 cborutil.decodeitem(encoded[0 : hlen + size]),
       
   273                 (True, source, hlen + size, cborutil.SPECIAL_NONE),
       
   274             )
   237 
   275 
   238     def testindefinitepartialdecode(self):
   276     def testindefinitepartialdecode(self):
   239         encoded = b''.join(cborutil.streamencodebytestringfromiter(
   277         encoded = b''.join(
   240             [b'foobar', b'biz']))
   278             cborutil.streamencodebytestringfromiter([b'foobar', b'biz'])
       
   279         )
   241 
   280 
   242         # First item should be begin of bytestring special.
   281         # First item should be begin of bytestring special.
   243         self.assertEqual(cborutil.decodeitem(encoded[0:1]),
   282         self.assertEqual(
   244                          (True, None, 1,
   283             cborutil.decodeitem(encoded[0:1]),
   245                           cborutil.SPECIAL_START_INDEFINITE_BYTESTRING))
   284             (True, None, 1, cborutil.SPECIAL_START_INDEFINITE_BYTESTRING),
       
   285         )
   246 
   286 
   247         # Second item should be the first chunk. But only available when
   287         # Second item should be the first chunk. But only available when
   248         # we give it 7 bytes (1 byte header + 6 byte chunk).
   288         # we give it 7 bytes (1 byte header + 6 byte chunk).
   249         self.assertEqual(cborutil.decodeitem(encoded[1:2]),
   289         self.assertEqual(
   250                          (False, None, -6, cborutil.SPECIAL_NONE))
   290             cborutil.decodeitem(encoded[1:2]),
   251         self.assertEqual(cborutil.decodeitem(encoded[1:3]),
   291             (False, None, -6, cborutil.SPECIAL_NONE),
   252                          (False, None, -5, cborutil.SPECIAL_NONE))
   292         )
   253         self.assertEqual(cborutil.decodeitem(encoded[1:4]),
   293         self.assertEqual(
   254                          (False, None, -4, cborutil.SPECIAL_NONE))
   294             cborutil.decodeitem(encoded[1:3]),
   255         self.assertEqual(cborutil.decodeitem(encoded[1:5]),
   295             (False, None, -5, cborutil.SPECIAL_NONE),
   256                          (False, None, -3, cborutil.SPECIAL_NONE))
   296         )
   257         self.assertEqual(cborutil.decodeitem(encoded[1:6]),
   297         self.assertEqual(
   258                          (False, None, -2, cborutil.SPECIAL_NONE))
   298             cborutil.decodeitem(encoded[1:4]),
   259         self.assertEqual(cborutil.decodeitem(encoded[1:7]),
   299             (False, None, -4, cborutil.SPECIAL_NONE),
   260                          (False, None, -1, cborutil.SPECIAL_NONE))
   300         )
   261 
   301         self.assertEqual(
   262         self.assertEqual(cborutil.decodeitem(encoded[1:8]),
   302             cborutil.decodeitem(encoded[1:5]),
   263                          (True, b'foobar', 7, cborutil.SPECIAL_NONE))
   303             (False, None, -3, cborutil.SPECIAL_NONE),
       
   304         )
       
   305         self.assertEqual(
       
   306             cborutil.decodeitem(encoded[1:6]),
       
   307             (False, None, -2, cborutil.SPECIAL_NONE),
       
   308         )
       
   309         self.assertEqual(
       
   310             cborutil.decodeitem(encoded[1:7]),
       
   311             (False, None, -1, cborutil.SPECIAL_NONE),
       
   312         )
       
   313 
       
   314         self.assertEqual(
       
   315             cborutil.decodeitem(encoded[1:8]),
       
   316             (True, b'foobar', 7, cborutil.SPECIAL_NONE),
       
   317         )
   264 
   318 
   265         # Third item should be second chunk. But only available when
   319         # Third item should be second chunk. But only available when
   266         # we give it 4 bytes (1 byte header + 3 byte chunk).
   320         # we give it 4 bytes (1 byte header + 3 byte chunk).
   267         self.assertEqual(cborutil.decodeitem(encoded[8:9]),
   321         self.assertEqual(
   268                          (False, None, -3, cborutil.SPECIAL_NONE))
   322             cborutil.decodeitem(encoded[8:9]),
   269         self.assertEqual(cborutil.decodeitem(encoded[8:10]),
   323             (False, None, -3, cborutil.SPECIAL_NONE),
   270                          (False, None, -2, cborutil.SPECIAL_NONE))
   324         )
   271         self.assertEqual(cborutil.decodeitem(encoded[8:11]),
   325         self.assertEqual(
   272                          (False, None, -1, cborutil.SPECIAL_NONE))
   326             cborutil.decodeitem(encoded[8:10]),
   273 
   327             (False, None, -2, cborutil.SPECIAL_NONE),
   274         self.assertEqual(cborutil.decodeitem(encoded[8:12]),
   328         )
   275                          (True, b'biz', 4, cborutil.SPECIAL_NONE))
   329         self.assertEqual(
       
   330             cborutil.decodeitem(encoded[8:11]),
       
   331             (False, None, -1, cborutil.SPECIAL_NONE),
       
   332         )
       
   333 
       
   334         self.assertEqual(
       
   335             cborutil.decodeitem(encoded[8:12]),
       
   336             (True, b'biz', 4, cborutil.SPECIAL_NONE),
       
   337         )
   276 
   338 
   277         # Fourth item should be end of indefinite stream marker.
   339         # Fourth item should be end of indefinite stream marker.
   278         self.assertEqual(cborutil.decodeitem(encoded[12:13]),
   340         self.assertEqual(
   279                          (True, None, 1, cborutil.SPECIAL_INDEFINITE_BREAK))
   341             cborutil.decodeitem(encoded[12:13]),
       
   342             (True, None, 1, cborutil.SPECIAL_INDEFINITE_BREAK),
       
   343         )
   280 
   344 
   281         # Now test the behavior when going through the decoder.
   345         # Now test the behavior when going through the decoder.
   282 
   346 
   283         self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:1]),
   347         self.assertEqual(
   284                          (False, 1, 0))
   348             cborutil.sansiodecoder().decode(encoded[0:1]), (False, 1, 0)
   285         self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:2]),
   349         )
   286                          (False, 1, 6))
   350         self.assertEqual(
   287         self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:3]),
   351             cborutil.sansiodecoder().decode(encoded[0:2]), (False, 1, 6)
   288                          (False, 1, 5))
   352         )
   289         self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:4]),
   353         self.assertEqual(
   290                          (False, 1, 4))
   354             cborutil.sansiodecoder().decode(encoded[0:3]), (False, 1, 5)
   291         self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:5]),
   355         )
   292                          (False, 1, 3))
   356         self.assertEqual(
   293         self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:6]),
   357             cborutil.sansiodecoder().decode(encoded[0:4]), (False, 1, 4)
   294                          (False, 1, 2))
   358         )
   295         self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:7]),
   359         self.assertEqual(
   296                          (False, 1, 1))
   360             cborutil.sansiodecoder().decode(encoded[0:5]), (False, 1, 3)
   297         self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:8]),
   361         )
   298                          (True, 8, 0))
   362         self.assertEqual(
   299 
   363             cborutil.sansiodecoder().decode(encoded[0:6]), (False, 1, 2)
   300         self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:9]),
   364         )
   301                          (True, 8, 3))
   365         self.assertEqual(
   302         self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:10]),
   366             cborutil.sansiodecoder().decode(encoded[0:7]), (False, 1, 1)
   303                          (True, 8, 2))
   367         )
   304         self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:11]),
   368         self.assertEqual(
   305                          (True, 8, 1))
   369             cborutil.sansiodecoder().decode(encoded[0:8]), (True, 8, 0)
   306         self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:12]),
   370         )
   307                          (True, 12, 0))
   371 
   308 
   372         self.assertEqual(
   309         self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:13]),
   373             cborutil.sansiodecoder().decode(encoded[0:9]), (True, 8, 3)
   310                          (True, 13, 0))
   374         )
       
   375         self.assertEqual(
       
   376             cborutil.sansiodecoder().decode(encoded[0:10]), (True, 8, 2)
       
   377         )
       
   378         self.assertEqual(
       
   379             cborutil.sansiodecoder().decode(encoded[0:11]), (True, 8, 1)
       
   380         )
       
   381         self.assertEqual(
       
   382             cborutil.sansiodecoder().decode(encoded[0:12]), (True, 12, 0)
       
   383         )
       
   384 
       
   385         self.assertEqual(
       
   386             cborutil.sansiodecoder().decode(encoded[0:13]), (True, 13, 0)
       
   387         )
   311 
   388 
   312         decoder = cborutil.sansiodecoder()
   389         decoder = cborutil.sansiodecoder()
   313         decoder.decode(encoded[0:8])
   390         decoder.decode(encoded[0:8])
   314         values = decoder.getavailable()
   391         values = decoder.getavailable()
   315         self.assertEqual(values, [b'foobar'])
   392         self.assertEqual(values, [b'foobar'])
   316         self.assertTrue(values[0].isfirst)
   393         self.assertTrue(values[0].isfirst)
   317         self.assertFalse(values[0].islast)
   394         self.assertFalse(values[0].islast)
   318 
   395 
   319         self.assertEqual(decoder.decode(encoded[8:12]),
   396         self.assertEqual(decoder.decode(encoded[8:12]), (True, 4, 0))
   320                          (True, 4, 0))
       
   321         values = decoder.getavailable()
   397         values = decoder.getavailable()
   322         self.assertEqual(values, [b'biz'])
   398         self.assertEqual(values, [b'biz'])
   323         self.assertFalse(values[0].isfirst)
   399         self.assertFalse(values[0].isfirst)
   324         self.assertFalse(values[0].islast)
   400         self.assertFalse(values[0].islast)
   325 
   401 
   326         self.assertEqual(decoder.decode(encoded[12:]),
   402         self.assertEqual(decoder.decode(encoded[12:]), (True, 1, 0))
   327                          (True, 1, 0))
       
   328         values = decoder.getavailable()
   403         values = decoder.getavailable()
   329         self.assertEqual(values, [b''])
   404         self.assertEqual(values, [b''])
   330         self.assertFalse(values[0].isfirst)
   405         self.assertFalse(values[0].isfirst)
   331         self.assertTrue(values[0].islast)
   406         self.assertTrue(values[0].islast)
   332 
   407 
       
   408 
   333 class StringTests(TestCase):
   409 class StringTests(TestCase):
   334     def testdecodeforbidden(self):
   410     def testdecodeforbidden(self):
   335         encoded = b'\x63foo'
   411         encoded = b'\x63foo'
   336         with self.assertRaisesRegex(cborutil.CBORDecodeError,
   412         with self.assertRaisesRegex(
   337                                     'string major type not supported'):
   413             cborutil.CBORDecodeError, 'string major type not supported'
       
   414         ):
   338             cborutil.decodeall(encoded)
   415             cborutil.decodeall(encoded)
       
   416 
   339 
   417 
   340 class IntTests(TestCase):
   418 class IntTests(TestCase):
   341     def testsmall(self):
   419     def testsmall(self):
   342         self.assertEqual(list(cborutil.streamencode(0)), [b'\x00'])
   420         self.assertEqual(list(cborutil.streamencode(0)), [b'\x00'])
   343         self.assertEqual(cborutil.decodeall(b'\x00'), [0])
   421         self.assertEqual(cborutil.decodeall(b'\x00'), [0])
   353 
   431 
   354         self.assertEqual(list(cborutil.streamencode(4)), [b'\x04'])
   432         self.assertEqual(list(cborutil.streamencode(4)), [b'\x04'])
   355         self.assertEqual(cborutil.decodeall(b'\x04'), [4])
   433         self.assertEqual(cborutil.decodeall(b'\x04'), [4])
   356 
   434 
   357         # Multiple value decode works.
   435         # Multiple value decode works.
   358         self.assertEqual(cborutil.decodeall(b'\x00\x01\x02\x03\x04'),
   436         self.assertEqual(
   359                          [0, 1, 2, 3, 4])
   437             cborutil.decodeall(b'\x00\x01\x02\x03\x04'), [0, 1, 2, 3, 4]
       
   438         )
   360 
   439 
   361     def testnegativesmall(self):
   440     def testnegativesmall(self):
   362         self.assertEqual(list(cborutil.streamencode(-1)), [b'\x20'])
   441         self.assertEqual(list(cborutil.streamencode(-1)), [b'\x20'])
   363         self.assertEqual(cborutil.decodeall(b'\x20'), [-1])
   442         self.assertEqual(cborutil.decodeall(b'\x20'), [-1])
   364 
   443 
   373 
   452 
   374         self.assertEqual(list(cborutil.streamencode(-5)), [b'\x24'])
   453         self.assertEqual(list(cborutil.streamencode(-5)), [b'\x24'])
   375         self.assertEqual(cborutil.decodeall(b'\x24'), [-5])
   454         self.assertEqual(cborutil.decodeall(b'\x24'), [-5])
   376 
   455 
   377         # Multiple value decode works.
   456         # Multiple value decode works.
   378         self.assertEqual(cborutil.decodeall(b'\x20\x21\x22\x23\x24'),
   457         self.assertEqual(
   379                          [-1, -2, -3, -4, -5])
   458             cborutil.decodeall(b'\x20\x21\x22\x23\x24'), [-1, -2, -3, -4, -5]
       
   459         )
   380 
   460 
   381     def testrange(self):
   461     def testrange(self):
   382         for i in range(-70000, 70000, 10):
   462         for i in range(-70000, 70000, 10):
   383             encoded = b''.join(cborutil.streamencode(i))
   463             encoded = b''.join(cborutil.streamencode(i))
   384 
   464 
   386             self.assertEqual(cborutil.decodeall(encoded), [i])
   466             self.assertEqual(cborutil.decodeall(encoded), [i])
   387 
   467 
   388     def testdecodepartialubyte(self):
   468     def testdecodepartialubyte(self):
   389         encoded = b''.join(cborutil.streamencode(250))
   469         encoded = b''.join(cborutil.streamencode(250))
   390 
   470 
   391         self.assertEqual(cborutil.decodeitem(encoded[0:1]),
   471         self.assertEqual(
   392                          (False, None, -1, cborutil.SPECIAL_NONE))
   472             cborutil.decodeitem(encoded[0:1]),
   393         self.assertEqual(cborutil.decodeitem(encoded[0:2]),
   473             (False, None, -1, cborutil.SPECIAL_NONE),
   394                          (True, 250, 2, cborutil.SPECIAL_NONE))
   474         )
       
   475         self.assertEqual(
       
   476             cborutil.decodeitem(encoded[0:2]),
       
   477             (True, 250, 2, cborutil.SPECIAL_NONE),
       
   478         )
   395 
   479 
   396     def testdecodepartialbyte(self):
   480     def testdecodepartialbyte(self):
   397         encoded = b''.join(cborutil.streamencode(-42))
   481         encoded = b''.join(cborutil.streamencode(-42))
   398         self.assertEqual(cborutil.decodeitem(encoded[0:1]),
   482         self.assertEqual(
   399                          (False, None, -1, cborutil.SPECIAL_NONE))
   483             cborutil.decodeitem(encoded[0:1]),
   400         self.assertEqual(cborutil.decodeitem(encoded[0:2]),
   484             (False, None, -1, cborutil.SPECIAL_NONE),
   401                          (True, -42, 2, cborutil.SPECIAL_NONE))
   485         )
       
   486         self.assertEqual(
       
   487             cborutil.decodeitem(encoded[0:2]),
       
   488             (True, -42, 2, cborutil.SPECIAL_NONE),
       
   489         )
   402 
   490 
   403     def testdecodepartialushort(self):
   491     def testdecodepartialushort(self):
   404         encoded = b''.join(cborutil.streamencode(2**15))
   492         encoded = b''.join(cborutil.streamencode(2 ** 15))
   405 
   493 
   406         self.assertEqual(cborutil.decodeitem(encoded[0:1]),
   494         self.assertEqual(
   407                          (False, None, -2, cborutil.SPECIAL_NONE))
   495             cborutil.decodeitem(encoded[0:1]),
   408         self.assertEqual(cborutil.decodeitem(encoded[0:2]),
   496             (False, None, -2, cborutil.SPECIAL_NONE),
   409                          (False, None, -1, cborutil.SPECIAL_NONE))
   497         )
   410         self.assertEqual(cborutil.decodeitem(encoded[0:5]),
   498         self.assertEqual(
   411                          (True, 2**15, 3, cborutil.SPECIAL_NONE))
   499             cborutil.decodeitem(encoded[0:2]),
       
   500             (False, None, -1, cborutil.SPECIAL_NONE),
       
   501         )
       
   502         self.assertEqual(
       
   503             cborutil.decodeitem(encoded[0:5]),
       
   504             (True, 2 ** 15, 3, cborutil.SPECIAL_NONE),
       
   505         )
   412 
   506 
   413     def testdecodepartialshort(self):
   507     def testdecodepartialshort(self):
   414         encoded = b''.join(cborutil.streamencode(-1024))
   508         encoded = b''.join(cborutil.streamencode(-1024))
   415 
   509 
   416         self.assertEqual(cborutil.decodeitem(encoded[0:1]),
   510         self.assertEqual(
   417                          (False, None, -2, cborutil.SPECIAL_NONE))
   511             cborutil.decodeitem(encoded[0:1]),
   418         self.assertEqual(cborutil.decodeitem(encoded[0:2]),
   512             (False, None, -2, cborutil.SPECIAL_NONE),
   419                          (False, None, -1, cborutil.SPECIAL_NONE))
   513         )
   420         self.assertEqual(cborutil.decodeitem(encoded[0:3]),
   514         self.assertEqual(
   421                          (True, -1024, 3, cborutil.SPECIAL_NONE))
   515             cborutil.decodeitem(encoded[0:2]),
       
   516             (False, None, -1, cborutil.SPECIAL_NONE),
       
   517         )
       
   518         self.assertEqual(
       
   519             cborutil.decodeitem(encoded[0:3]),
       
   520             (True, -1024, 3, cborutil.SPECIAL_NONE),
       
   521         )
   422 
   522 
   423     def testdecodepartialulong(self):
   523     def testdecodepartialulong(self):
   424         encoded = b''.join(cborutil.streamencode(2**28))
   524         encoded = b''.join(cborutil.streamencode(2 ** 28))
   425 
   525 
   426         self.assertEqual(cborutil.decodeitem(encoded[0:1]),
   526         self.assertEqual(
   427                          (False, None, -4, cborutil.SPECIAL_NONE))
   527             cborutil.decodeitem(encoded[0:1]),
   428         self.assertEqual(cborutil.decodeitem(encoded[0:2]),
   528             (False, None, -4, cborutil.SPECIAL_NONE),
   429                          (False, None, -3, cborutil.SPECIAL_NONE))
   529         )
   430         self.assertEqual(cborutil.decodeitem(encoded[0:3]),
   530         self.assertEqual(
   431                          (False, None, -2, cborutil.SPECIAL_NONE))
   531             cborutil.decodeitem(encoded[0:2]),
   432         self.assertEqual(cborutil.decodeitem(encoded[0:4]),
   532             (False, None, -3, cborutil.SPECIAL_NONE),
   433                          (False, None, -1, cborutil.SPECIAL_NONE))
   533         )
   434         self.assertEqual(cborutil.decodeitem(encoded[0:5]),
   534         self.assertEqual(
   435                          (True, 2**28, 5, cborutil.SPECIAL_NONE))
   535             cborutil.decodeitem(encoded[0:3]),
       
   536             (False, None, -2, cborutil.SPECIAL_NONE),
       
   537         )
       
   538         self.assertEqual(
       
   539             cborutil.decodeitem(encoded[0:4]),
       
   540             (False, None, -1, cborutil.SPECIAL_NONE),
       
   541         )
       
   542         self.assertEqual(
       
   543             cborutil.decodeitem(encoded[0:5]),
       
   544             (True, 2 ** 28, 5, cborutil.SPECIAL_NONE),
       
   545         )
   436 
   546 
   437     def testdecodepartiallong(self):
   547     def testdecodepartiallong(self):
   438         encoded = b''.join(cborutil.streamencode(-1048580))
   548         encoded = b''.join(cborutil.streamencode(-1048580))
   439 
   549 
   440         self.assertEqual(cborutil.decodeitem(encoded[0:1]),
   550         self.assertEqual(
   441                          (False, None, -4, cborutil.SPECIAL_NONE))
   551             cborutil.decodeitem(encoded[0:1]),
   442         self.assertEqual(cborutil.decodeitem(encoded[0:2]),
   552             (False, None, -4, cborutil.SPECIAL_NONE),
   443                          (False, None, -3, cborutil.SPECIAL_NONE))
   553         )
   444         self.assertEqual(cborutil.decodeitem(encoded[0:3]),
   554         self.assertEqual(
   445                          (False, None, -2, cborutil.SPECIAL_NONE))
   555             cborutil.decodeitem(encoded[0:2]),
   446         self.assertEqual(cborutil.decodeitem(encoded[0:4]),
   556             (False, None, -3, cborutil.SPECIAL_NONE),
   447                          (False, None, -1, cborutil.SPECIAL_NONE))
   557         )
   448         self.assertEqual(cborutil.decodeitem(encoded[0:5]),
   558         self.assertEqual(
   449                          (True, -1048580, 5, cborutil.SPECIAL_NONE))
   559             cborutil.decodeitem(encoded[0:3]),
       
   560             (False, None, -2, cborutil.SPECIAL_NONE),
       
   561         )
       
   562         self.assertEqual(
       
   563             cborutil.decodeitem(encoded[0:4]),
       
   564             (False, None, -1, cborutil.SPECIAL_NONE),
       
   565         )
       
   566         self.assertEqual(
       
   567             cborutil.decodeitem(encoded[0:5]),
       
   568             (True, -1048580, 5, cborutil.SPECIAL_NONE),
       
   569         )
   450 
   570 
   451     def testdecodepartialulonglong(self):
   571     def testdecodepartialulonglong(self):
   452         encoded = b''.join(cborutil.streamencode(2**32))
   572         encoded = b''.join(cborutil.streamencode(2 ** 32))
   453 
   573 
   454         self.assertEqual(cborutil.decodeitem(encoded[0:1]),
   574         self.assertEqual(
   455                          (False, None, -8, cborutil.SPECIAL_NONE))
   575             cborutil.decodeitem(encoded[0:1]),
   456         self.assertEqual(cborutil.decodeitem(encoded[0:2]),
   576             (False, None, -8, cborutil.SPECIAL_NONE),
   457                          (False, None, -7, cborutil.SPECIAL_NONE))
   577         )
   458         self.assertEqual(cborutil.decodeitem(encoded[0:3]),
   578         self.assertEqual(
   459                          (False, None, -6, cborutil.SPECIAL_NONE))
   579             cborutil.decodeitem(encoded[0:2]),
   460         self.assertEqual(cborutil.decodeitem(encoded[0:4]),
   580             (False, None, -7, cborutil.SPECIAL_NONE),
   461                          (False, None, -5, cborutil.SPECIAL_NONE))
   581         )
   462         self.assertEqual(cborutil.decodeitem(encoded[0:5]),
   582         self.assertEqual(
   463                          (False, None, -4, cborutil.SPECIAL_NONE))
   583             cborutil.decodeitem(encoded[0:3]),
   464         self.assertEqual(cborutil.decodeitem(encoded[0:6]),
   584             (False, None, -6, cborutil.SPECIAL_NONE),
   465                          (False, None, -3, cborutil.SPECIAL_NONE))
   585         )
   466         self.assertEqual(cborutil.decodeitem(encoded[0:7]),
   586         self.assertEqual(
   467                          (False, None, -2, cborutil.SPECIAL_NONE))
   587             cborutil.decodeitem(encoded[0:4]),
   468         self.assertEqual(cborutil.decodeitem(encoded[0:8]),
   588             (False, None, -5, cborutil.SPECIAL_NONE),
   469                          (False, None, -1, cborutil.SPECIAL_NONE))
   589         )
   470         self.assertEqual(cborutil.decodeitem(encoded[0:9]),
   590         self.assertEqual(
   471                          (True, 2**32, 9, cborutil.SPECIAL_NONE))
   591             cborutil.decodeitem(encoded[0:5]),
   472 
   592             (False, None, -4, cborutil.SPECIAL_NONE),
   473         with self.assertRaisesRegex(
   593         )
   474             cborutil.CBORDecodeError, 'input data not fully consumed'):
   594         self.assertEqual(
       
   595             cborutil.decodeitem(encoded[0:6]),
       
   596             (False, None, -3, cborutil.SPECIAL_NONE),
       
   597         )
       
   598         self.assertEqual(
       
   599             cborutil.decodeitem(encoded[0:7]),
       
   600             (False, None, -2, cborutil.SPECIAL_NONE),
       
   601         )
       
   602         self.assertEqual(
       
   603             cborutil.decodeitem(encoded[0:8]),
       
   604             (False, None, -1, cborutil.SPECIAL_NONE),
       
   605         )
       
   606         self.assertEqual(
       
   607             cborutil.decodeitem(encoded[0:9]),
       
   608             (True, 2 ** 32, 9, cborutil.SPECIAL_NONE),
       
   609         )
       
   610 
       
   611         with self.assertRaisesRegex(
       
   612             cborutil.CBORDecodeError, 'input data not fully consumed'
       
   613         ):
   475             cborutil.decodeall(encoded[0:1])
   614             cborutil.decodeall(encoded[0:1])
   476 
   615 
   477         with self.assertRaisesRegex(
   616         with self.assertRaisesRegex(
   478             cborutil.CBORDecodeError, 'input data not fully consumed'):
   617             cborutil.CBORDecodeError, 'input data not fully consumed'
       
   618         ):
   479             cborutil.decodeall(encoded[0:2])
   619             cborutil.decodeall(encoded[0:2])
   480 
   620 
   481     def testdecodepartiallonglong(self):
   621     def testdecodepartiallonglong(self):
   482         encoded = b''.join(cborutil.streamencode(-7000000000))
   622         encoded = b''.join(cborutil.streamencode(-7000000000))
   483 
   623 
   484         self.assertEqual(cborutil.decodeitem(encoded[0:1]),
   624         self.assertEqual(
   485                          (False, None, -8, cborutil.SPECIAL_NONE))
   625             cborutil.decodeitem(encoded[0:1]),
   486         self.assertEqual(cborutil.decodeitem(encoded[0:2]),
   626             (False, None, -8, cborutil.SPECIAL_NONE),
   487                          (False, None, -7, cborutil.SPECIAL_NONE))
   627         )
   488         self.assertEqual(cborutil.decodeitem(encoded[0:3]),
   628         self.assertEqual(
   489                          (False, None, -6, cborutil.SPECIAL_NONE))
   629             cborutil.decodeitem(encoded[0:2]),
   490         self.assertEqual(cborutil.decodeitem(encoded[0:4]),
   630             (False, None, -7, cborutil.SPECIAL_NONE),
   491                          (False, None, -5, cborutil.SPECIAL_NONE))
   631         )
   492         self.assertEqual(cborutil.decodeitem(encoded[0:5]),
   632         self.assertEqual(
   493                          (False, None, -4, cborutil.SPECIAL_NONE))
   633             cborutil.decodeitem(encoded[0:3]),
   494         self.assertEqual(cborutil.decodeitem(encoded[0:6]),
   634             (False, None, -6, cborutil.SPECIAL_NONE),
   495                          (False, None, -3, cborutil.SPECIAL_NONE))
   635         )
   496         self.assertEqual(cborutil.decodeitem(encoded[0:7]),
   636         self.assertEqual(
   497                          (False, None, -2, cborutil.SPECIAL_NONE))
   637             cborutil.decodeitem(encoded[0:4]),
   498         self.assertEqual(cborutil.decodeitem(encoded[0:8]),
   638             (False, None, -5, cborutil.SPECIAL_NONE),
   499                          (False, None, -1, cborutil.SPECIAL_NONE))
   639         )
   500         self.assertEqual(cborutil.decodeitem(encoded[0:9]),
   640         self.assertEqual(
   501                          (True, -7000000000, 9, cborutil.SPECIAL_NONE))
   641             cborutil.decodeitem(encoded[0:5]),
       
   642             (False, None, -4, cborutil.SPECIAL_NONE),
       
   643         )
       
   644         self.assertEqual(
       
   645             cborutil.decodeitem(encoded[0:6]),
       
   646             (False, None, -3, cborutil.SPECIAL_NONE),
       
   647         )
       
   648         self.assertEqual(
       
   649             cborutil.decodeitem(encoded[0:7]),
       
   650             (False, None, -2, cborutil.SPECIAL_NONE),
       
   651         )
       
   652         self.assertEqual(
       
   653             cborutil.decodeitem(encoded[0:8]),
       
   654             (False, None, -1, cborutil.SPECIAL_NONE),
       
   655         )
       
   656         self.assertEqual(
       
   657             cborutil.decodeitem(encoded[0:9]),
       
   658             (True, -7000000000, 9, cborutil.SPECIAL_NONE),
       
   659         )
       
   660 
   502 
   661 
   503 class ArrayTests(TestCase):
   662 class ArrayTests(TestCase):
   504     def testempty(self):
   663     def testempty(self):
   505         self.assertEqual(list(cborutil.streamencode([])), [b'\x80'])
   664         self.assertEqual(list(cborutil.streamencode([])), [b'\x80'])
   506         self.assertEqual(loadit(cborutil.streamencode([])), [])
   665         self.assertEqual(loadit(cborutil.streamencode([])), [])
   508         self.assertEqual(cborutil.decodeall(b'\x80'), [[]])
   667         self.assertEqual(cborutil.decodeall(b'\x80'), [[]])
   509 
   668 
   510     def testbasic(self):
   669     def testbasic(self):
   511         source = [b'foo', b'bar', 1, -10]
   670         source = [b'foo', b'bar', 1, -10]
   512 
   671 
   513         chunks = [
   672         chunks = [b'\x84', b'\x43', b'foo', b'\x43', b'bar', b'\x01', b'\x29']
   514             b'\x84', b'\x43', b'foo', b'\x43', b'bar', b'\x01', b'\x29']
       
   515 
   673 
   516         self.assertEqual(list(cborutil.streamencode(source)), chunks)
   674         self.assertEqual(list(cborutil.streamencode(source)), chunks)
   517 
   675 
   518         self.assertEqual(cborutil.decodeall(b''.join(chunks)), [source])
   676         self.assertEqual(cborutil.decodeall(b''.join(chunks)), [source])
   519 
   677 
   520     def testemptyfromiter(self):
   678     def testemptyfromiter(self):
   521         self.assertEqual(b''.join(cborutil.streamencodearrayfromiter([])),
   679         self.assertEqual(
   522                          b'\x9f\xff')
   680             b''.join(cborutil.streamencodearrayfromiter([])), b'\x9f\xff'
   523 
   681         )
   524         with self.assertRaisesRegex(cborutil.CBORDecodeError,
   682 
   525                                     'indefinite length uint not allowed'):
   683         with self.assertRaisesRegex(
       
   684             cborutil.CBORDecodeError, 'indefinite length uint not allowed'
       
   685         ):
   526             cborutil.decodeall(b'\x9f\xff')
   686             cborutil.decodeall(b'\x9f\xff')
   527 
   687 
   528     def testfromiter1(self):
   688     def testfromiter1(self):
   529         source = [b'foo']
   689         source = [b'foo']
   530 
   690 
   531         self.assertEqual(list(cborutil.streamencodearrayfromiter(source)), [
   691         self.assertEqual(
   532             b'\x9f',
   692             list(cborutil.streamencodearrayfromiter(source)),
   533             b'\x43', b'foo',
   693             [b'\x9f', b'\x43', b'foo', b'\xff',],
   534             b'\xff',
   694         )
   535         ])
       
   536 
   695 
   537         dest = b''.join(cborutil.streamencodearrayfromiter(source))
   696         dest = b''.join(cborutil.streamencodearrayfromiter(source))
   538         self.assertEqual(cbor.loads(dest), source)
   697         self.assertEqual(cbor.loads(dest), source)
   539 
   698 
   540         with self.assertRaisesRegex(cborutil.CBORDecodeError,
   699         with self.assertRaisesRegex(
   541                                     'indefinite length uint not allowed'):
   700             cborutil.CBORDecodeError, 'indefinite length uint not allowed'
       
   701         ):
   542             cborutil.decodeall(dest)
   702             cborutil.decodeall(dest)
   543 
   703 
   544     def testtuple(self):
   704     def testtuple(self):
   545         source = (b'foo', None, 42)
   705         source = (b'foo', None, 42)
   546         encoded = b''.join(cborutil.streamencode(source))
   706         encoded = b''.join(cborutil.streamencode(source))
   550         self.assertEqual(cborutil.decodeall(encoded), [list(source)])
   710         self.assertEqual(cborutil.decodeall(encoded), [list(source)])
   551 
   711 
   552     def testpartialdecode(self):
   712     def testpartialdecode(self):
   553         source = list(range(4))
   713         source = list(range(4))
   554         encoded = b''.join(cborutil.streamencode(source))
   714         encoded = b''.join(cborutil.streamencode(source))
   555         self.assertEqual(cborutil.decodeitem(encoded[0:1]),
   715         self.assertEqual(
   556                          (True, 4, 1, cborutil.SPECIAL_START_ARRAY))
   716             cborutil.decodeitem(encoded[0:1]),
   557         self.assertEqual(cborutil.decodeitem(encoded[0:2]),
   717             (True, 4, 1, cborutil.SPECIAL_START_ARRAY),
   558                          (True, 4, 1, cborutil.SPECIAL_START_ARRAY))
   718         )
       
   719         self.assertEqual(
       
   720             cborutil.decodeitem(encoded[0:2]),
       
   721             (True, 4, 1, cborutil.SPECIAL_START_ARRAY),
       
   722         )
   559 
   723 
   560         source = list(range(23))
   724         source = list(range(23))
   561         encoded = b''.join(cborutil.streamencode(source))
   725         encoded = b''.join(cborutil.streamencode(source))
   562         self.assertEqual(cborutil.decodeitem(encoded[0:1]),
   726         self.assertEqual(
   563                          (True, 23, 1, cborutil.SPECIAL_START_ARRAY))
   727             cborutil.decodeitem(encoded[0:1]),
   564         self.assertEqual(cborutil.decodeitem(encoded[0:2]),
   728             (True, 23, 1, cborutil.SPECIAL_START_ARRAY),
   565                          (True, 23, 1, cborutil.SPECIAL_START_ARRAY))
   729         )
       
   730         self.assertEqual(
       
   731             cborutil.decodeitem(encoded[0:2]),
       
   732             (True, 23, 1, cborutil.SPECIAL_START_ARRAY),
       
   733         )
   566 
   734 
   567         source = list(range(24))
   735         source = list(range(24))
   568         encoded = b''.join(cborutil.streamencode(source))
   736         encoded = b''.join(cborutil.streamencode(source))
   569         self.assertEqual(cborutil.decodeitem(encoded[0:1]),
   737         self.assertEqual(
   570                          (False, None, -1, cborutil.SPECIAL_NONE))
   738             cborutil.decodeitem(encoded[0:1]),
   571         self.assertEqual(cborutil.decodeitem(encoded[0:2]),
   739             (False, None, -1, cborutil.SPECIAL_NONE),
   572                          (True, 24, 2, cborutil.SPECIAL_START_ARRAY))
   740         )
   573         self.assertEqual(cborutil.decodeitem(encoded[0:3]),
   741         self.assertEqual(
   574                          (True, 24, 2, cborutil.SPECIAL_START_ARRAY))
   742             cborutil.decodeitem(encoded[0:2]),
       
   743             (True, 24, 2, cborutil.SPECIAL_START_ARRAY),
       
   744         )
       
   745         self.assertEqual(
       
   746             cborutil.decodeitem(encoded[0:3]),
       
   747             (True, 24, 2, cborutil.SPECIAL_START_ARRAY),
       
   748         )
   575 
   749 
   576         source = list(range(256))
   750         source = list(range(256))
   577         encoded = b''.join(cborutil.streamencode(source))
   751         encoded = b''.join(cborutil.streamencode(source))
   578         self.assertEqual(cborutil.decodeitem(encoded[0:1]),
   752         self.assertEqual(
   579                          (False, None, -2, cborutil.SPECIAL_NONE))
   753             cborutil.decodeitem(encoded[0:1]),
   580         self.assertEqual(cborutil.decodeitem(encoded[0:2]),
   754             (False, None, -2, cborutil.SPECIAL_NONE),
   581                          (False, None, -1, cborutil.SPECIAL_NONE))
   755         )
   582         self.assertEqual(cborutil.decodeitem(encoded[0:3]),
   756         self.assertEqual(
   583                          (True, 256, 3, cborutil.SPECIAL_START_ARRAY))
   757             cborutil.decodeitem(encoded[0:2]),
   584         self.assertEqual(cborutil.decodeitem(encoded[0:4]),
   758             (False, None, -1, cborutil.SPECIAL_NONE),
   585                          (True, 256, 3, cborutil.SPECIAL_START_ARRAY))
   759         )
       
   760         self.assertEqual(
       
   761             cborutil.decodeitem(encoded[0:3]),
       
   762             (True, 256, 3, cborutil.SPECIAL_START_ARRAY),
       
   763         )
       
   764         self.assertEqual(
       
   765             cborutil.decodeitem(encoded[0:4]),
       
   766             (True, 256, 3, cborutil.SPECIAL_START_ARRAY),
       
   767         )
   586 
   768 
   587     def testnested(self):
   769     def testnested(self):
   588         source = [[], [], [[], [], []]]
   770         source = [[], [], [[], [], []]]
   589         encoded = b''.join(cborutil.streamencode(source))
   771         encoded = b''.join(cborutil.streamencode(source))
   590         self.assertEqual(cborutil.decodeall(encoded), [source])
   772         self.assertEqual(cborutil.decodeall(encoded), [source])
   605 
   787 
   606     def testindefinitebytestringvalues(self):
   788     def testindefinitebytestringvalues(self):
   607         # Single value array whose value is an empty indefinite bytestring.
   789         # Single value array whose value is an empty indefinite bytestring.
   608         encoded = b'\x81\x5f\x40\xff'
   790         encoded = b'\x81\x5f\x40\xff'
   609 
   791 
   610         with self.assertRaisesRegex(cborutil.CBORDecodeError,
   792         with self.assertRaisesRegex(
   611                                     'indefinite length bytestrings not '
   793             cborutil.CBORDecodeError,
   612                                     'allowed as array values'):
   794             'indefinite length bytestrings not ' 'allowed as array values',
       
   795         ):
   613             cborutil.decodeall(encoded)
   796             cborutil.decodeall(encoded)
       
   797 
   614 
   798 
   615 class SetTests(TestCase):
   799 class SetTests(TestCase):
   616     def testempty(self):
   800     def testempty(self):
   617         self.assertEqual(list(cborutil.streamencode(set())), [
   801         self.assertEqual(
   618             b'\xd9\x01\x02',
   802             list(cborutil.streamencode(set())), [b'\xd9\x01\x02', b'\x80',]
   619             b'\x80',
   803         )
   620         ])
       
   621 
   804 
   622         self.assertEqual(cborutil.decodeall(b'\xd9\x01\x02\x80'), [set()])
   805         self.assertEqual(cborutil.decodeall(b'\xd9\x01\x02\x80'), [set()])
   623 
   806 
   624     def testset(self):
   807     def testset(self):
   625         source = {b'foo', None, 42}
   808         source = {b'foo', None, 42}
   631 
   814 
   632     def testinvalidtag(self):
   815     def testinvalidtag(self):
   633         # Must use array to encode sets.
   816         # Must use array to encode sets.
   634         encoded = b'\xd9\x01\x02\xa0'
   817         encoded = b'\xd9\x01\x02\xa0'
   635 
   818 
   636         with self.assertRaisesRegex(cborutil.CBORDecodeError,
   819         with self.assertRaisesRegex(
   637                                     'expected array after finite set '
   820             cborutil.CBORDecodeError,
   638                                     'semantic tag'):
   821             'expected array after finite set ' 'semantic tag',
       
   822         ):
   639             cborutil.decodeall(encoded)
   823             cborutil.decodeall(encoded)
   640 
   824 
   641     def testpartialdecode(self):
   825     def testpartialdecode(self):
   642         # Semantic tag item will be 3 bytes. Set header will be variable
   826         # Semantic tag item will be 3 bytes. Set header will be variable
   643         # depending on length.
   827         # depending on length.
   644         encoded = b''.join(cborutil.streamencode({i for i in range(23)}))
   828         encoded = b''.join(cborutil.streamencode({i for i in range(23)}))
   645         self.assertEqual(cborutil.decodeitem(encoded[0:1]),
   829         self.assertEqual(
   646                          (False, None, -2, cborutil.SPECIAL_NONE))
   830             cborutil.decodeitem(encoded[0:1]),
   647         self.assertEqual(cborutil.decodeitem(encoded[0:2]),
   831             (False, None, -2, cborutil.SPECIAL_NONE),
   648                          (False, None, -1, cborutil.SPECIAL_NONE))
   832         )
   649         self.assertEqual(cborutil.decodeitem(encoded[0:3]),
   833         self.assertEqual(
   650                          (False, None, -1, cborutil.SPECIAL_NONE))
   834             cborutil.decodeitem(encoded[0:2]),
   651         self.assertEqual(cborutil.decodeitem(encoded[0:4]),
   835             (False, None, -1, cborutil.SPECIAL_NONE),
   652                          (True, 23, 4, cborutil.SPECIAL_START_SET))
   836         )
   653         self.assertEqual(cborutil.decodeitem(encoded[0:5]),
   837         self.assertEqual(
   654                          (True, 23, 4, cborutil.SPECIAL_START_SET))
   838             cborutil.decodeitem(encoded[0:3]),
       
   839             (False, None, -1, cborutil.SPECIAL_NONE),
       
   840         )
       
   841         self.assertEqual(
       
   842             cborutil.decodeitem(encoded[0:4]),
       
   843             (True, 23, 4, cborutil.SPECIAL_START_SET),
       
   844         )
       
   845         self.assertEqual(
       
   846             cborutil.decodeitem(encoded[0:5]),
       
   847             (True, 23, 4, cborutil.SPECIAL_START_SET),
       
   848         )
   655 
   849 
   656         encoded = b''.join(cborutil.streamencode({i for i in range(24)}))
   850         encoded = b''.join(cborutil.streamencode({i for i in range(24)}))
   657         self.assertEqual(cborutil.decodeitem(encoded[0:1]),
   851         self.assertEqual(
   658                          (False, None, -2, cborutil.SPECIAL_NONE))
   852             cborutil.decodeitem(encoded[0:1]),
   659         self.assertEqual(cborutil.decodeitem(encoded[0:2]),
   853             (False, None, -2, cborutil.SPECIAL_NONE),
   660                          (False, None, -1, cborutil.SPECIAL_NONE))
   854         )
   661         self.assertEqual(cborutil.decodeitem(encoded[0:3]),
   855         self.assertEqual(
   662                          (False, None, -1, cborutil.SPECIAL_NONE))
   856             cborutil.decodeitem(encoded[0:2]),
   663         self.assertEqual(cborutil.decodeitem(encoded[0:4]),
   857             (False, None, -1, cborutil.SPECIAL_NONE),
   664                          (False, None, -1, cborutil.SPECIAL_NONE))
   858         )
   665         self.assertEqual(cborutil.decodeitem(encoded[0:5]),
   859         self.assertEqual(
   666                          (True, 24, 5, cborutil.SPECIAL_START_SET))
   860             cborutil.decodeitem(encoded[0:3]),
   667         self.assertEqual(cborutil.decodeitem(encoded[0:6]),
   861             (False, None, -1, cborutil.SPECIAL_NONE),
   668                          (True, 24, 5, cborutil.SPECIAL_START_SET))
   862         )
       
   863         self.assertEqual(
       
   864             cborutil.decodeitem(encoded[0:4]),
       
   865             (False, None, -1, cborutil.SPECIAL_NONE),
       
   866         )
       
   867         self.assertEqual(
       
   868             cborutil.decodeitem(encoded[0:5]),
       
   869             (True, 24, 5, cborutil.SPECIAL_START_SET),
       
   870         )
       
   871         self.assertEqual(
       
   872             cborutil.decodeitem(encoded[0:6]),
       
   873             (True, 24, 5, cborutil.SPECIAL_START_SET),
       
   874         )
   669 
   875 
   670         encoded = b''.join(cborutil.streamencode({i for i in range(256)}))
   876         encoded = b''.join(cborutil.streamencode({i for i in range(256)}))
   671         self.assertEqual(cborutil.decodeitem(encoded[0:1]),
   877         self.assertEqual(
   672                          (False, None, -2, cborutil.SPECIAL_NONE))
   878             cborutil.decodeitem(encoded[0:1]),
   673         self.assertEqual(cborutil.decodeitem(encoded[0:2]),
   879             (False, None, -2, cborutil.SPECIAL_NONE),
   674                          (False, None, -1, cborutil.SPECIAL_NONE))
   880         )
   675         self.assertEqual(cborutil.decodeitem(encoded[0:3]),
   881         self.assertEqual(
   676                          (False, None, -1, cborutil.SPECIAL_NONE))
   882             cborutil.decodeitem(encoded[0:2]),
   677         self.assertEqual(cborutil.decodeitem(encoded[0:4]),
   883             (False, None, -1, cborutil.SPECIAL_NONE),
   678                          (False, None, -2, cborutil.SPECIAL_NONE))
   884         )
   679         self.assertEqual(cborutil.decodeitem(encoded[0:5]),
   885         self.assertEqual(
   680                          (False, None, -1, cborutil.SPECIAL_NONE))
   886             cborutil.decodeitem(encoded[0:3]),
   681         self.assertEqual(cborutil.decodeitem(encoded[0:6]),
   887             (False, None, -1, cborutil.SPECIAL_NONE),
   682                          (True, 256, 6, cborutil.SPECIAL_START_SET))
   888         )
       
   889         self.assertEqual(
       
   890             cborutil.decodeitem(encoded[0:4]),
       
   891             (False, None, -2, cborutil.SPECIAL_NONE),
       
   892         )
       
   893         self.assertEqual(
       
   894             cborutil.decodeitem(encoded[0:5]),
       
   895             (False, None, -1, cborutil.SPECIAL_NONE),
       
   896         )
       
   897         self.assertEqual(
       
   898             cborutil.decodeitem(encoded[0:6]),
       
   899             (True, 256, 6, cborutil.SPECIAL_START_SET),
       
   900         )
   683 
   901 
   684     def testinvalidvalue(self):
   902     def testinvalidvalue(self):
   685         encoded = b''.join([
   903         encoded = b''.join(
   686             b'\xd9\x01\x02', # semantic tag
   904             [
   687             b'\x81', # array of size 1
   905                 b'\xd9\x01\x02',  # semantic tag
   688             b'\x5f\x43foo\xff', # indefinite length bytestring "foo"
   906                 b'\x81',  # array of size 1
   689         ])
   907                 b'\x5f\x43foo\xff',  # indefinite length bytestring "foo"
   690 
   908             ]
   691         with self.assertRaisesRegex(cborutil.CBORDecodeError,
   909         )
   692                                     'indefinite length bytestrings not '
   910 
   693                                     'allowed as set values'):
   911         with self.assertRaisesRegex(
       
   912             cborutil.CBORDecodeError,
       
   913             'indefinite length bytestrings not ' 'allowed as set values',
       
   914         ):
   694             cborutil.decodeall(encoded)
   915             cborutil.decodeall(encoded)
   695 
   916 
   696         encoded = b''.join([
   917         encoded = b''.join([b'\xd9\x01\x02', b'\x81', b'\x80',])  # empty array
   697             b'\xd9\x01\x02',
   918 
   698             b'\x81',
   919         with self.assertRaisesRegex(
   699             b'\x80', # empty array
   920             cborutil.CBORDecodeError, 'collections not allowed as set values'
   700         ])
   921         ):
   701 
       
   702         with self.assertRaisesRegex(cborutil.CBORDecodeError,
       
   703                                     'collections not allowed as set values'):
       
   704             cborutil.decodeall(encoded)
   922             cborutil.decodeall(encoded)
   705 
   923 
   706         encoded = b''.join([
   924         encoded = b''.join([b'\xd9\x01\x02', b'\x81', b'\xa0',])  # empty map
   707             b'\xd9\x01\x02',
   925 
   708             b'\x81',
   926         with self.assertRaisesRegex(
   709             b'\xa0', # empty map
   927             cborutil.CBORDecodeError, 'collections not allowed as set values'
   710         ])
   928         ):
   711 
       
   712         with self.assertRaisesRegex(cborutil.CBORDecodeError,
       
   713                                     'collections not allowed as set values'):
       
   714             cborutil.decodeall(encoded)
   929             cborutil.decodeall(encoded)
   715 
   930 
   716         encoded = b''.join([
   931         encoded = b''.join(
   717             b'\xd9\x01\x02',
   932             [
   718             b'\x81',
   933                 b'\xd9\x01\x02',
   719             b'\xd9\x01\x02\x81\x01', # set with integer 1
   934                 b'\x81',
   720         ])
   935                 b'\xd9\x01\x02\x81\x01',  # set with integer 1
   721 
   936             ]
   722         with self.assertRaisesRegex(cborutil.CBORDecodeError,
   937         )
   723                                     'collections not allowed as set values'):
   938 
       
   939         with self.assertRaisesRegex(
       
   940             cborutil.CBORDecodeError, 'collections not allowed as set values'
       
   941         ):
   724             cborutil.decodeall(encoded)
   942             cborutil.decodeall(encoded)
       
   943 
   725 
   944 
   726 class BoolTests(TestCase):
   945 class BoolTests(TestCase):
   727     def testbasic(self):
   946     def testbasic(self):
   728         self.assertEqual(list(cborutil.streamencode(True)),  [b'\xf5'])
   947         self.assertEqual(list(cborutil.streamencode(True)), [b'\xf5'])
   729         self.assertEqual(list(cborutil.streamencode(False)), [b'\xf4'])
   948         self.assertEqual(list(cborutil.streamencode(False)), [b'\xf4'])
   730 
   949 
   731         self.assertIs(loadit(cborutil.streamencode(True)), True)
   950         self.assertIs(loadit(cborutil.streamencode(True)), True)
   732         self.assertIs(loadit(cborutil.streamencode(False)), False)
   951         self.assertIs(loadit(cborutil.streamencode(False)), False)
   733 
   952 
   734         self.assertEqual(cborutil.decodeall(b'\xf4'), [False])
   953         self.assertEqual(cborutil.decodeall(b'\xf4'), [False])
   735         self.assertEqual(cborutil.decodeall(b'\xf5'), [True])
   954         self.assertEqual(cborutil.decodeall(b'\xf5'), [True])
   736 
   955 
   737         self.assertEqual(cborutil.decodeall(b'\xf4\xf5\xf5\xf4'),
   956         self.assertEqual(
   738                          [False, True, True, False])
   957             cborutil.decodeall(b'\xf4\xf5\xf5\xf4'), [False, True, True, False]
       
   958         )
       
   959 
   739 
   960 
   740 class NoneTests(TestCase):
   961 class NoneTests(TestCase):
   741     def testbasic(self):
   962     def testbasic(self):
   742         self.assertEqual(list(cborutil.streamencode(None)), [b'\xf6'])
   963         self.assertEqual(list(cborutil.streamencode(None)), [b'\xf6'])
   743 
   964 
   744         self.assertIs(loadit(cborutil.streamencode(None)), None)
   965         self.assertIs(loadit(cborutil.streamencode(None)), None)
   745 
   966 
   746         self.assertEqual(cborutil.decodeall(b'\xf6'), [None])
   967         self.assertEqual(cborutil.decodeall(b'\xf6'), [None])
   747         self.assertEqual(cborutil.decodeall(b'\xf6\xf6'), [None, None])
   968         self.assertEqual(cborutil.decodeall(b'\xf6\xf6'), [None, None])
       
   969 
   748 
   970 
   749 class MapTests(TestCase):
   971 class MapTests(TestCase):
   750     def testempty(self):
   972     def testempty(self):
   751         self.assertEqual(list(cborutil.streamencode({})), [b'\xa0'])
   973         self.assertEqual(list(cborutil.streamencode({})), [b'\xa0'])
   752         self.assertEqual(loadit(cborutil.streamencode({})), {})
   974         self.assertEqual(loadit(cborutil.streamencode({})), {})
   753 
   975 
   754         self.assertEqual(cborutil.decodeall(b'\xa0'), [{}])
   976         self.assertEqual(cborutil.decodeall(b'\xa0'), [{}])
   755 
   977 
   756     def testemptyindefinite(self):
   978     def testemptyindefinite(self):
   757         self.assertEqual(list(cborutil.streamencodemapfromiter([])), [
   979         self.assertEqual(
   758             b'\xbf', b'\xff'])
   980             list(cborutil.streamencodemapfromiter([])), [b'\xbf', b'\xff']
       
   981         )
   759 
   982 
   760         self.assertEqual(loadit(cborutil.streamencodemapfromiter([])), {})
   983         self.assertEqual(loadit(cborutil.streamencodemapfromiter([])), {})
   761 
   984 
   762         with self.assertRaisesRegex(cborutil.CBORDecodeError,
   985         with self.assertRaisesRegex(
   763                                     'indefinite length uint not allowed'):
   986             cborutil.CBORDecodeError, 'indefinite length uint not allowed'
       
   987         ):
   764             cborutil.decodeall(b'\xbf\xff')
   988             cborutil.decodeall(b'\xbf\xff')
   765 
   989 
   766     def testone(self):
   990     def testone(self):
   767         source = {b'foo': b'bar'}
   991         source = {b'foo': b'bar'}
   768         self.assertEqual(list(cborutil.streamencode(source)), [
   992         self.assertEqual(
   769             b'\xa1', b'\x43', b'foo', b'\x43', b'bar'])
   993             list(cborutil.streamencode(source)),
       
   994             [b'\xa1', b'\x43', b'foo', b'\x43', b'bar'],
       
   995         )
   770 
   996 
   771         self.assertEqual(loadit(cborutil.streamencode(source)), source)
   997         self.assertEqual(loadit(cborutil.streamencode(source)), source)
   772 
   998 
   773         self.assertEqual(cborutil.decodeall(b'\xa1\x43foo\x43bar'), [source])
   999         self.assertEqual(cborutil.decodeall(b'\xa1\x43foo\x43bar'), [source])
   774 
  1000 
   779         }
  1005         }
   780 
  1006 
   781         self.assertEqual(loadit(cborutil.streamencode(source)), source)
  1007         self.assertEqual(loadit(cborutil.streamencode(source)), source)
   782 
  1008 
   783         self.assertEqual(
  1009         self.assertEqual(
   784             loadit(cborutil.streamencodemapfromiter(source.items())),
  1010             loadit(cborutil.streamencodemapfromiter(source.items())), source
   785             source)
  1011         )
   786 
  1012 
   787         encoded = b''.join(cborutil.streamencode(source))
  1013         encoded = b''.join(cborutil.streamencode(source))
   788         self.assertEqual(cborutil.decodeall(encoded), [source])
  1014         self.assertEqual(cborutil.decodeall(encoded), [source])
   789 
  1015 
   790     def testcomplex(self):
  1016     def testcomplex(self):
   791         source = {
  1017         source = {
   792             b'key': 1,
  1018             b'key': 1,
   793             2: -10,
  1019             2: -10,
   794         }
  1020         }
   795 
  1021 
   796         self.assertEqual(loadit(cborutil.streamencode(source)),
  1022         self.assertEqual(loadit(cborutil.streamencode(source)), source)
   797                          source)
  1023 
   798 
  1024         self.assertEqual(
   799         self.assertEqual(
  1025             loadit(cborutil.streamencodemapfromiter(source.items())), source
   800             loadit(cborutil.streamencodemapfromiter(source.items())),
  1026         )
   801             source)
       
   802 
  1027 
   803         encoded = b''.join(cborutil.streamencode(source))
  1028         encoded = b''.join(cborutil.streamencode(source))
   804         self.assertEqual(cborutil.decodeall(encoded), [source])
  1029         self.assertEqual(cborutil.decodeall(encoded), [source])
   805 
  1030 
   806     def testnested(self):
  1031     def testnested(self):
   817         }
  1042         }
   818         encoded = b''.join(cborutil.streamencode(source))
  1043         encoded = b''.join(cborutil.streamencode(source))
   819         self.assertEqual(cborutil.decodeall(encoded), [source])
  1044         self.assertEqual(cborutil.decodeall(encoded), [source])
   820 
  1045 
   821     def testillegalkey(self):
  1046     def testillegalkey(self):
   822         encoded = b''.join([
  1047         encoded = b''.join(
   823             # map header + len 1
  1048             [
   824             b'\xa1',
  1049                 # map header + len 1
   825             # indefinite length bytestring "foo" in key position
  1050                 b'\xa1',
   826             b'\x5f\x03foo\xff'
  1051                 # indefinite length bytestring "foo" in key position
   827         ])
  1052                 b'\x5f\x03foo\xff',
   828 
  1053             ]
   829         with self.assertRaisesRegex(cborutil.CBORDecodeError,
  1054         )
   830                                     'indefinite length bytestrings not '
  1055 
   831                                     'allowed as map keys'):
  1056         with self.assertRaisesRegex(
       
  1057             cborutil.CBORDecodeError,
       
  1058             'indefinite length bytestrings not ' 'allowed as map keys',
       
  1059         ):
   832             cborutil.decodeall(encoded)
  1060             cborutil.decodeall(encoded)
   833 
  1061 
   834         encoded = b''.join([
  1062         encoded = b''.join([b'\xa1', b'\x80', b'\x43foo',])  # empty array
   835             b'\xa1',
  1063 
   836             b'\x80', # empty array
  1064         with self.assertRaisesRegex(
   837             b'\x43foo',
  1065             cborutil.CBORDecodeError, 'collections not supported as map keys'
   838         ])
  1066         ):
   839 
       
   840         with self.assertRaisesRegex(cborutil.CBORDecodeError,
       
   841                                     'collections not supported as map keys'):
       
   842             cborutil.decodeall(encoded)
  1067             cborutil.decodeall(encoded)
   843 
  1068 
   844     def testillegalvalue(self):
  1069     def testillegalvalue(self):
   845         encoded = b''.join([
  1070         encoded = b''.join(
   846             b'\xa1', # map headers
  1071             [
   847             b'\x43foo', # key
  1072                 b'\xa1',  # map headers
   848             b'\x5f\x03bar\xff', # indefinite length value
  1073                 b'\x43foo',  # key
   849         ])
  1074                 b'\x5f\x03bar\xff',  # indefinite length value
   850 
  1075             ]
   851         with self.assertRaisesRegex(cborutil.CBORDecodeError,
  1076         )
   852                                     'indefinite length bytestrings not '
  1077 
   853                                     'allowed as map values'):
  1078         with self.assertRaisesRegex(
       
  1079             cborutil.CBORDecodeError,
       
  1080             'indefinite length bytestrings not ' 'allowed as map values',
       
  1081         ):
   854             cborutil.decodeall(encoded)
  1082             cborutil.decodeall(encoded)
   855 
  1083 
   856     def testpartialdecode(self):
  1084     def testpartialdecode(self):
   857         source = {b'key1': b'value1'}
  1085         source = {b'key1': b'value1'}
   858         encoded = b''.join(cborutil.streamencode(source))
  1086         encoded = b''.join(cborutil.streamencode(source))
   859 
  1087 
   860         self.assertEqual(cborutil.decodeitem(encoded[0:1]),
  1088         self.assertEqual(
   861                          (True, 1, 1, cborutil.SPECIAL_START_MAP))
  1089             cborutil.decodeitem(encoded[0:1]),
   862         self.assertEqual(cborutil.decodeitem(encoded[0:2]),
  1090             (True, 1, 1, cborutil.SPECIAL_START_MAP),
   863                          (True, 1, 1, cborutil.SPECIAL_START_MAP))
  1091         )
       
  1092         self.assertEqual(
       
  1093             cborutil.decodeitem(encoded[0:2]),
       
  1094             (True, 1, 1, cborutil.SPECIAL_START_MAP),
       
  1095         )
   864 
  1096 
   865         source = {b'key%d' % i: None for i in range(23)}
  1097         source = {b'key%d' % i: None for i in range(23)}
   866         encoded = b''.join(cborutil.streamencode(source))
  1098         encoded = b''.join(cborutil.streamencode(source))
   867         self.assertEqual(cborutil.decodeitem(encoded[0:1]),
  1099         self.assertEqual(
   868                          (True, 23, 1, cborutil.SPECIAL_START_MAP))
  1100             cborutil.decodeitem(encoded[0:1]),
       
  1101             (True, 23, 1, cborutil.SPECIAL_START_MAP),
       
  1102         )
   869 
  1103 
   870         source = {b'key%d' % i: None for i in range(24)}
  1104         source = {b'key%d' % i: None for i in range(24)}
   871         encoded = b''.join(cborutil.streamencode(source))
  1105         encoded = b''.join(cborutil.streamencode(source))
   872         self.assertEqual(cborutil.decodeitem(encoded[0:1]),
  1106         self.assertEqual(
   873                          (False, None, -1, cborutil.SPECIAL_NONE))
  1107             cborutil.decodeitem(encoded[0:1]),
   874         self.assertEqual(cborutil.decodeitem(encoded[0:2]),
  1108             (False, None, -1, cborutil.SPECIAL_NONE),
   875                          (True, 24, 2, cborutil.SPECIAL_START_MAP))
  1109         )
   876         self.assertEqual(cborutil.decodeitem(encoded[0:3]),
  1110         self.assertEqual(
   877                          (True, 24, 2, cborutil.SPECIAL_START_MAP))
  1111             cborutil.decodeitem(encoded[0:2]),
       
  1112             (True, 24, 2, cborutil.SPECIAL_START_MAP),
       
  1113         )
       
  1114         self.assertEqual(
       
  1115             cborutil.decodeitem(encoded[0:3]),
       
  1116             (True, 24, 2, cborutil.SPECIAL_START_MAP),
       
  1117         )
   878 
  1118 
   879         source = {b'key%d' % i: None for i in range(256)}
  1119         source = {b'key%d' % i: None for i in range(256)}
   880         encoded = b''.join(cborutil.streamencode(source))
  1120         encoded = b''.join(cborutil.streamencode(source))
   881         self.assertEqual(cborutil.decodeitem(encoded[0:1]),
  1121         self.assertEqual(
   882                          (False, None, -2, cborutil.SPECIAL_NONE))
  1122             cborutil.decodeitem(encoded[0:1]),
   883         self.assertEqual(cborutil.decodeitem(encoded[0:2]),
  1123             (False, None, -2, cborutil.SPECIAL_NONE),
   884                          (False, None, -1, cborutil.SPECIAL_NONE))
  1124         )
   885         self.assertEqual(cborutil.decodeitem(encoded[0:3]),
  1125         self.assertEqual(
   886                          (True, 256, 3, cborutil.SPECIAL_START_MAP))
  1126             cborutil.decodeitem(encoded[0:2]),
   887         self.assertEqual(cborutil.decodeitem(encoded[0:4]),
  1127             (False, None, -1, cborutil.SPECIAL_NONE),
   888                          (True, 256, 3, cborutil.SPECIAL_START_MAP))
  1128         )
       
  1129         self.assertEqual(
       
  1130             cborutil.decodeitem(encoded[0:3]),
       
  1131             (True, 256, 3, cborutil.SPECIAL_START_MAP),
       
  1132         )
       
  1133         self.assertEqual(
       
  1134             cborutil.decodeitem(encoded[0:4]),
       
  1135             (True, 256, 3, cborutil.SPECIAL_START_MAP),
       
  1136         )
   889 
  1137 
   890         source = {b'key%d' % i: None for i in range(65536)}
  1138         source = {b'key%d' % i: None for i in range(65536)}
   891         encoded = b''.join(cborutil.streamencode(source))
  1139         encoded = b''.join(cborutil.streamencode(source))
   892         self.assertEqual(cborutil.decodeitem(encoded[0:1]),
  1140         self.assertEqual(
   893                          (False, None, -4, cborutil.SPECIAL_NONE))
  1141             cborutil.decodeitem(encoded[0:1]),
   894         self.assertEqual(cborutil.decodeitem(encoded[0:2]),
  1142             (False, None, -4, cborutil.SPECIAL_NONE),
   895                          (False, None, -3, cborutil.SPECIAL_NONE))
  1143         )
   896         self.assertEqual(cborutil.decodeitem(encoded[0:3]),
  1144         self.assertEqual(
   897                          (False, None, -2, cborutil.SPECIAL_NONE))
  1145             cborutil.decodeitem(encoded[0:2]),
   898         self.assertEqual(cborutil.decodeitem(encoded[0:4]),
  1146             (False, None, -3, cborutil.SPECIAL_NONE),
   899                          (False, None, -1, cborutil.SPECIAL_NONE))
  1147         )
   900         self.assertEqual(cborutil.decodeitem(encoded[0:5]),
  1148         self.assertEqual(
   901                          (True, 65536, 5, cborutil.SPECIAL_START_MAP))
  1149             cborutil.decodeitem(encoded[0:3]),
   902         self.assertEqual(cborutil.decodeitem(encoded[0:6]),
  1150             (False, None, -2, cborutil.SPECIAL_NONE),
   903                          (True, 65536, 5, cborutil.SPECIAL_START_MAP))
  1151         )
       
  1152         self.assertEqual(
       
  1153             cborutil.decodeitem(encoded[0:4]),
       
  1154             (False, None, -1, cborutil.SPECIAL_NONE),
       
  1155         )
       
  1156         self.assertEqual(
       
  1157             cborutil.decodeitem(encoded[0:5]),
       
  1158             (True, 65536, 5, cborutil.SPECIAL_START_MAP),
       
  1159         )
       
  1160         self.assertEqual(
       
  1161             cborutil.decodeitem(encoded[0:6]),
       
  1162             (True, 65536, 5, cborutil.SPECIAL_START_MAP),
       
  1163         )
       
  1164 
   904 
  1165 
   905 class SemanticTagTests(TestCase):
  1166 class SemanticTagTests(TestCase):
   906     def testdecodeforbidden(self):
  1167     def testdecodeforbidden(self):
   907         for i in range(500):
  1168         for i in range(500):
   908             if i == cborutil.SEMANTIC_TAG_FINITE_SET:
  1169             if i == cborutil.SEMANTIC_TAG_FINITE_SET:
   909                 continue
  1170                 continue
   910 
  1171 
   911             tag = cborutil.encodelength(cborutil.MAJOR_TYPE_SEMANTIC,
  1172             tag = cborutil.encodelength(cborutil.MAJOR_TYPE_SEMANTIC, i)
   912                                         i)
       
   913 
  1173 
   914             encoded = tag + cborutil.encodelength(cborutil.MAJOR_TYPE_UINT, 42)
  1174             encoded = tag + cborutil.encodelength(cborutil.MAJOR_TYPE_UINT, 42)
   915 
  1175 
   916             # Partial decode is incomplete.
  1176             # Partial decode is incomplete.
   917             if i < 24:
  1177             if i < 24:
   918                 pass
  1178                 pass
   919             elif i < 256:
  1179             elif i < 256:
   920                 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
  1180                 self.assertEqual(
   921                                  (False, None, -1, cborutil.SPECIAL_NONE))
  1181                     cborutil.decodeitem(encoded[0:1]),
       
  1182                     (False, None, -1, cborutil.SPECIAL_NONE),
       
  1183                 )
   922             elif i < 65536:
  1184             elif i < 65536:
   923                 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
  1185                 self.assertEqual(
   924                                  (False, None, -2, cborutil.SPECIAL_NONE))
  1186                     cborutil.decodeitem(encoded[0:1]),
   925                 self.assertEqual(cborutil.decodeitem(encoded[0:2]),
  1187                     (False, None, -2, cborutil.SPECIAL_NONE),
   926                                  (False, None, -1, cborutil.SPECIAL_NONE))
  1188                 )
   927 
  1189                 self.assertEqual(
   928             with self.assertRaisesRegex(cborutil.CBORDecodeError,
  1190                     cborutil.decodeitem(encoded[0:2]),
   929                                         r'semantic tag \d+ not allowed'):
  1191                     (False, None, -1, cborutil.SPECIAL_NONE),
       
  1192                 )
       
  1193 
       
  1194             with self.assertRaisesRegex(
       
  1195                 cborutil.CBORDecodeError, r'semantic tag \d+ not allowed'
       
  1196             ):
   930                 cborutil.decodeitem(encoded)
  1197                 cborutil.decodeitem(encoded)
       
  1198 
   931 
  1199 
   932 class SpecialTypesTests(TestCase):
  1200 class SpecialTypesTests(TestCase):
   933     def testforbiddentypes(self):
  1201     def testforbiddentypes(self):
   934         for i in range(256):
  1202         for i in range(256):
   935             if i == cborutil.SUBTYPE_FALSE:
  1203             if i == cborutil.SUBTYPE_FALSE:
   939             elif i == cborutil.SUBTYPE_NULL:
  1207             elif i == cborutil.SUBTYPE_NULL:
   940                 continue
  1208                 continue
   941 
  1209 
   942             encoded = cborutil.encodelength(cborutil.MAJOR_TYPE_SPECIAL, i)
  1210             encoded = cborutil.encodelength(cborutil.MAJOR_TYPE_SPECIAL, i)
   943 
  1211 
   944             with self.assertRaisesRegex(cborutil.CBORDecodeError,
  1212             with self.assertRaisesRegex(
   945                                         r'special type \d+ not allowed'):
  1213                 cborutil.CBORDecodeError, r'special type \d+ not allowed'
       
  1214             ):
   946                 cborutil.decodeitem(encoded)
  1215                 cborutil.decodeitem(encoded)
       
  1216 
   947 
  1217 
   948 class SansIODecoderTests(TestCase):
  1218 class SansIODecoderTests(TestCase):
   949     def testemptyinput(self):
  1219     def testemptyinput(self):
   950         decoder = cborutil.sansiodecoder()
  1220         decoder = cborutil.sansiodecoder()
   951         self.assertEqual(decoder.decode(b''), (False, 0, 0))
  1221         self.assertEqual(decoder.decode(b''), (False, 0, 0))
       
  1222 
   952 
  1223 
   953 class BufferingDecoderTests(TestCase):
  1224 class BufferingDecoderTests(TestCase):
   954     def testsimple(self):
  1225     def testsimple(self):
   955         source = [
  1226         source = [
   956             b'foobar',
  1227             b'foobar',
   967         for step in range(1, 32):
  1238         for step in range(1, 32):
   968             decoder = cborutil.bufferingdecoder()
  1239             decoder = cborutil.bufferingdecoder()
   969             start = 0
  1240             start = 0
   970 
  1241 
   971             while start < len(encoded):
  1242             while start < len(encoded):
   972                 decoder.decode(encoded[start:start + step])
  1243                 decoder.decode(encoded[start : start + step])
   973                 start += step
  1244                 start += step
   974 
  1245 
   975             self.assertEqual(decoder.getavailable(), [source])
  1246             self.assertEqual(decoder.getavailable(), [source])
   976 
  1247 
   977     def testbytearray(self):
  1248     def testbytearray(self):
   979 
  1250 
   980         decoder = cborutil.bufferingdecoder()
  1251         decoder = cborutil.bufferingdecoder()
   981         decoder.decode(bytearray(source))
  1252         decoder.decode(bytearray(source))
   982 
  1253 
   983         self.assertEqual(decoder.getavailable(), [b'foobar'])
  1254         self.assertEqual(decoder.getavailable(), [b'foobar'])
       
  1255 
   984 
  1256 
   985 class DecodeallTests(TestCase):
  1257 class DecodeallTests(TestCase):
   986     def testemptyinput(self):
  1258     def testemptyinput(self):
   987         self.assertEqual(cborutil.decodeall(b''), [])
  1259         self.assertEqual(cborutil.decodeall(b''), [])
   988 
  1260 
   989     def testpartialinput(self):
  1261     def testpartialinput(self):
   990         encoded = b''.join([
  1262         encoded = b''.join(
   991             b'\x82', # array of 2 elements
  1263             [b'\x82', b'\x01',]  # array of 2 elements  # integer 1
   992             b'\x01', # integer 1
  1264         )
   993         ])
  1265 
   994 
  1266         with self.assertRaisesRegex(
   995         with self.assertRaisesRegex(cborutil.CBORDecodeError,
  1267             cborutil.CBORDecodeError, 'input data not complete'
   996                                     'input data not complete'):
  1268         ):
   997             cborutil.decodeall(encoded)
  1269             cborutil.decodeall(encoded)
       
  1270 
   998 
  1271 
   999 if __name__ == '__main__':
  1272 if __name__ == '__main__':
  1000     import silenttestrunner
  1273     import silenttestrunner
       
  1274 
  1001     silenttestrunner.main(__name__)
  1275     silenttestrunner.main(__name__)