|
1 # Copyright (c) 2016-present, Gregory Szorc |
|
2 # All rights reserved. |
|
3 # |
|
4 # This software may be modified and distributed under the terms |
|
5 # of the BSD license. See the LICENSE file for details. |
|
6 |
|
7 """Python interface to the Zstandard (zstd) compression library.""" |
|
8 |
|
9 from __future__ import absolute_import, unicode_literals |
|
10 |
|
11 import io |
|
12 |
|
13 from _zstd_cffi import ( |
|
14 ffi, |
|
15 lib, |
|
16 ) |
|
17 |
|
18 |
|
19 _CSTREAM_IN_SIZE = lib.ZSTD_CStreamInSize() |
|
20 _CSTREAM_OUT_SIZE = lib.ZSTD_CStreamOutSize() |
|
21 |
|
22 |
|
23 class _ZstdCompressionWriter(object): |
|
24 def __init__(self, cstream, writer): |
|
25 self._cstream = cstream |
|
26 self._writer = writer |
|
27 |
|
28 def __enter__(self): |
|
29 return self |
|
30 |
|
31 def __exit__(self, exc_type, exc_value, exc_tb): |
|
32 if not exc_type and not exc_value and not exc_tb: |
|
33 out_buffer = ffi.new('ZSTD_outBuffer *') |
|
34 out_buffer.dst = ffi.new('char[]', _CSTREAM_OUT_SIZE) |
|
35 out_buffer.size = _CSTREAM_OUT_SIZE |
|
36 out_buffer.pos = 0 |
|
37 |
|
38 while True: |
|
39 res = lib.ZSTD_endStream(self._cstream, out_buffer) |
|
40 if lib.ZSTD_isError(res): |
|
41 raise Exception('error ending compression stream: %s' % lib.ZSTD_getErrorName) |
|
42 |
|
43 if out_buffer.pos: |
|
44 self._writer.write(ffi.buffer(out_buffer.dst, out_buffer.pos)) |
|
45 out_buffer.pos = 0 |
|
46 |
|
47 if res == 0: |
|
48 break |
|
49 |
|
50 return False |
|
51 |
|
52 def write(self, data): |
|
53 out_buffer = ffi.new('ZSTD_outBuffer *') |
|
54 out_buffer.dst = ffi.new('char[]', _CSTREAM_OUT_SIZE) |
|
55 out_buffer.size = _CSTREAM_OUT_SIZE |
|
56 out_buffer.pos = 0 |
|
57 |
|
58 # TODO can we reuse existing memory? |
|
59 in_buffer = ffi.new('ZSTD_inBuffer *') |
|
60 in_buffer.src = ffi.new('char[]', data) |
|
61 in_buffer.size = len(data) |
|
62 in_buffer.pos = 0 |
|
63 while in_buffer.pos < in_buffer.size: |
|
64 res = lib.ZSTD_compressStream(self._cstream, out_buffer, in_buffer) |
|
65 if lib.ZSTD_isError(res): |
|
66 raise Exception('zstd compress error: %s' % lib.ZSTD_getErrorName(res)) |
|
67 |
|
68 if out_buffer.pos: |
|
69 self._writer.write(ffi.buffer(out_buffer.dst, out_buffer.pos)) |
|
70 out_buffer.pos = 0 |
|
71 |
|
72 |
|
73 class ZstdCompressor(object): |
|
74 def __init__(self, level=3, dict_data=None, compression_params=None): |
|
75 if dict_data: |
|
76 raise Exception('dict_data not yet supported') |
|
77 if compression_params: |
|
78 raise Exception('compression_params not yet supported') |
|
79 |
|
80 self._compression_level = level |
|
81 |
|
82 def compress(self, data): |
|
83 # Just use the stream API for now. |
|
84 output = io.BytesIO() |
|
85 with self.write_to(output) as compressor: |
|
86 compressor.write(data) |
|
87 return output.getvalue() |
|
88 |
|
89 def copy_stream(self, ifh, ofh): |
|
90 cstream = self._get_cstream() |
|
91 |
|
92 in_buffer = ffi.new('ZSTD_inBuffer *') |
|
93 out_buffer = ffi.new('ZSTD_outBuffer *') |
|
94 |
|
95 out_buffer.dst = ffi.new('char[]', _CSTREAM_OUT_SIZE) |
|
96 out_buffer.size = _CSTREAM_OUT_SIZE |
|
97 out_buffer.pos = 0 |
|
98 |
|
99 total_read, total_write = 0, 0 |
|
100 |
|
101 while True: |
|
102 data = ifh.read(_CSTREAM_IN_SIZE) |
|
103 if not data: |
|
104 break |
|
105 |
|
106 total_read += len(data) |
|
107 |
|
108 in_buffer.src = ffi.new('char[]', data) |
|
109 in_buffer.size = len(data) |
|
110 in_buffer.pos = 0 |
|
111 |
|
112 while in_buffer.pos < in_buffer.size: |
|
113 res = lib.ZSTD_compressStream(cstream, out_buffer, in_buffer) |
|
114 if lib.ZSTD_isError(res): |
|
115 raise Exception('zstd compress error: %s' % |
|
116 lib.ZSTD_getErrorName(res)) |
|
117 |
|
118 if out_buffer.pos: |
|
119 ofh.write(ffi.buffer(out_buffer.dst, out_buffer.pos)) |
|
120 total_write = out_buffer.pos |
|
121 out_buffer.pos = 0 |
|
122 |
|
123 # We've finished reading. Flush the compressor. |
|
124 while True: |
|
125 res = lib.ZSTD_endStream(cstream, out_buffer) |
|
126 if lib.ZSTD_isError(res): |
|
127 raise Exception('error ending compression stream: %s' % |
|
128 lib.ZSTD_getErrorName(res)) |
|
129 |
|
130 if out_buffer.pos: |
|
131 ofh.write(ffi.buffer(out_buffer.dst, out_buffer.pos)) |
|
132 total_write += out_buffer.pos |
|
133 out_buffer.pos = 0 |
|
134 |
|
135 if res == 0: |
|
136 break |
|
137 |
|
138 return total_read, total_write |
|
139 |
|
140 def write_to(self, writer): |
|
141 return _ZstdCompressionWriter(self._get_cstream(), writer) |
|
142 |
|
143 def _get_cstream(self): |
|
144 cstream = lib.ZSTD_createCStream() |
|
145 cstream = ffi.gc(cstream, lib.ZSTD_freeCStream) |
|
146 |
|
147 res = lib.ZSTD_initCStream(cstream, self._compression_level) |
|
148 if lib.ZSTD_isError(res): |
|
149 raise Exception('cannot init CStream: %s' % |
|
150 lib.ZSTD_getErrorName(res)) |
|
151 |
|
152 return cstream |