contrib/python-zstandard/tests/test_compressor.py
changeset 30895 c32454d69b85
parent 30822 b54a2984cdd4
child 31796 e0dc40530c5a
equal deleted inserted replaced
30894:5b60464efbde 30895:c32454d69b85
     8 except ImportError:
     8 except ImportError:
     9     import unittest
     9     import unittest
    10 
    10 
    11 import zstd
    11 import zstd
    12 
    12 
    13 from .common import OpCountingBytesIO
    13 from .common import (
       
    14     make_cffi,
       
    15     OpCountingBytesIO,
       
    16 )
    14 
    17 
    15 
    18 
    16 if sys.version_info[0] >= 3:
    19 if sys.version_info[0] >= 3:
    17     next = lambda it: it.__next__()
    20     next = lambda it: it.__next__()
    18 else:
    21 else:
    19     next = lambda it: it.next()
    22     next = lambda it: it.next()
    20 
    23 
    21 
    24 
       
    25 @make_cffi
    22 class TestCompressor(unittest.TestCase):
    26 class TestCompressor(unittest.TestCase):
    23     def test_level_bounds(self):
    27     def test_level_bounds(self):
    24         with self.assertRaises(ValueError):
    28         with self.assertRaises(ValueError):
    25             zstd.ZstdCompressor(level=0)
    29             zstd.ZstdCompressor(level=0)
    26 
    30 
    27         with self.assertRaises(ValueError):
    31         with self.assertRaises(ValueError):
    28             zstd.ZstdCompressor(level=23)
    32             zstd.ZstdCompressor(level=23)
    29 
    33 
    30 
    34 
       
    35 @make_cffi
    31 class TestCompressor_compress(unittest.TestCase):
    36 class TestCompressor_compress(unittest.TestCase):
    32     def test_compress_empty(self):
    37     def test_compress_empty(self):
    33         cctx = zstd.ZstdCompressor(level=1)
    38         cctx = zstd.ZstdCompressor(level=1)
    34         cctx.compress(b'')
    39         result = cctx.compress(b'')
    35 
    40         self.assertEqual(result, b'\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00')
    36         cctx = zstd.ZstdCompressor(level=22)
    41         params = zstd.get_frame_parameters(result)
    37         cctx.compress(b'')
    42         self.assertEqual(params.content_size, 0)
    38 
    43         self.assertEqual(params.window_size, 524288)
    39     def test_compress_empty(self):
    44         self.assertEqual(params.dict_id, 0)
    40         cctx = zstd.ZstdCompressor(level=1)
    45         self.assertFalse(params.has_checksum, 0)
    41         self.assertEqual(cctx.compress(b''),
       
    42                          b'\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00')
       
    43 
    46 
    44         # TODO should be temporary until https://github.com/facebook/zstd/issues/506
    47         # TODO should be temporary until https://github.com/facebook/zstd/issues/506
    45         # is fixed.
    48         # is fixed.
    46         cctx = zstd.ZstdCompressor(write_content_size=True)
    49         cctx = zstd.ZstdCompressor(write_content_size=True)
    47         with self.assertRaises(ValueError):
    50         with self.assertRaises(ValueError):
    57         cctx = zstd.ZstdCompressor(level=3)
    60         cctx = zstd.ZstdCompressor(level=3)
    58         result = cctx.compress(b''.join(chunks))
    61         result = cctx.compress(b''.join(chunks))
    59         self.assertEqual(len(result), 999)
    62         self.assertEqual(len(result), 999)
    60         self.assertEqual(result[0:4], b'\x28\xb5\x2f\xfd')
    63         self.assertEqual(result[0:4], b'\x28\xb5\x2f\xfd')
    61 
    64 
       
    65         # This matches the test for read_from() below.
       
    66         cctx = zstd.ZstdCompressor(level=1)
       
    67         result = cctx.compress(b'f' * zstd.COMPRESSION_RECOMMENDED_INPUT_SIZE + b'o')
       
    68         self.assertEqual(result, b'\x28\xb5\x2f\xfd\x00\x40\x54\x00\x00'
       
    69                                  b'\x10\x66\x66\x01\x00\xfb\xff\x39\xc0'
       
    70                                  b'\x02\x09\x00\x00\x6f')
       
    71 
    62     def test_write_checksum(self):
    72     def test_write_checksum(self):
    63         cctx = zstd.ZstdCompressor(level=1)
    73         cctx = zstd.ZstdCompressor(level=1)
    64         no_checksum = cctx.compress(b'foobar')
    74         no_checksum = cctx.compress(b'foobar')
    65         cctx = zstd.ZstdCompressor(level=1, write_checksum=True)
    75         cctx = zstd.ZstdCompressor(level=1, write_checksum=True)
    66         with_checksum = cctx.compress(b'foobar')
    76         with_checksum = cctx.compress(b'foobar')
    67 
    77 
    68         self.assertEqual(len(with_checksum), len(no_checksum) + 4)
    78         self.assertEqual(len(with_checksum), len(no_checksum) + 4)
    69 
    79 
       
    80         no_params = zstd.get_frame_parameters(no_checksum)
       
    81         with_params = zstd.get_frame_parameters(with_checksum)
       
    82 
       
    83         self.assertFalse(no_params.has_checksum)
       
    84         self.assertTrue(with_params.has_checksum)
       
    85 
    70     def test_write_content_size(self):
    86     def test_write_content_size(self):
    71         cctx = zstd.ZstdCompressor(level=1)
    87         cctx = zstd.ZstdCompressor(level=1)
    72         no_size = cctx.compress(b'foobar' * 256)
    88         no_size = cctx.compress(b'foobar' * 256)
    73         cctx = zstd.ZstdCompressor(level=1, write_content_size=True)
    89         cctx = zstd.ZstdCompressor(level=1, write_content_size=True)
    74         with_size = cctx.compress(b'foobar' * 256)
    90         with_size = cctx.compress(b'foobar' * 256)
    75 
    91 
    76         self.assertEqual(len(with_size), len(no_size) + 1)
    92         self.assertEqual(len(with_size), len(no_size) + 1)
       
    93 
       
    94         no_params = zstd.get_frame_parameters(no_size)
       
    95         with_params = zstd.get_frame_parameters(with_size)
       
    96         self.assertEqual(no_params.content_size, 0)
       
    97         self.assertEqual(with_params.content_size, 1536)
    77 
    98 
    78     def test_no_dict_id(self):
    99     def test_no_dict_id(self):
    79         samples = []
   100         samples = []
    80         for i in range(128):
   101         for i in range(128):
    81             samples.append(b'foo' * 64)
   102             samples.append(b'foo' * 64)
    90         cctx = zstd.ZstdCompressor(level=1, dict_data=d, write_dict_id=False)
   111         cctx = zstd.ZstdCompressor(level=1, dict_data=d, write_dict_id=False)
    91         no_dict_id = cctx.compress(b'foobarfoobar')
   112         no_dict_id = cctx.compress(b'foobarfoobar')
    92 
   113 
    93         self.assertEqual(len(with_dict_id), len(no_dict_id) + 4)
   114         self.assertEqual(len(with_dict_id), len(no_dict_id) + 4)
    94 
   115 
       
   116         no_params = zstd.get_frame_parameters(no_dict_id)
       
   117         with_params = zstd.get_frame_parameters(with_dict_id)
       
   118         self.assertEqual(no_params.dict_id, 0)
       
   119         self.assertEqual(with_params.dict_id, 1584102229)
       
   120 
    95     def test_compress_dict_multiple(self):
   121     def test_compress_dict_multiple(self):
    96         samples = []
   122         samples = []
    97         for i in range(128):
   123         for i in range(128):
    98             samples.append(b'foo' * 64)
   124             samples.append(b'foo' * 64)
    99             samples.append(b'bar' * 64)
   125             samples.append(b'bar' * 64)
   105 
   131 
   106         for i in range(32):
   132         for i in range(32):
   107             cctx.compress(b'foo bar foobar foo bar foobar')
   133             cctx.compress(b'foo bar foobar foo bar foobar')
   108 
   134 
   109 
   135 
       
   136 @make_cffi
   110 class TestCompressor_compressobj(unittest.TestCase):
   137 class TestCompressor_compressobj(unittest.TestCase):
   111     def test_compressobj_empty(self):
   138     def test_compressobj_empty(self):
   112         cctx = zstd.ZstdCompressor(level=1)
   139         cctx = zstd.ZstdCompressor(level=1)
   113         cobj = cctx.compressobj()
   140         cobj = cctx.compressobj()
   114         self.assertEqual(cobj.compress(b''), b'')
   141         self.assertEqual(cobj.compress(b''), b'')
   124         cobj = cctx.compressobj()
   151         cobj = cctx.compressobj()
   125 
   152 
   126         result = cobj.compress(b''.join(chunks)) + cobj.flush()
   153         result = cobj.compress(b''.join(chunks)) + cobj.flush()
   127         self.assertEqual(len(result), 999)
   154         self.assertEqual(len(result), 999)
   128         self.assertEqual(result[0:4], b'\x28\xb5\x2f\xfd')
   155         self.assertEqual(result[0:4], b'\x28\xb5\x2f\xfd')
       
   156 
       
   157         params = zstd.get_frame_parameters(result)
       
   158         self.assertEqual(params.content_size, 0)
       
   159         self.assertEqual(params.window_size, 1048576)
       
   160         self.assertEqual(params.dict_id, 0)
       
   161         self.assertFalse(params.has_checksum)
   129 
   162 
   130     def test_write_checksum(self):
   163     def test_write_checksum(self):
   131         cctx = zstd.ZstdCompressor(level=1)
   164         cctx = zstd.ZstdCompressor(level=1)
   132         cobj = cctx.compressobj()
   165         cobj = cctx.compressobj()
   133         no_checksum = cobj.compress(b'foobar') + cobj.flush()
   166         no_checksum = cobj.compress(b'foobar') + cobj.flush()
   134         cctx = zstd.ZstdCompressor(level=1, write_checksum=True)
   167         cctx = zstd.ZstdCompressor(level=1, write_checksum=True)
   135         cobj = cctx.compressobj()
   168         cobj = cctx.compressobj()
   136         with_checksum = cobj.compress(b'foobar') + cobj.flush()
   169         with_checksum = cobj.compress(b'foobar') + cobj.flush()
   137 
   170 
       
   171         no_params = zstd.get_frame_parameters(no_checksum)
       
   172         with_params = zstd.get_frame_parameters(with_checksum)
       
   173         self.assertEqual(no_params.content_size, 0)
       
   174         self.assertEqual(with_params.content_size, 0)
       
   175         self.assertEqual(no_params.dict_id, 0)
       
   176         self.assertEqual(with_params.dict_id, 0)
       
   177         self.assertFalse(no_params.has_checksum)
       
   178         self.assertTrue(with_params.has_checksum)
       
   179 
   138         self.assertEqual(len(with_checksum), len(no_checksum) + 4)
   180         self.assertEqual(len(with_checksum), len(no_checksum) + 4)
   139 
   181 
   140     def test_write_content_size(self):
   182     def test_write_content_size(self):
   141         cctx = zstd.ZstdCompressor(level=1)
   183         cctx = zstd.ZstdCompressor(level=1)
   142         cobj = cctx.compressobj(size=len(b'foobar' * 256))
   184         cobj = cctx.compressobj(size=len(b'foobar' * 256))
   143         no_size = cobj.compress(b'foobar' * 256) + cobj.flush()
   185         no_size = cobj.compress(b'foobar' * 256) + cobj.flush()
   144         cctx = zstd.ZstdCompressor(level=1, write_content_size=True)
   186         cctx = zstd.ZstdCompressor(level=1, write_content_size=True)
   145         cobj = cctx.compressobj(size=len(b'foobar' * 256))
   187         cobj = cctx.compressobj(size=len(b'foobar' * 256))
   146         with_size = cobj.compress(b'foobar' * 256) + cobj.flush()
   188         with_size = cobj.compress(b'foobar' * 256) + cobj.flush()
       
   189 
       
   190         no_params = zstd.get_frame_parameters(no_size)
       
   191         with_params = zstd.get_frame_parameters(with_size)
       
   192         self.assertEqual(no_params.content_size, 0)
       
   193         self.assertEqual(with_params.content_size, 1536)
       
   194         self.assertEqual(no_params.dict_id, 0)
       
   195         self.assertEqual(with_params.dict_id, 0)
       
   196         self.assertFalse(no_params.has_checksum)
       
   197         self.assertFalse(with_params.has_checksum)
   147 
   198 
   148         self.assertEqual(len(with_size), len(no_size) + 1)
   199         self.assertEqual(len(with_size), len(no_size) + 1)
   149 
   200 
   150     def test_compress_after_finished(self):
   201     def test_compress_after_finished(self):
   151         cctx = zstd.ZstdCompressor()
   202         cctx = zstd.ZstdCompressor()
   185         self.assertEqual(len(trailing), 7)
   236         self.assertEqual(len(trailing), 7)
   186         header = trailing[0:3]
   237         header = trailing[0:3]
   187         self.assertEqual(header, b'\x01\x00\x00')
   238         self.assertEqual(header, b'\x01\x00\x00')
   188 
   239 
   189 
   240 
       
   241 @make_cffi
   190 class TestCompressor_copy_stream(unittest.TestCase):
   242 class TestCompressor_copy_stream(unittest.TestCase):
   191     def test_no_read(self):
   243     def test_no_read(self):
   192         source = object()
   244         source = object()
   193         dest = io.BytesIO()
   245         dest = io.BytesIO()
   194 
   246 
   227         r, w = cctx.copy_stream(source, dest)
   279         r, w = cctx.copy_stream(source, dest)
   228 
   280 
   229         self.assertEqual(r, 255 * 16384)
   281         self.assertEqual(r, 255 * 16384)
   230         self.assertEqual(w, 999)
   282         self.assertEqual(w, 999)
   231 
   283 
       
   284         params = zstd.get_frame_parameters(dest.getvalue())
       
   285         self.assertEqual(params.content_size, 0)
       
   286         self.assertEqual(params.window_size, 1048576)
       
   287         self.assertEqual(params.dict_id, 0)
       
   288         self.assertFalse(params.has_checksum)
       
   289 
   232     def test_write_checksum(self):
   290     def test_write_checksum(self):
   233         source = io.BytesIO(b'foobar')
   291         source = io.BytesIO(b'foobar')
   234         no_checksum = io.BytesIO()
   292         no_checksum = io.BytesIO()
   235 
   293 
   236         cctx = zstd.ZstdCompressor(level=1)
   294         cctx = zstd.ZstdCompressor(level=1)
   242         cctx.copy_stream(source, with_checksum)
   300         cctx.copy_stream(source, with_checksum)
   243 
   301 
   244         self.assertEqual(len(with_checksum.getvalue()),
   302         self.assertEqual(len(with_checksum.getvalue()),
   245                          len(no_checksum.getvalue()) + 4)
   303                          len(no_checksum.getvalue()) + 4)
   246 
   304 
       
   305         no_params = zstd.get_frame_parameters(no_checksum.getvalue())
       
   306         with_params = zstd.get_frame_parameters(with_checksum.getvalue())
       
   307         self.assertEqual(no_params.content_size, 0)
       
   308         self.assertEqual(with_params.content_size, 0)
       
   309         self.assertEqual(no_params.dict_id, 0)
       
   310         self.assertEqual(with_params.dict_id, 0)
       
   311         self.assertFalse(no_params.has_checksum)
       
   312         self.assertTrue(with_params.has_checksum)
       
   313 
   247     def test_write_content_size(self):
   314     def test_write_content_size(self):
   248         source = io.BytesIO(b'foobar' * 256)
   315         source = io.BytesIO(b'foobar' * 256)
   249         no_size = io.BytesIO()
   316         no_size = io.BytesIO()
   250 
   317 
   251         cctx = zstd.ZstdCompressor(level=1)
   318         cctx = zstd.ZstdCompressor(level=1)
   265         cctx.copy_stream(source, with_size, size=len(source.getvalue()))
   332         cctx.copy_stream(source, with_size, size=len(source.getvalue()))
   266 
   333 
   267         # We specified source size, so content size header is present.
   334         # We specified source size, so content size header is present.
   268         self.assertEqual(len(with_size.getvalue()),
   335         self.assertEqual(len(with_size.getvalue()),
   269                          len(no_size.getvalue()) + 1)
   336                          len(no_size.getvalue()) + 1)
       
   337 
       
   338         no_params = zstd.get_frame_parameters(no_size.getvalue())
       
   339         with_params = zstd.get_frame_parameters(with_size.getvalue())
       
   340         self.assertEqual(no_params.content_size, 0)
       
   341         self.assertEqual(with_params.content_size, 1536)
       
   342         self.assertEqual(no_params.dict_id, 0)
       
   343         self.assertEqual(with_params.dict_id, 0)
       
   344         self.assertFalse(no_params.has_checksum)
       
   345         self.assertFalse(with_params.has_checksum)
   270 
   346 
   271     def test_read_write_size(self):
   347     def test_read_write_size(self):
   272         source = OpCountingBytesIO(b'foobarfoobar')
   348         source = OpCountingBytesIO(b'foobarfoobar')
   273         dest = OpCountingBytesIO()
   349         dest = OpCountingBytesIO()
   274         cctx = zstd.ZstdCompressor()
   350         cctx = zstd.ZstdCompressor()
   286     with cctx.write_to(buffer) as compressor:
   362     with cctx.write_to(buffer) as compressor:
   287         compressor.write(data)
   363         compressor.write(data)
   288     return buffer.getvalue()
   364     return buffer.getvalue()
   289 
   365 
   290 
   366 
       
   367 @make_cffi
   291 class TestCompressor_write_to(unittest.TestCase):
   368 class TestCompressor_write_to(unittest.TestCase):
   292     def test_empty(self):
   369     def test_empty(self):
   293         self.assertEqual(compress(b'', 1),
   370         result = compress(b'', 1)
   294                          b'\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00')
   371         self.assertEqual(result, b'\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00')
       
   372 
       
   373         params = zstd.get_frame_parameters(result)
       
   374         self.assertEqual(params.content_size, 0)
       
   375         self.assertEqual(params.window_size, 524288)
       
   376         self.assertEqual(params.dict_id, 0)
       
   377         self.assertFalse(params.has_checksum)
   295 
   378 
   296     def test_multiple_compress(self):
   379     def test_multiple_compress(self):
   297         buffer = io.BytesIO()
   380         buffer = io.BytesIO()
   298         cctx = zstd.ZstdCompressor(level=5)
   381         cctx = zstd.ZstdCompressor(level=5)
   299         with cctx.write_to(buffer) as compressor:
   382         with cctx.write_to(buffer) as compressor:
   300             compressor.write(b'foo')
   383             self.assertEqual(compressor.write(b'foo'), 0)
   301             compressor.write(b'bar')
   384             self.assertEqual(compressor.write(b'bar'), 0)
   302             compressor.write(b'x' * 8192)
   385             self.assertEqual(compressor.write(b'x' * 8192), 0)
   303 
   386 
   304         result = buffer.getvalue()
   387         result = buffer.getvalue()
   305         self.assertEqual(result,
   388         self.assertEqual(result,
   306                          b'\x28\xb5\x2f\xfd\x00\x50\x75\x00\x00\x38\x66\x6f'
   389                          b'\x28\xb5\x2f\xfd\x00\x50\x75\x00\x00\x38\x66\x6f'
   307                          b'\x6f\x62\x61\x72\x78\x01\x00\xfc\xdf\x03\x23')
   390                          b'\x6f\x62\x61\x72\x78\x01\x00\xfc\xdf\x03\x23')
   316         d = zstd.train_dictionary(8192, samples)
   399         d = zstd.train_dictionary(8192, samples)
   317 
   400 
   318         buffer = io.BytesIO()
   401         buffer = io.BytesIO()
   319         cctx = zstd.ZstdCompressor(level=9, dict_data=d)
   402         cctx = zstd.ZstdCompressor(level=9, dict_data=d)
   320         with cctx.write_to(buffer) as compressor:
   403         with cctx.write_to(buffer) as compressor:
   321             compressor.write(b'foo')
   404             self.assertEqual(compressor.write(b'foo'), 0)
   322             compressor.write(b'bar')
   405             self.assertEqual(compressor.write(b'bar'), 0)
   323             compressor.write(b'foo' * 16384)
   406             self.assertEqual(compressor.write(b'foo' * 16384), 634)
   324 
   407 
   325         compressed = buffer.getvalue()
   408         compressed = buffer.getvalue()
       
   409 
       
   410         params = zstd.get_frame_parameters(compressed)
       
   411         self.assertEqual(params.content_size, 0)
       
   412         self.assertEqual(params.window_size, 1024)
       
   413         self.assertEqual(params.dict_id, d.dict_id())
       
   414         self.assertFalse(params.has_checksum)
       
   415 
       
   416         self.assertEqual(compressed[0:32],
       
   417                          b'\x28\xb5\x2f\xfd\x03\x00\x55\x7b\x6b\x5e\x54\x00'
       
   418                          b'\x00\x00\x02\xfc\xf4\xa5\xba\x23\x3f\x85\xb3\x54'
       
   419                          b'\x00\x00\x18\x6f\x6f\x66\x01\x00')
       
   420 
   326         h = hashlib.sha1(compressed).hexdigest()
   421         h = hashlib.sha1(compressed).hexdigest()
   327         self.assertEqual(h, '1c5bcd25181bcd8c1a73ea8773323e0056129f92')
   422         self.assertEqual(h, '1c5bcd25181bcd8c1a73ea8773323e0056129f92')
   328 
   423 
   329     def test_compression_params(self):
   424     def test_compression_params(self):
   330         params = zstd.CompressionParameters(20, 6, 12, 5, 4, 10, zstd.STRATEGY_FAST)
   425         params = zstd.CompressionParameters(20, 6, 12, 5, 4, 10, zstd.STRATEGY_FAST)
   331 
   426 
   332         buffer = io.BytesIO()
   427         buffer = io.BytesIO()
   333         cctx = zstd.ZstdCompressor(compression_params=params)
   428         cctx = zstd.ZstdCompressor(compression_params=params)
   334         with cctx.write_to(buffer) as compressor:
   429         with cctx.write_to(buffer) as compressor:
   335             compressor.write(b'foo')
   430             self.assertEqual(compressor.write(b'foo'), 0)
   336             compressor.write(b'bar')
   431             self.assertEqual(compressor.write(b'bar'), 0)
   337             compressor.write(b'foobar' * 16384)
   432             self.assertEqual(compressor.write(b'foobar' * 16384), 0)
   338 
   433 
   339         compressed = buffer.getvalue()
   434         compressed = buffer.getvalue()
       
   435 
       
   436         params = zstd.get_frame_parameters(compressed)
       
   437         self.assertEqual(params.content_size, 0)
       
   438         self.assertEqual(params.window_size, 1048576)
       
   439         self.assertEqual(params.dict_id, 0)
       
   440         self.assertFalse(params.has_checksum)
       
   441 
   340         h = hashlib.sha1(compressed).hexdigest()
   442         h = hashlib.sha1(compressed).hexdigest()
   341         self.assertEqual(h, '1ae31f270ed7de14235221a604b31ecd517ebd99')
   443         self.assertEqual(h, '1ae31f270ed7de14235221a604b31ecd517ebd99')
   342 
   444 
   343     def test_write_checksum(self):
   445     def test_write_checksum(self):
   344         no_checksum = io.BytesIO()
   446         no_checksum = io.BytesIO()
   345         cctx = zstd.ZstdCompressor(level=1)
   447         cctx = zstd.ZstdCompressor(level=1)
   346         with cctx.write_to(no_checksum) as compressor:
   448         with cctx.write_to(no_checksum) as compressor:
   347             compressor.write(b'foobar')
   449             self.assertEqual(compressor.write(b'foobar'), 0)
   348 
   450 
   349         with_checksum = io.BytesIO()
   451         with_checksum = io.BytesIO()
   350         cctx = zstd.ZstdCompressor(level=1, write_checksum=True)
   452         cctx = zstd.ZstdCompressor(level=1, write_checksum=True)
   351         with cctx.write_to(with_checksum) as compressor:
   453         with cctx.write_to(with_checksum) as compressor:
   352             compressor.write(b'foobar')
   454             self.assertEqual(compressor.write(b'foobar'), 0)
       
   455 
       
   456         no_params = zstd.get_frame_parameters(no_checksum.getvalue())
       
   457         with_params = zstd.get_frame_parameters(with_checksum.getvalue())
       
   458         self.assertEqual(no_params.content_size, 0)
       
   459         self.assertEqual(with_params.content_size, 0)
       
   460         self.assertEqual(no_params.dict_id, 0)
       
   461         self.assertEqual(with_params.dict_id, 0)
       
   462         self.assertFalse(no_params.has_checksum)
       
   463         self.assertTrue(with_params.has_checksum)
   353 
   464 
   354         self.assertEqual(len(with_checksum.getvalue()),
   465         self.assertEqual(len(with_checksum.getvalue()),
   355                          len(no_checksum.getvalue()) + 4)
   466                          len(no_checksum.getvalue()) + 4)
   356 
   467 
   357     def test_write_content_size(self):
   468     def test_write_content_size(self):
   358         no_size = io.BytesIO()
   469         no_size = io.BytesIO()
   359         cctx = zstd.ZstdCompressor(level=1)
   470         cctx = zstd.ZstdCompressor(level=1)
   360         with cctx.write_to(no_size) as compressor:
   471         with cctx.write_to(no_size) as compressor:
   361             compressor.write(b'foobar' * 256)
   472             self.assertEqual(compressor.write(b'foobar' * 256), 0)
   362 
   473 
   363         with_size = io.BytesIO()
   474         with_size = io.BytesIO()
   364         cctx = zstd.ZstdCompressor(level=1, write_content_size=True)
   475         cctx = zstd.ZstdCompressor(level=1, write_content_size=True)
   365         with cctx.write_to(with_size) as compressor:
   476         with cctx.write_to(with_size) as compressor:
   366             compressor.write(b'foobar' * 256)
   477             self.assertEqual(compressor.write(b'foobar' * 256), 0)
   367 
   478 
   368         # Source size is not known in streaming mode, so header not
   479         # Source size is not known in streaming mode, so header not
   369         # written.
   480         # written.
   370         self.assertEqual(len(with_size.getvalue()),
   481         self.assertEqual(len(with_size.getvalue()),
   371                          len(no_size.getvalue()))
   482                          len(no_size.getvalue()))
   372 
   483 
   373         # Declaring size will write the header.
   484         # Declaring size will write the header.
   374         with_size = io.BytesIO()
   485         with_size = io.BytesIO()
   375         with cctx.write_to(with_size, size=len(b'foobar' * 256)) as compressor:
   486         with cctx.write_to(with_size, size=len(b'foobar' * 256)) as compressor:
   376             compressor.write(b'foobar' * 256)
   487             self.assertEqual(compressor.write(b'foobar' * 256), 0)
       
   488 
       
   489         no_params = zstd.get_frame_parameters(no_size.getvalue())
       
   490         with_params = zstd.get_frame_parameters(with_size.getvalue())
       
   491         self.assertEqual(no_params.content_size, 0)
       
   492         self.assertEqual(with_params.content_size, 1536)
       
   493         self.assertEqual(no_params.dict_id, 0)
       
   494         self.assertEqual(with_params.dict_id, 0)
       
   495         self.assertFalse(no_params.has_checksum)
       
   496         self.assertFalse(with_params.has_checksum)
   377 
   497 
   378         self.assertEqual(len(with_size.getvalue()),
   498         self.assertEqual(len(with_size.getvalue()),
   379                          len(no_size.getvalue()) + 1)
   499                          len(no_size.getvalue()) + 1)
   380 
   500 
   381     def test_no_dict_id(self):
   501     def test_no_dict_id(self):
   388         d = zstd.train_dictionary(1024, samples)
   508         d = zstd.train_dictionary(1024, samples)
   389 
   509 
   390         with_dict_id = io.BytesIO()
   510         with_dict_id = io.BytesIO()
   391         cctx = zstd.ZstdCompressor(level=1, dict_data=d)
   511         cctx = zstd.ZstdCompressor(level=1, dict_data=d)
   392         with cctx.write_to(with_dict_id) as compressor:
   512         with cctx.write_to(with_dict_id) as compressor:
   393             compressor.write(b'foobarfoobar')
   513             self.assertEqual(compressor.write(b'foobarfoobar'), 0)
   394 
   514 
   395         cctx = zstd.ZstdCompressor(level=1, dict_data=d, write_dict_id=False)
   515         cctx = zstd.ZstdCompressor(level=1, dict_data=d, write_dict_id=False)
   396         no_dict_id = io.BytesIO()
   516         no_dict_id = io.BytesIO()
   397         with cctx.write_to(no_dict_id) as compressor:
   517         with cctx.write_to(no_dict_id) as compressor:
   398             compressor.write(b'foobarfoobar')
   518             self.assertEqual(compressor.write(b'foobarfoobar'), 0)
       
   519 
       
   520         no_params = zstd.get_frame_parameters(no_dict_id.getvalue())
       
   521         with_params = zstd.get_frame_parameters(with_dict_id.getvalue())
       
   522         self.assertEqual(no_params.content_size, 0)
       
   523         self.assertEqual(with_params.content_size, 0)
       
   524         self.assertEqual(no_params.dict_id, 0)
       
   525         self.assertEqual(with_params.dict_id, d.dict_id())
       
   526         self.assertFalse(no_params.has_checksum)
       
   527         self.assertFalse(with_params.has_checksum)
   399 
   528 
   400         self.assertEqual(len(with_dict_id.getvalue()),
   529         self.assertEqual(len(with_dict_id.getvalue()),
   401                          len(no_dict_id.getvalue()) + 4)
   530                          len(no_dict_id.getvalue()) + 4)
   402 
   531 
   403     def test_memory_size(self):
   532     def test_memory_size(self):
   410 
   539 
   411     def test_write_size(self):
   540     def test_write_size(self):
   412         cctx = zstd.ZstdCompressor(level=3)
   541         cctx = zstd.ZstdCompressor(level=3)
   413         dest = OpCountingBytesIO()
   542         dest = OpCountingBytesIO()
   414         with cctx.write_to(dest, write_size=1) as compressor:
   543         with cctx.write_to(dest, write_size=1) as compressor:
   415             compressor.write(b'foo')
   544             self.assertEqual(compressor.write(b'foo'), 0)
   416             compressor.write(b'bar')
   545             self.assertEqual(compressor.write(b'bar'), 0)
   417             compressor.write(b'foobar')
   546             self.assertEqual(compressor.write(b'foobar'), 0)
   418 
   547 
   419         self.assertEqual(len(dest.getvalue()), dest._write_count)
   548         self.assertEqual(len(dest.getvalue()), dest._write_count)
   420 
   549 
   421     def test_flush_repeated(self):
   550     def test_flush_repeated(self):
   422         cctx = zstd.ZstdCompressor(level=3)
   551         cctx = zstd.ZstdCompressor(level=3)
   423         dest = OpCountingBytesIO()
   552         dest = OpCountingBytesIO()
   424         with cctx.write_to(dest) as compressor:
   553         with cctx.write_to(dest) as compressor:
   425             compressor.write(b'foo')
   554             self.assertEqual(compressor.write(b'foo'), 0)
   426             self.assertEqual(dest._write_count, 0)
   555             self.assertEqual(dest._write_count, 0)
   427             compressor.flush()
   556             self.assertEqual(compressor.flush(), 12)
   428             self.assertEqual(dest._write_count, 1)
   557             self.assertEqual(dest._write_count, 1)
   429             compressor.write(b'bar')
   558             self.assertEqual(compressor.write(b'bar'), 0)
   430             self.assertEqual(dest._write_count, 1)
   559             self.assertEqual(dest._write_count, 1)
   431             compressor.flush()
   560             self.assertEqual(compressor.flush(), 6)
   432             self.assertEqual(dest._write_count, 2)
   561             self.assertEqual(dest._write_count, 2)
   433             compressor.write(b'baz')
   562             self.assertEqual(compressor.write(b'baz'), 0)
   434 
   563 
   435         self.assertEqual(dest._write_count, 3)
   564         self.assertEqual(dest._write_count, 3)
   436 
   565 
   437     def test_flush_empty_block(self):
   566     def test_flush_empty_block(self):
   438         cctx = zstd.ZstdCompressor(level=3, write_checksum=True)
   567         cctx = zstd.ZstdCompressor(level=3, write_checksum=True)
   439         dest = OpCountingBytesIO()
   568         dest = OpCountingBytesIO()
   440         with cctx.write_to(dest) as compressor:
   569         with cctx.write_to(dest) as compressor:
   441             compressor.write(b'foobar' * 8192)
   570             self.assertEqual(compressor.write(b'foobar' * 8192), 0)
   442             count = dest._write_count
   571             count = dest._write_count
   443             offset = dest.tell()
   572             offset = dest.tell()
   444             compressor.flush()
   573             self.assertEqual(compressor.flush(), 23)
   445             self.assertGreater(dest._write_count, count)
   574             self.assertGreater(dest._write_count, count)
   446             self.assertGreater(dest.tell(), offset)
   575             self.assertGreater(dest.tell(), offset)
   447             offset = dest.tell()
   576             offset = dest.tell()
   448             # Ending the write here should cause an empty block to be written
   577             # Ending the write here should cause an empty block to be written
   449             # to denote end of frame.
   578             # to denote end of frame.
   454 
   583 
   455         header = trailing[0:3]
   584         header = trailing[0:3]
   456         self.assertEqual(header, b'\x01\x00\x00')
   585         self.assertEqual(header, b'\x01\x00\x00')
   457 
   586 
   458 
   587 
       
   588 @make_cffi
   459 class TestCompressor_read_from(unittest.TestCase):
   589 class TestCompressor_read_from(unittest.TestCase):
   460     def test_type_validation(self):
   590     def test_type_validation(self):
   461         cctx = zstd.ZstdCompressor()
   591         cctx = zstd.ZstdCompressor()
   462 
   592 
   463         # Object with read() works.
   593         # Object with read() works.
   464         cctx.read_from(io.BytesIO())
   594         for chunk in cctx.read_from(io.BytesIO()):
       
   595             pass
   465 
   596 
   466         # Buffer protocol works.
   597         # Buffer protocol works.
   467         cctx.read_from(b'foobar')
   598         for chunk in cctx.read_from(b'foobar'):
       
   599             pass
   468 
   600 
   469         with self.assertRaisesRegexp(ValueError, 'must pass an object with a read'):
   601         with self.assertRaisesRegexp(ValueError, 'must pass an object with a read'):
   470             cctx.read_from(True)
   602             for chunk in cctx.read_from(True):
       
   603                 pass
   471 
   604 
   472     def test_read_empty(self):
   605     def test_read_empty(self):
   473         cctx = zstd.ZstdCompressor(level=1)
   606         cctx = zstd.ZstdCompressor(level=1)
   474 
   607 
   475         source = io.BytesIO()
   608         source = io.BytesIO()
   519             next(it)
   652             next(it)
   520 
   653 
   521         # We should get the same output as the one-shot compression mechanism.
   654         # We should get the same output as the one-shot compression mechanism.
   522         self.assertEqual(b''.join(chunks), cctx.compress(source.getvalue()))
   655         self.assertEqual(b''.join(chunks), cctx.compress(source.getvalue()))
   523 
   656 
       
   657         params = zstd.get_frame_parameters(b''.join(chunks))
       
   658         self.assertEqual(params.content_size, 0)
       
   659         self.assertEqual(params.window_size, 262144)
       
   660         self.assertEqual(params.dict_id, 0)
       
   661         self.assertFalse(params.has_checksum)
       
   662 
   524         # Now check the buffer protocol.
   663         # Now check the buffer protocol.
   525         it = cctx.read_from(source.getvalue())
   664         it = cctx.read_from(source.getvalue())
   526         chunks = list(it)
   665         chunks = list(it)
   527         self.assertEqual(len(chunks), 2)
   666         self.assertEqual(len(chunks), 2)
   528         self.assertEqual(b''.join(chunks), cctx.compress(source.getvalue()))
   667         self.assertEqual(b''.join(chunks), cctx.compress(source.getvalue()))