contrib/python-zstandard/tests/test_train_dictionary.py
changeset 37495 b1fb341d8a61
parent 31796 e0dc40530c5a
child 40121 73fef626dae3
--- a/contrib/python-zstandard/tests/test_train_dictionary.py	Sun Apr 08 01:08:43 2018 +0200
+++ b/contrib/python-zstandard/tests/test_train_dictionary.py	Mon Apr 09 10:13:29 2018 -0700
@@ -1,13 +1,11 @@
+import struct
 import sys
+import unittest
 
-try:
-    import unittest2 as unittest
-except ImportError:
-    import unittest
-
-import zstd
+import zstandard as zstd
 
 from . common import (
+    generate_samples,
     make_cffi,
 )
 
@@ -30,55 +28,18 @@
         with self.assertRaises(ValueError):
             zstd.train_dictionary(8192, [u'foo'])
 
-    def test_basic(self):
-        samples = []
-        for i in range(128):
-            samples.append(b'foo' * 64)
-            samples.append(b'bar' * 64)
-            samples.append(b'foobar' * 64)
-            samples.append(b'baz' * 64)
-            samples.append(b'foobaz' * 64)
-            samples.append(b'bazfoo' * 64)
+    def test_no_params(self):
+        d = zstd.train_dictionary(8192, generate_samples())
+        self.assertIsInstance(d.dict_id(), int_type)
 
-        d = zstd.train_dictionary(8192, samples)
-        self.assertLessEqual(len(d), 8192)
-
-        dict_id = d.dict_id()
-        self.assertIsInstance(dict_id, int_type)
+        # The dictionary ID may be different across platforms.
+        expected = b'\x37\xa4\x30\xec' + struct.pack('<I', d.dict_id())
 
         data = d.as_bytes()
-        self.assertEqual(data[0:4], b'\x37\xa4\x30\xec')
-
-    def test_set_dict_id(self):
-        samples = []
-        for i in range(128):
-            samples.append(b'foo' * 64)
-            samples.append(b'foobar' * 64)
-
-        d = zstd.train_dictionary(8192, samples, dict_id=42)
-        self.assertEqual(d.dict_id(), 42)
-
-
-@make_cffi
-class TestTrainCoverDictionary(unittest.TestCase):
-    def test_no_args(self):
-        with self.assertRaises(TypeError):
-            zstd.train_cover_dictionary()
-
-    def test_bad_args(self):
-        with self.assertRaises(TypeError):
-            zstd.train_cover_dictionary(8192, u'foo')
-
-        with self.assertRaises(ValueError):
-            zstd.train_cover_dictionary(8192, [u'foo'])
+        self.assertEqual(data[0:8], expected)
 
     def test_basic(self):
-        samples = []
-        for i in range(128):
-            samples.append(b'foo' * 64)
-            samples.append(b'foobar' * 64)
-
-        d = zstd.train_cover_dictionary(8192, samples, k=64, d=16)
+        d = zstd.train_dictionary(8192, generate_samples(), k=64, d=16)
         self.assertIsInstance(d.dict_id(), int_type)
 
         data = d.as_bytes()
@@ -88,23 +49,39 @@
         self.assertEqual(d.d, 16)
 
     def test_set_dict_id(self):
-        samples = []
-        for i in range(128):
-            samples.append(b'foo' * 64)
-            samples.append(b'foobar' * 64)
-
-        d = zstd.train_cover_dictionary(8192, samples, k=64, d=16,
-                                        dict_id=42)
+        d = zstd.train_dictionary(8192, generate_samples(), k=64, d=16,
+                                  dict_id=42)
         self.assertEqual(d.dict_id(), 42)
 
     def test_optimize(self):
-        samples = []
-        for i in range(128):
-            samples.append(b'foo' * 64)
-            samples.append(b'foobar' * 64)
+        d = zstd.train_dictionary(8192, generate_samples(), threads=-1, steps=1,
+                                  d=16)
+
+        self.assertEqual(d.k, 50)
+        self.assertEqual(d.d, 16)
+
+@make_cffi
+class TestCompressionDict(unittest.TestCase):
+    def test_bad_mode(self):
+        with self.assertRaisesRegexp(ValueError, 'invalid dictionary load mode'):
+            zstd.ZstdCompressionDict(b'foo', dict_type=42)
+
+    def test_bad_precompute_compress(self):
+        d = zstd.train_dictionary(8192, generate_samples(), k=64, d=16)
 
-        d = zstd.train_cover_dictionary(8192, samples, optimize=True,
-                                        threads=-1, steps=1, d=16)
+        with self.assertRaisesRegexp(ValueError, 'must specify one of level or '):
+            d.precompute_compress()
+
+        with self.assertRaisesRegexp(ValueError, 'must only specify one of level or '):
+            d.precompute_compress(level=3,
+                                  compression_params=zstd.CompressionParameters())
 
-        self.assertEqual(d.k, 16)
-        self.assertEqual(d.d, 16)
+    def test_precompute_compress_rawcontent(self):
+        d = zstd.ZstdCompressionDict(b'dictcontent' * 64,
+                                     dict_type=zstd.DICT_TYPE_RAWCONTENT)
+        d.precompute_compress(level=1)
+
+        d = zstd.ZstdCompressionDict(b'dictcontent' * 64,
+                                     dict_type=zstd.DICT_TYPE_FULLDICT)
+        with self.assertRaisesRegexp(zstd.ZstdError, 'unable to precompute dictionary'):
+            d.precompute_compress(level=1)