--- a/contrib/python-zstandard/tests/test_train_dictionary.py Sun Apr 08 01:08:43 2018 +0200
+++ b/contrib/python-zstandard/tests/test_train_dictionary.py Mon Apr 09 10:13:29 2018 -0700
@@ -1,13 +1,11 @@
+import struct
import sys
+import unittest
-try:
- import unittest2 as unittest
-except ImportError:
- import unittest
-
-import zstd
+import zstandard as zstd
from . common import (
+ generate_samples,
make_cffi,
)
@@ -30,55 +28,18 @@
with self.assertRaises(ValueError):
zstd.train_dictionary(8192, [u'foo'])
- def test_basic(self):
- samples = []
- for i in range(128):
- samples.append(b'foo' * 64)
- samples.append(b'bar' * 64)
- samples.append(b'foobar' * 64)
- samples.append(b'baz' * 64)
- samples.append(b'foobaz' * 64)
- samples.append(b'bazfoo' * 64)
+ def test_no_params(self):
+ d = zstd.train_dictionary(8192, generate_samples())
+ self.assertIsInstance(d.dict_id(), int_type)
- d = zstd.train_dictionary(8192, samples)
- self.assertLessEqual(len(d), 8192)
-
- dict_id = d.dict_id()
- self.assertIsInstance(dict_id, int_type)
+ # The dictionary ID may be different across platforms.
+ expected = b'\x37\xa4\x30\xec' + struct.pack('<I', d.dict_id())
data = d.as_bytes()
- self.assertEqual(data[0:4], b'\x37\xa4\x30\xec')
-
- def test_set_dict_id(self):
- samples = []
- for i in range(128):
- samples.append(b'foo' * 64)
- samples.append(b'foobar' * 64)
-
- d = zstd.train_dictionary(8192, samples, dict_id=42)
- self.assertEqual(d.dict_id(), 42)
-
-
-@make_cffi
-class TestTrainCoverDictionary(unittest.TestCase):
- def test_no_args(self):
- with self.assertRaises(TypeError):
- zstd.train_cover_dictionary()
-
- def test_bad_args(self):
- with self.assertRaises(TypeError):
- zstd.train_cover_dictionary(8192, u'foo')
-
- with self.assertRaises(ValueError):
- zstd.train_cover_dictionary(8192, [u'foo'])
+ self.assertEqual(data[0:8], expected)
def test_basic(self):
- samples = []
- for i in range(128):
- samples.append(b'foo' * 64)
- samples.append(b'foobar' * 64)
-
- d = zstd.train_cover_dictionary(8192, samples, k=64, d=16)
+ d = zstd.train_dictionary(8192, generate_samples(), k=64, d=16)
self.assertIsInstance(d.dict_id(), int_type)
data = d.as_bytes()
@@ -88,23 +49,39 @@
self.assertEqual(d.d, 16)
def test_set_dict_id(self):
- samples = []
- for i in range(128):
- samples.append(b'foo' * 64)
- samples.append(b'foobar' * 64)
-
- d = zstd.train_cover_dictionary(8192, samples, k=64, d=16,
- dict_id=42)
+ d = zstd.train_dictionary(8192, generate_samples(), k=64, d=16,
+ dict_id=42)
self.assertEqual(d.dict_id(), 42)
def test_optimize(self):
- samples = []
- for i in range(128):
- samples.append(b'foo' * 64)
- samples.append(b'foobar' * 64)
+ d = zstd.train_dictionary(8192, generate_samples(), threads=-1, steps=1,
+ d=16)
+
+ self.assertEqual(d.k, 50)
+ self.assertEqual(d.d, 16)
+
+@make_cffi
+class TestCompressionDict(unittest.TestCase):
+ def test_bad_mode(self):
+ with self.assertRaisesRegexp(ValueError, 'invalid dictionary load mode'):
+ zstd.ZstdCompressionDict(b'foo', dict_type=42)
+
+ def test_bad_precompute_compress(self):
+ d = zstd.train_dictionary(8192, generate_samples(), k=64, d=16)
- d = zstd.train_cover_dictionary(8192, samples, optimize=True,
- threads=-1, steps=1, d=16)
+ with self.assertRaisesRegexp(ValueError, 'must specify one of level or '):
+ d.precompute_compress()
+
+ with self.assertRaisesRegexp(ValueError, 'must only specify one of level or '):
+ d.precompute_compress(level=3,
+ compression_params=zstd.CompressionParameters())
- self.assertEqual(d.k, 16)
- self.assertEqual(d.d, 16)
+ def test_precompute_compress_rawcontent(self):
+ d = zstd.ZstdCompressionDict(b'dictcontent' * 64,
+ dict_type=zstd.DICT_TYPE_RAWCONTENT)
+ d.precompute_compress(level=1)
+
+ d = zstd.ZstdCompressionDict(b'dictcontent' * 64,
+ dict_type=zstd.DICT_TYPE_FULLDICT)
+ with self.assertRaisesRegexp(zstd.ZstdError, 'unable to precompute dictionary'):
+ d.precompute_compress(level=1)