|
1 import hashlib |
|
2 import io |
|
3 import struct |
|
4 import sys |
|
5 |
|
6 try: |
|
7 import unittest2 as unittest |
|
8 except ImportError: |
|
9 import unittest |
|
10 |
|
11 import zstd |
|
12 |
|
13 from .common import OpCountingBytesIO |
|
14 |
|
15 |
|
16 if sys.version_info[0] >= 3: |
|
17 next = lambda it: it.__next__() |
|
18 else: |
|
19 next = lambda it: it.next() |
|
20 |
|
21 |
|
22 class TestCompressor(unittest.TestCase): |
|
23 def test_level_bounds(self): |
|
24 with self.assertRaises(ValueError): |
|
25 zstd.ZstdCompressor(level=0) |
|
26 |
|
27 with self.assertRaises(ValueError): |
|
28 zstd.ZstdCompressor(level=23) |
|
29 |
|
30 |
|
31 class TestCompressor_compress(unittest.TestCase): |
|
32 def test_compress_empty(self): |
|
33 cctx = zstd.ZstdCompressor(level=1) |
|
34 cctx.compress(b'') |
|
35 |
|
36 cctx = zstd.ZstdCompressor(level=22) |
|
37 cctx.compress(b'') |
|
38 |
|
39 def test_compress_empty(self): |
|
40 cctx = zstd.ZstdCompressor(level=1) |
|
41 self.assertEqual(cctx.compress(b''), |
|
42 b'\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00') |
|
43 |
|
44 def test_compress_large(self): |
|
45 chunks = [] |
|
46 for i in range(255): |
|
47 chunks.append(struct.Struct('>B').pack(i) * 16384) |
|
48 |
|
49 cctx = zstd.ZstdCompressor(level=3) |
|
50 result = cctx.compress(b''.join(chunks)) |
|
51 self.assertEqual(len(result), 999) |
|
52 self.assertEqual(result[0:4], b'\x28\xb5\x2f\xfd') |
|
53 |
|
54 def test_write_checksum(self): |
|
55 cctx = zstd.ZstdCompressor(level=1) |
|
56 no_checksum = cctx.compress(b'foobar') |
|
57 cctx = zstd.ZstdCompressor(level=1, write_checksum=True) |
|
58 with_checksum = cctx.compress(b'foobar') |
|
59 |
|
60 self.assertEqual(len(with_checksum), len(no_checksum) + 4) |
|
61 |
|
62 def test_write_content_size(self): |
|
63 cctx = zstd.ZstdCompressor(level=1) |
|
64 no_size = cctx.compress(b'foobar' * 256) |
|
65 cctx = zstd.ZstdCompressor(level=1, write_content_size=True) |
|
66 with_size = cctx.compress(b'foobar' * 256) |
|
67 |
|
68 self.assertEqual(len(with_size), len(no_size) + 1) |
|
69 |
|
70 def test_no_dict_id(self): |
|
71 samples = [] |
|
72 for i in range(128): |
|
73 samples.append(b'foo' * 64) |
|
74 samples.append(b'bar' * 64) |
|
75 samples.append(b'foobar' * 64) |
|
76 |
|
77 d = zstd.train_dictionary(1024, samples) |
|
78 |
|
79 cctx = zstd.ZstdCompressor(level=1, dict_data=d) |
|
80 with_dict_id = cctx.compress(b'foobarfoobar') |
|
81 |
|
82 cctx = zstd.ZstdCompressor(level=1, dict_data=d, write_dict_id=False) |
|
83 no_dict_id = cctx.compress(b'foobarfoobar') |
|
84 |
|
85 self.assertEqual(len(with_dict_id), len(no_dict_id) + 4) |
|
86 |
|
87 def test_compress_dict_multiple(self): |
|
88 samples = [] |
|
89 for i in range(128): |
|
90 samples.append(b'foo' * 64) |
|
91 samples.append(b'bar' * 64) |
|
92 samples.append(b'foobar' * 64) |
|
93 |
|
94 d = zstd.train_dictionary(8192, samples) |
|
95 |
|
96 cctx = zstd.ZstdCompressor(level=1, dict_data=d) |
|
97 |
|
98 for i in range(32): |
|
99 cctx.compress(b'foo bar foobar foo bar foobar') |
|
100 |
|
101 |
|
102 class TestCompressor_compressobj(unittest.TestCase): |
|
103 def test_compressobj_empty(self): |
|
104 cctx = zstd.ZstdCompressor(level=1) |
|
105 cobj = cctx.compressobj() |
|
106 self.assertEqual(cobj.compress(b''), b'') |
|
107 self.assertEqual(cobj.flush(), |
|
108 b'\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00') |
|
109 |
|
110 def test_compressobj_large(self): |
|
111 chunks = [] |
|
112 for i in range(255): |
|
113 chunks.append(struct.Struct('>B').pack(i) * 16384) |
|
114 |
|
115 cctx = zstd.ZstdCompressor(level=3) |
|
116 cobj = cctx.compressobj() |
|
117 |
|
118 result = cobj.compress(b''.join(chunks)) + cobj.flush() |
|
119 self.assertEqual(len(result), 999) |
|
120 self.assertEqual(result[0:4], b'\x28\xb5\x2f\xfd') |
|
121 |
|
122 def test_write_checksum(self): |
|
123 cctx = zstd.ZstdCompressor(level=1) |
|
124 cobj = cctx.compressobj() |
|
125 no_checksum = cobj.compress(b'foobar') + cobj.flush() |
|
126 cctx = zstd.ZstdCompressor(level=1, write_checksum=True) |
|
127 cobj = cctx.compressobj() |
|
128 with_checksum = cobj.compress(b'foobar') + cobj.flush() |
|
129 |
|
130 self.assertEqual(len(with_checksum), len(no_checksum) + 4) |
|
131 |
|
132 def test_write_content_size(self): |
|
133 cctx = zstd.ZstdCompressor(level=1) |
|
134 cobj = cctx.compressobj(size=len(b'foobar' * 256)) |
|
135 no_size = cobj.compress(b'foobar' * 256) + cobj.flush() |
|
136 cctx = zstd.ZstdCompressor(level=1, write_content_size=True) |
|
137 cobj = cctx.compressobj(size=len(b'foobar' * 256)) |
|
138 with_size = cobj.compress(b'foobar' * 256) + cobj.flush() |
|
139 |
|
140 self.assertEqual(len(with_size), len(no_size) + 1) |
|
141 |
|
142 def test_compress_after_flush(self): |
|
143 cctx = zstd.ZstdCompressor() |
|
144 cobj = cctx.compressobj() |
|
145 |
|
146 cobj.compress(b'foo') |
|
147 cobj.flush() |
|
148 |
|
149 with self.assertRaisesRegexp(zstd.ZstdError, 'cannot call compress\(\) after flush'): |
|
150 cobj.compress(b'foo') |
|
151 |
|
152 with self.assertRaisesRegexp(zstd.ZstdError, 'flush\(\) already called'): |
|
153 cobj.flush() |
|
154 |
|
155 |
|
156 class TestCompressor_copy_stream(unittest.TestCase): |
|
157 def test_no_read(self): |
|
158 source = object() |
|
159 dest = io.BytesIO() |
|
160 |
|
161 cctx = zstd.ZstdCompressor() |
|
162 with self.assertRaises(ValueError): |
|
163 cctx.copy_stream(source, dest) |
|
164 |
|
165 def test_no_write(self): |
|
166 source = io.BytesIO() |
|
167 dest = object() |
|
168 |
|
169 cctx = zstd.ZstdCompressor() |
|
170 with self.assertRaises(ValueError): |
|
171 cctx.copy_stream(source, dest) |
|
172 |
|
173 def test_empty(self): |
|
174 source = io.BytesIO() |
|
175 dest = io.BytesIO() |
|
176 |
|
177 cctx = zstd.ZstdCompressor(level=1) |
|
178 r, w = cctx.copy_stream(source, dest) |
|
179 self.assertEqual(int(r), 0) |
|
180 self.assertEqual(w, 9) |
|
181 |
|
182 self.assertEqual(dest.getvalue(), |
|
183 b'\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00') |
|
184 |
|
185 def test_large_data(self): |
|
186 source = io.BytesIO() |
|
187 for i in range(255): |
|
188 source.write(struct.Struct('>B').pack(i) * 16384) |
|
189 source.seek(0) |
|
190 |
|
191 dest = io.BytesIO() |
|
192 cctx = zstd.ZstdCompressor() |
|
193 r, w = cctx.copy_stream(source, dest) |
|
194 |
|
195 self.assertEqual(r, 255 * 16384) |
|
196 self.assertEqual(w, 999) |
|
197 |
|
198 def test_write_checksum(self): |
|
199 source = io.BytesIO(b'foobar') |
|
200 no_checksum = io.BytesIO() |
|
201 |
|
202 cctx = zstd.ZstdCompressor(level=1) |
|
203 cctx.copy_stream(source, no_checksum) |
|
204 |
|
205 source.seek(0) |
|
206 with_checksum = io.BytesIO() |
|
207 cctx = zstd.ZstdCompressor(level=1, write_checksum=True) |
|
208 cctx.copy_stream(source, with_checksum) |
|
209 |
|
210 self.assertEqual(len(with_checksum.getvalue()), |
|
211 len(no_checksum.getvalue()) + 4) |
|
212 |
|
213 def test_write_content_size(self): |
|
214 source = io.BytesIO(b'foobar' * 256) |
|
215 no_size = io.BytesIO() |
|
216 |
|
217 cctx = zstd.ZstdCompressor(level=1) |
|
218 cctx.copy_stream(source, no_size) |
|
219 |
|
220 source.seek(0) |
|
221 with_size = io.BytesIO() |
|
222 cctx = zstd.ZstdCompressor(level=1, write_content_size=True) |
|
223 cctx.copy_stream(source, with_size) |
|
224 |
|
225 # Source content size is unknown, so no content size written. |
|
226 self.assertEqual(len(with_size.getvalue()), |
|
227 len(no_size.getvalue())) |
|
228 |
|
229 source.seek(0) |
|
230 with_size = io.BytesIO() |
|
231 cctx.copy_stream(source, with_size, size=len(source.getvalue())) |
|
232 |
|
233 # We specified source size, so content size header is present. |
|
234 self.assertEqual(len(with_size.getvalue()), |
|
235 len(no_size.getvalue()) + 1) |
|
236 |
|
237 def test_read_write_size(self): |
|
238 source = OpCountingBytesIO(b'foobarfoobar') |
|
239 dest = OpCountingBytesIO() |
|
240 cctx = zstd.ZstdCompressor() |
|
241 r, w = cctx.copy_stream(source, dest, read_size=1, write_size=1) |
|
242 |
|
243 self.assertEqual(r, len(source.getvalue())) |
|
244 self.assertEqual(w, 21) |
|
245 self.assertEqual(source._read_count, len(source.getvalue()) + 1) |
|
246 self.assertEqual(dest._write_count, len(dest.getvalue())) |
|
247 |
|
248 |
|
249 def compress(data, level): |
|
250 buffer = io.BytesIO() |
|
251 cctx = zstd.ZstdCompressor(level=level) |
|
252 with cctx.write_to(buffer) as compressor: |
|
253 compressor.write(data) |
|
254 return buffer.getvalue() |
|
255 |
|
256 |
|
257 class TestCompressor_write_to(unittest.TestCase): |
|
258 def test_empty(self): |
|
259 self.assertEqual(compress(b'', 1), |
|
260 b'\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00') |
|
261 |
|
262 def test_multiple_compress(self): |
|
263 buffer = io.BytesIO() |
|
264 cctx = zstd.ZstdCompressor(level=5) |
|
265 with cctx.write_to(buffer) as compressor: |
|
266 compressor.write(b'foo') |
|
267 compressor.write(b'bar') |
|
268 compressor.write(b'x' * 8192) |
|
269 |
|
270 result = buffer.getvalue() |
|
271 self.assertEqual(result, |
|
272 b'\x28\xb5\x2f\xfd\x00\x50\x75\x00\x00\x38\x66\x6f' |
|
273 b'\x6f\x62\x61\x72\x78\x01\x00\xfc\xdf\x03\x23') |
|
274 |
|
275 def test_dictionary(self): |
|
276 samples = [] |
|
277 for i in range(128): |
|
278 samples.append(b'foo' * 64) |
|
279 samples.append(b'bar' * 64) |
|
280 samples.append(b'foobar' * 64) |
|
281 |
|
282 d = zstd.train_dictionary(8192, samples) |
|
283 |
|
284 buffer = io.BytesIO() |
|
285 cctx = zstd.ZstdCompressor(level=9, dict_data=d) |
|
286 with cctx.write_to(buffer) as compressor: |
|
287 compressor.write(b'foo') |
|
288 compressor.write(b'bar') |
|
289 compressor.write(b'foo' * 16384) |
|
290 |
|
291 compressed = buffer.getvalue() |
|
292 h = hashlib.sha1(compressed).hexdigest() |
|
293 self.assertEqual(h, '1c5bcd25181bcd8c1a73ea8773323e0056129f92') |
|
294 |
|
295 def test_compression_params(self): |
|
296 params = zstd.CompressionParameters(20, 6, 12, 5, 4, 10, zstd.STRATEGY_FAST) |
|
297 |
|
298 buffer = io.BytesIO() |
|
299 cctx = zstd.ZstdCompressor(compression_params=params) |
|
300 with cctx.write_to(buffer) as compressor: |
|
301 compressor.write(b'foo') |
|
302 compressor.write(b'bar') |
|
303 compressor.write(b'foobar' * 16384) |
|
304 |
|
305 compressed = buffer.getvalue() |
|
306 h = hashlib.sha1(compressed).hexdigest() |
|
307 self.assertEqual(h, '1ae31f270ed7de14235221a604b31ecd517ebd99') |
|
308 |
|
309 def test_write_checksum(self): |
|
310 no_checksum = io.BytesIO() |
|
311 cctx = zstd.ZstdCompressor(level=1) |
|
312 with cctx.write_to(no_checksum) as compressor: |
|
313 compressor.write(b'foobar') |
|
314 |
|
315 with_checksum = io.BytesIO() |
|
316 cctx = zstd.ZstdCompressor(level=1, write_checksum=True) |
|
317 with cctx.write_to(with_checksum) as compressor: |
|
318 compressor.write(b'foobar') |
|
319 |
|
320 self.assertEqual(len(with_checksum.getvalue()), |
|
321 len(no_checksum.getvalue()) + 4) |
|
322 |
|
323 def test_write_content_size(self): |
|
324 no_size = io.BytesIO() |
|
325 cctx = zstd.ZstdCompressor(level=1) |
|
326 with cctx.write_to(no_size) as compressor: |
|
327 compressor.write(b'foobar' * 256) |
|
328 |
|
329 with_size = io.BytesIO() |
|
330 cctx = zstd.ZstdCompressor(level=1, write_content_size=True) |
|
331 with cctx.write_to(with_size) as compressor: |
|
332 compressor.write(b'foobar' * 256) |
|
333 |
|
334 # Source size is not known in streaming mode, so header not |
|
335 # written. |
|
336 self.assertEqual(len(with_size.getvalue()), |
|
337 len(no_size.getvalue())) |
|
338 |
|
339 # Declaring size will write the header. |
|
340 with_size = io.BytesIO() |
|
341 with cctx.write_to(with_size, size=len(b'foobar' * 256)) as compressor: |
|
342 compressor.write(b'foobar' * 256) |
|
343 |
|
344 self.assertEqual(len(with_size.getvalue()), |
|
345 len(no_size.getvalue()) + 1) |
|
346 |
|
347 def test_no_dict_id(self): |
|
348 samples = [] |
|
349 for i in range(128): |
|
350 samples.append(b'foo' * 64) |
|
351 samples.append(b'bar' * 64) |
|
352 samples.append(b'foobar' * 64) |
|
353 |
|
354 d = zstd.train_dictionary(1024, samples) |
|
355 |
|
356 with_dict_id = io.BytesIO() |
|
357 cctx = zstd.ZstdCompressor(level=1, dict_data=d) |
|
358 with cctx.write_to(with_dict_id) as compressor: |
|
359 compressor.write(b'foobarfoobar') |
|
360 |
|
361 cctx = zstd.ZstdCompressor(level=1, dict_data=d, write_dict_id=False) |
|
362 no_dict_id = io.BytesIO() |
|
363 with cctx.write_to(no_dict_id) as compressor: |
|
364 compressor.write(b'foobarfoobar') |
|
365 |
|
366 self.assertEqual(len(with_dict_id.getvalue()), |
|
367 len(no_dict_id.getvalue()) + 4) |
|
368 |
|
369 def test_memory_size(self): |
|
370 cctx = zstd.ZstdCompressor(level=3) |
|
371 buffer = io.BytesIO() |
|
372 with cctx.write_to(buffer) as compressor: |
|
373 size = compressor.memory_size() |
|
374 |
|
375 self.assertGreater(size, 100000) |
|
376 |
|
377 def test_write_size(self): |
|
378 cctx = zstd.ZstdCompressor(level=3) |
|
379 dest = OpCountingBytesIO() |
|
380 with cctx.write_to(dest, write_size=1) as compressor: |
|
381 compressor.write(b'foo') |
|
382 compressor.write(b'bar') |
|
383 compressor.write(b'foobar') |
|
384 |
|
385 self.assertEqual(len(dest.getvalue()), dest._write_count) |
|
386 |
|
387 |
|
388 class TestCompressor_read_from(unittest.TestCase): |
|
389 def test_type_validation(self): |
|
390 cctx = zstd.ZstdCompressor() |
|
391 |
|
392 # Object with read() works. |
|
393 cctx.read_from(io.BytesIO()) |
|
394 |
|
395 # Buffer protocol works. |
|
396 cctx.read_from(b'foobar') |
|
397 |
|
398 with self.assertRaisesRegexp(ValueError, 'must pass an object with a read'): |
|
399 cctx.read_from(True) |
|
400 |
|
401 def test_read_empty(self): |
|
402 cctx = zstd.ZstdCompressor(level=1) |
|
403 |
|
404 source = io.BytesIO() |
|
405 it = cctx.read_from(source) |
|
406 chunks = list(it) |
|
407 self.assertEqual(len(chunks), 1) |
|
408 compressed = b''.join(chunks) |
|
409 self.assertEqual(compressed, b'\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00') |
|
410 |
|
411 # And again with the buffer protocol. |
|
412 it = cctx.read_from(b'') |
|
413 chunks = list(it) |
|
414 self.assertEqual(len(chunks), 1) |
|
415 compressed2 = b''.join(chunks) |
|
416 self.assertEqual(compressed2, compressed) |
|
417 |
|
418 def test_read_large(self): |
|
419 cctx = zstd.ZstdCompressor(level=1) |
|
420 |
|
421 source = io.BytesIO() |
|
422 source.write(b'f' * zstd.COMPRESSION_RECOMMENDED_INPUT_SIZE) |
|
423 source.write(b'o') |
|
424 source.seek(0) |
|
425 |
|
426 # Creating an iterator should not perform any compression until |
|
427 # first read. |
|
428 it = cctx.read_from(source, size=len(source.getvalue())) |
|
429 self.assertEqual(source.tell(), 0) |
|
430 |
|
431 # We should have exactly 2 output chunks. |
|
432 chunks = [] |
|
433 chunk = next(it) |
|
434 self.assertIsNotNone(chunk) |
|
435 self.assertEqual(source.tell(), zstd.COMPRESSION_RECOMMENDED_INPUT_SIZE) |
|
436 chunks.append(chunk) |
|
437 chunk = next(it) |
|
438 self.assertIsNotNone(chunk) |
|
439 chunks.append(chunk) |
|
440 |
|
441 self.assertEqual(source.tell(), len(source.getvalue())) |
|
442 |
|
443 with self.assertRaises(StopIteration): |
|
444 next(it) |
|
445 |
|
446 # And again for good measure. |
|
447 with self.assertRaises(StopIteration): |
|
448 next(it) |
|
449 |
|
450 # We should get the same output as the one-shot compression mechanism. |
|
451 self.assertEqual(b''.join(chunks), cctx.compress(source.getvalue())) |
|
452 |
|
453 # Now check the buffer protocol. |
|
454 it = cctx.read_from(source.getvalue()) |
|
455 chunks = list(it) |
|
456 self.assertEqual(len(chunks), 2) |
|
457 self.assertEqual(b''.join(chunks), cctx.compress(source.getvalue())) |
|
458 |
|
459 def test_read_write_size(self): |
|
460 source = OpCountingBytesIO(b'foobarfoobar') |
|
461 cctx = zstd.ZstdCompressor(level=3) |
|
462 for chunk in cctx.read_from(source, read_size=1, write_size=1): |
|
463 self.assertEqual(len(chunk), 1) |
|
464 |
|
465 self.assertEqual(source._read_count, len(source.getvalue()) + 1) |