contrib/python-zstandard/tests/test_compressor.py
changeset 30435 b86a448a2965
child 30822 b54a2984cdd4
equal deleted inserted replaced
30434:2e484bdea8c4 30435:b86a448a2965
       
     1 import hashlib
       
     2 import io
       
     3 import struct
       
     4 import sys
       
     5 
       
     6 try:
       
     7     import unittest2 as unittest
       
     8 except ImportError:
       
     9     import unittest
       
    10 
       
    11 import zstd
       
    12 
       
    13 from .common import OpCountingBytesIO
       
    14 
       
    15 
       
    16 if sys.version_info[0] >= 3:
       
    17     next = lambda it: it.__next__()
       
    18 else:
       
    19     next = lambda it: it.next()
       
    20 
       
    21 
       
    22 class TestCompressor(unittest.TestCase):
       
    23     def test_level_bounds(self):
       
    24         with self.assertRaises(ValueError):
       
    25             zstd.ZstdCompressor(level=0)
       
    26 
       
    27         with self.assertRaises(ValueError):
       
    28             zstd.ZstdCompressor(level=23)
       
    29 
       
    30 
       
    31 class TestCompressor_compress(unittest.TestCase):
       
    32     def test_compress_empty(self):
       
    33         cctx = zstd.ZstdCompressor(level=1)
       
    34         cctx.compress(b'')
       
    35 
       
    36         cctx = zstd.ZstdCompressor(level=22)
       
    37         cctx.compress(b'')
       
    38 
       
    39     def test_compress_empty(self):
       
    40         cctx = zstd.ZstdCompressor(level=1)
       
    41         self.assertEqual(cctx.compress(b''),
       
    42                          b'\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00')
       
    43 
       
    44     def test_compress_large(self):
       
    45         chunks = []
       
    46         for i in range(255):
       
    47             chunks.append(struct.Struct('>B').pack(i) * 16384)
       
    48 
       
    49         cctx = zstd.ZstdCompressor(level=3)
       
    50         result = cctx.compress(b''.join(chunks))
       
    51         self.assertEqual(len(result), 999)
       
    52         self.assertEqual(result[0:4], b'\x28\xb5\x2f\xfd')
       
    53 
       
    54     def test_write_checksum(self):
       
    55         cctx = zstd.ZstdCompressor(level=1)
       
    56         no_checksum = cctx.compress(b'foobar')
       
    57         cctx = zstd.ZstdCompressor(level=1, write_checksum=True)
       
    58         with_checksum = cctx.compress(b'foobar')
       
    59 
       
    60         self.assertEqual(len(with_checksum), len(no_checksum) + 4)
       
    61 
       
    62     def test_write_content_size(self):
       
    63         cctx = zstd.ZstdCompressor(level=1)
       
    64         no_size = cctx.compress(b'foobar' * 256)
       
    65         cctx = zstd.ZstdCompressor(level=1, write_content_size=True)
       
    66         with_size = cctx.compress(b'foobar' * 256)
       
    67 
       
    68         self.assertEqual(len(with_size), len(no_size) + 1)
       
    69 
       
    70     def test_no_dict_id(self):
       
    71         samples = []
       
    72         for i in range(128):
       
    73             samples.append(b'foo' * 64)
       
    74             samples.append(b'bar' * 64)
       
    75             samples.append(b'foobar' * 64)
       
    76 
       
    77         d = zstd.train_dictionary(1024, samples)
       
    78 
       
    79         cctx = zstd.ZstdCompressor(level=1, dict_data=d)
       
    80         with_dict_id = cctx.compress(b'foobarfoobar')
       
    81 
       
    82         cctx = zstd.ZstdCompressor(level=1, dict_data=d, write_dict_id=False)
       
    83         no_dict_id = cctx.compress(b'foobarfoobar')
       
    84 
       
    85         self.assertEqual(len(with_dict_id), len(no_dict_id) + 4)
       
    86 
       
    87     def test_compress_dict_multiple(self):
       
    88         samples = []
       
    89         for i in range(128):
       
    90             samples.append(b'foo' * 64)
       
    91             samples.append(b'bar' * 64)
       
    92             samples.append(b'foobar' * 64)
       
    93 
       
    94         d = zstd.train_dictionary(8192, samples)
       
    95 
       
    96         cctx = zstd.ZstdCompressor(level=1, dict_data=d)
       
    97 
       
    98         for i in range(32):
       
    99             cctx.compress(b'foo bar foobar foo bar foobar')
       
   100 
       
   101 
       
   102 class TestCompressor_compressobj(unittest.TestCase):
       
   103     def test_compressobj_empty(self):
       
   104         cctx = zstd.ZstdCompressor(level=1)
       
   105         cobj = cctx.compressobj()
       
   106         self.assertEqual(cobj.compress(b''), b'')
       
   107         self.assertEqual(cobj.flush(),
       
   108                          b'\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00')
       
   109 
       
   110     def test_compressobj_large(self):
       
   111         chunks = []
       
   112         for i in range(255):
       
   113             chunks.append(struct.Struct('>B').pack(i) * 16384)
       
   114 
       
   115         cctx = zstd.ZstdCompressor(level=3)
       
   116         cobj = cctx.compressobj()
       
   117 
       
   118         result = cobj.compress(b''.join(chunks)) + cobj.flush()
       
   119         self.assertEqual(len(result), 999)
       
   120         self.assertEqual(result[0:4], b'\x28\xb5\x2f\xfd')
       
   121 
       
   122     def test_write_checksum(self):
       
   123         cctx = zstd.ZstdCompressor(level=1)
       
   124         cobj = cctx.compressobj()
       
   125         no_checksum = cobj.compress(b'foobar') + cobj.flush()
       
   126         cctx = zstd.ZstdCompressor(level=1, write_checksum=True)
       
   127         cobj = cctx.compressobj()
       
   128         with_checksum = cobj.compress(b'foobar') + cobj.flush()
       
   129 
       
   130         self.assertEqual(len(with_checksum), len(no_checksum) + 4)
       
   131 
       
   132     def test_write_content_size(self):
       
   133         cctx = zstd.ZstdCompressor(level=1)
       
   134         cobj = cctx.compressobj(size=len(b'foobar' * 256))
       
   135         no_size = cobj.compress(b'foobar' * 256) + cobj.flush()
       
   136         cctx = zstd.ZstdCompressor(level=1, write_content_size=True)
       
   137         cobj = cctx.compressobj(size=len(b'foobar' * 256))
       
   138         with_size = cobj.compress(b'foobar' * 256) + cobj.flush()
       
   139 
       
   140         self.assertEqual(len(with_size), len(no_size) + 1)
       
   141 
       
   142     def test_compress_after_flush(self):
       
   143         cctx = zstd.ZstdCompressor()
       
   144         cobj = cctx.compressobj()
       
   145 
       
   146         cobj.compress(b'foo')
       
   147         cobj.flush()
       
   148 
       
   149         with self.assertRaisesRegexp(zstd.ZstdError, 'cannot call compress\(\) after flush'):
       
   150             cobj.compress(b'foo')
       
   151 
       
   152         with self.assertRaisesRegexp(zstd.ZstdError, 'flush\(\) already called'):
       
   153             cobj.flush()
       
   154 
       
   155 
       
   156 class TestCompressor_copy_stream(unittest.TestCase):
       
   157     def test_no_read(self):
       
   158         source = object()
       
   159         dest = io.BytesIO()
       
   160 
       
   161         cctx = zstd.ZstdCompressor()
       
   162         with self.assertRaises(ValueError):
       
   163             cctx.copy_stream(source, dest)
       
   164 
       
   165     def test_no_write(self):
       
   166         source = io.BytesIO()
       
   167         dest = object()
       
   168 
       
   169         cctx = zstd.ZstdCompressor()
       
   170         with self.assertRaises(ValueError):
       
   171             cctx.copy_stream(source, dest)
       
   172 
       
   173     def test_empty(self):
       
   174         source = io.BytesIO()
       
   175         dest = io.BytesIO()
       
   176 
       
   177         cctx = zstd.ZstdCompressor(level=1)
       
   178         r, w = cctx.copy_stream(source, dest)
       
   179         self.assertEqual(int(r), 0)
       
   180         self.assertEqual(w, 9)
       
   181 
       
   182         self.assertEqual(dest.getvalue(),
       
   183                          b'\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00')
       
   184 
       
   185     def test_large_data(self):
       
   186         source = io.BytesIO()
       
   187         for i in range(255):
       
   188             source.write(struct.Struct('>B').pack(i) * 16384)
       
   189         source.seek(0)
       
   190 
       
   191         dest = io.BytesIO()
       
   192         cctx = zstd.ZstdCompressor()
       
   193         r, w = cctx.copy_stream(source, dest)
       
   194 
       
   195         self.assertEqual(r, 255 * 16384)
       
   196         self.assertEqual(w, 999)
       
   197 
       
   198     def test_write_checksum(self):
       
   199         source = io.BytesIO(b'foobar')
       
   200         no_checksum = io.BytesIO()
       
   201 
       
   202         cctx = zstd.ZstdCompressor(level=1)
       
   203         cctx.copy_stream(source, no_checksum)
       
   204 
       
   205         source.seek(0)
       
   206         with_checksum = io.BytesIO()
       
   207         cctx = zstd.ZstdCompressor(level=1, write_checksum=True)
       
   208         cctx.copy_stream(source, with_checksum)
       
   209 
       
   210         self.assertEqual(len(with_checksum.getvalue()),
       
   211                          len(no_checksum.getvalue()) + 4)
       
   212 
       
   213     def test_write_content_size(self):
       
   214         source = io.BytesIO(b'foobar' * 256)
       
   215         no_size = io.BytesIO()
       
   216 
       
   217         cctx = zstd.ZstdCompressor(level=1)
       
   218         cctx.copy_stream(source, no_size)
       
   219 
       
   220         source.seek(0)
       
   221         with_size = io.BytesIO()
       
   222         cctx = zstd.ZstdCompressor(level=1, write_content_size=True)
       
   223         cctx.copy_stream(source, with_size)
       
   224 
       
   225         # Source content size is unknown, so no content size written.
       
   226         self.assertEqual(len(with_size.getvalue()),
       
   227                          len(no_size.getvalue()))
       
   228 
       
   229         source.seek(0)
       
   230         with_size = io.BytesIO()
       
   231         cctx.copy_stream(source, with_size, size=len(source.getvalue()))
       
   232 
       
   233         # We specified source size, so content size header is present.
       
   234         self.assertEqual(len(with_size.getvalue()),
       
   235                          len(no_size.getvalue()) + 1)
       
   236 
       
   237     def test_read_write_size(self):
       
   238         source = OpCountingBytesIO(b'foobarfoobar')
       
   239         dest = OpCountingBytesIO()
       
   240         cctx = zstd.ZstdCompressor()
       
   241         r, w = cctx.copy_stream(source, dest, read_size=1, write_size=1)
       
   242 
       
   243         self.assertEqual(r, len(source.getvalue()))
       
   244         self.assertEqual(w, 21)
       
   245         self.assertEqual(source._read_count, len(source.getvalue()) + 1)
       
   246         self.assertEqual(dest._write_count, len(dest.getvalue()))
       
   247 
       
   248 
       
   249 def compress(data, level):
       
   250     buffer = io.BytesIO()
       
   251     cctx = zstd.ZstdCompressor(level=level)
       
   252     with cctx.write_to(buffer) as compressor:
       
   253         compressor.write(data)
       
   254     return buffer.getvalue()
       
   255 
       
   256 
       
   257 class TestCompressor_write_to(unittest.TestCase):
       
   258     def test_empty(self):
       
   259         self.assertEqual(compress(b'', 1),
       
   260                          b'\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00')
       
   261 
       
   262     def test_multiple_compress(self):
       
   263         buffer = io.BytesIO()
       
   264         cctx = zstd.ZstdCompressor(level=5)
       
   265         with cctx.write_to(buffer) as compressor:
       
   266             compressor.write(b'foo')
       
   267             compressor.write(b'bar')
       
   268             compressor.write(b'x' * 8192)
       
   269 
       
   270         result = buffer.getvalue()
       
   271         self.assertEqual(result,
       
   272                          b'\x28\xb5\x2f\xfd\x00\x50\x75\x00\x00\x38\x66\x6f'
       
   273                          b'\x6f\x62\x61\x72\x78\x01\x00\xfc\xdf\x03\x23')
       
   274 
       
   275     def test_dictionary(self):
       
   276         samples = []
       
   277         for i in range(128):
       
   278             samples.append(b'foo' * 64)
       
   279             samples.append(b'bar' * 64)
       
   280             samples.append(b'foobar' * 64)
       
   281 
       
   282         d = zstd.train_dictionary(8192, samples)
       
   283 
       
   284         buffer = io.BytesIO()
       
   285         cctx = zstd.ZstdCompressor(level=9, dict_data=d)
       
   286         with cctx.write_to(buffer) as compressor:
       
   287             compressor.write(b'foo')
       
   288             compressor.write(b'bar')
       
   289             compressor.write(b'foo' * 16384)
       
   290 
       
   291         compressed = buffer.getvalue()
       
   292         h = hashlib.sha1(compressed).hexdigest()
       
   293         self.assertEqual(h, '1c5bcd25181bcd8c1a73ea8773323e0056129f92')
       
   294 
       
   295     def test_compression_params(self):
       
   296         params = zstd.CompressionParameters(20, 6, 12, 5, 4, 10, zstd.STRATEGY_FAST)
       
   297 
       
   298         buffer = io.BytesIO()
       
   299         cctx = zstd.ZstdCompressor(compression_params=params)
       
   300         with cctx.write_to(buffer) as compressor:
       
   301             compressor.write(b'foo')
       
   302             compressor.write(b'bar')
       
   303             compressor.write(b'foobar' * 16384)
       
   304 
       
   305         compressed = buffer.getvalue()
       
   306         h = hashlib.sha1(compressed).hexdigest()
       
   307         self.assertEqual(h, '1ae31f270ed7de14235221a604b31ecd517ebd99')
       
   308 
       
   309     def test_write_checksum(self):
       
   310         no_checksum = io.BytesIO()
       
   311         cctx = zstd.ZstdCompressor(level=1)
       
   312         with cctx.write_to(no_checksum) as compressor:
       
   313             compressor.write(b'foobar')
       
   314 
       
   315         with_checksum = io.BytesIO()
       
   316         cctx = zstd.ZstdCompressor(level=1, write_checksum=True)
       
   317         with cctx.write_to(with_checksum) as compressor:
       
   318             compressor.write(b'foobar')
       
   319 
       
   320         self.assertEqual(len(with_checksum.getvalue()),
       
   321                          len(no_checksum.getvalue()) + 4)
       
   322 
       
   323     def test_write_content_size(self):
       
   324         no_size = io.BytesIO()
       
   325         cctx = zstd.ZstdCompressor(level=1)
       
   326         with cctx.write_to(no_size) as compressor:
       
   327             compressor.write(b'foobar' * 256)
       
   328 
       
   329         with_size = io.BytesIO()
       
   330         cctx = zstd.ZstdCompressor(level=1, write_content_size=True)
       
   331         with cctx.write_to(with_size) as compressor:
       
   332             compressor.write(b'foobar' * 256)
       
   333 
       
   334         # Source size is not known in streaming mode, so header not
       
   335         # written.
       
   336         self.assertEqual(len(with_size.getvalue()),
       
   337                          len(no_size.getvalue()))
       
   338 
       
   339         # Declaring size will write the header.
       
   340         with_size = io.BytesIO()
       
   341         with cctx.write_to(with_size, size=len(b'foobar' * 256)) as compressor:
       
   342             compressor.write(b'foobar' * 256)
       
   343 
       
   344         self.assertEqual(len(with_size.getvalue()),
       
   345                          len(no_size.getvalue()) + 1)
       
   346 
       
   347     def test_no_dict_id(self):
       
   348         samples = []
       
   349         for i in range(128):
       
   350             samples.append(b'foo' * 64)
       
   351             samples.append(b'bar' * 64)
       
   352             samples.append(b'foobar' * 64)
       
   353 
       
   354         d = zstd.train_dictionary(1024, samples)
       
   355 
       
   356         with_dict_id = io.BytesIO()
       
   357         cctx = zstd.ZstdCompressor(level=1, dict_data=d)
       
   358         with cctx.write_to(with_dict_id) as compressor:
       
   359             compressor.write(b'foobarfoobar')
       
   360 
       
   361         cctx = zstd.ZstdCompressor(level=1, dict_data=d, write_dict_id=False)
       
   362         no_dict_id = io.BytesIO()
       
   363         with cctx.write_to(no_dict_id) as compressor:
       
   364             compressor.write(b'foobarfoobar')
       
   365 
       
   366         self.assertEqual(len(with_dict_id.getvalue()),
       
   367                          len(no_dict_id.getvalue()) + 4)
       
   368 
       
   369     def test_memory_size(self):
       
   370         cctx = zstd.ZstdCompressor(level=3)
       
   371         buffer = io.BytesIO()
       
   372         with cctx.write_to(buffer) as compressor:
       
   373             size = compressor.memory_size()
       
   374 
       
   375         self.assertGreater(size, 100000)
       
   376 
       
   377     def test_write_size(self):
       
   378         cctx = zstd.ZstdCompressor(level=3)
       
   379         dest = OpCountingBytesIO()
       
   380         with cctx.write_to(dest, write_size=1) as compressor:
       
   381             compressor.write(b'foo')
       
   382             compressor.write(b'bar')
       
   383             compressor.write(b'foobar')
       
   384 
       
   385         self.assertEqual(len(dest.getvalue()), dest._write_count)
       
   386 
       
   387 
       
   388 class TestCompressor_read_from(unittest.TestCase):
       
   389     def test_type_validation(self):
       
   390         cctx = zstd.ZstdCompressor()
       
   391 
       
   392         # Object with read() works.
       
   393         cctx.read_from(io.BytesIO())
       
   394 
       
   395         # Buffer protocol works.
       
   396         cctx.read_from(b'foobar')
       
   397 
       
   398         with self.assertRaisesRegexp(ValueError, 'must pass an object with a read'):
       
   399             cctx.read_from(True)
       
   400 
       
   401     def test_read_empty(self):
       
   402         cctx = zstd.ZstdCompressor(level=1)
       
   403 
       
   404         source = io.BytesIO()
       
   405         it = cctx.read_from(source)
       
   406         chunks = list(it)
       
   407         self.assertEqual(len(chunks), 1)
       
   408         compressed = b''.join(chunks)
       
   409         self.assertEqual(compressed, b'\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00')
       
   410 
       
   411         # And again with the buffer protocol.
       
   412         it = cctx.read_from(b'')
       
   413         chunks = list(it)
       
   414         self.assertEqual(len(chunks), 1)
       
   415         compressed2 = b''.join(chunks)
       
   416         self.assertEqual(compressed2, compressed)
       
   417 
       
   418     def test_read_large(self):
       
   419         cctx = zstd.ZstdCompressor(level=1)
       
   420 
       
   421         source = io.BytesIO()
       
   422         source.write(b'f' * zstd.COMPRESSION_RECOMMENDED_INPUT_SIZE)
       
   423         source.write(b'o')
       
   424         source.seek(0)
       
   425 
       
   426         # Creating an iterator should not perform any compression until
       
   427         # first read.
       
   428         it = cctx.read_from(source, size=len(source.getvalue()))
       
   429         self.assertEqual(source.tell(), 0)
       
   430 
       
   431         # We should have exactly 2 output chunks.
       
   432         chunks = []
       
   433         chunk = next(it)
       
   434         self.assertIsNotNone(chunk)
       
   435         self.assertEqual(source.tell(), zstd.COMPRESSION_RECOMMENDED_INPUT_SIZE)
       
   436         chunks.append(chunk)
       
   437         chunk = next(it)
       
   438         self.assertIsNotNone(chunk)
       
   439         chunks.append(chunk)
       
   440 
       
   441         self.assertEqual(source.tell(), len(source.getvalue()))
       
   442 
       
   443         with self.assertRaises(StopIteration):
       
   444             next(it)
       
   445 
       
   446         # And again for good measure.
       
   447         with self.assertRaises(StopIteration):
       
   448             next(it)
       
   449 
       
   450         # We should get the same output as the one-shot compression mechanism.
       
   451         self.assertEqual(b''.join(chunks), cctx.compress(source.getvalue()))
       
   452 
       
   453         # Now check the buffer protocol.
       
   454         it = cctx.read_from(source.getvalue())
       
   455         chunks = list(it)
       
   456         self.assertEqual(len(chunks), 2)
       
   457         self.assertEqual(b''.join(chunks), cctx.compress(source.getvalue()))
       
   458 
       
   459     def test_read_write_size(self):
       
   460         source = OpCountingBytesIO(b'foobarfoobar')
       
   461         cctx = zstd.ZstdCompressor(level=3)
       
   462         for chunk in cctx.read_from(source, read_size=1, write_size=1):
       
   463             self.assertEqual(len(chunk), 1)
       
   464 
       
   465         self.assertEqual(source._read_count, len(source.getvalue()) + 1)