contrib/python-zstandard/zstd_cffi.py
changeset 31796 e0dc40530c5a
parent 30895 c32454d69b85
child 37495 b1fb341d8a61
--- a/contrib/python-zstandard/zstd_cffi.py	Sat Apr 01 13:43:52 2017 -0700
+++ b/contrib/python-zstandard/zstd_cffi.py	Sat Apr 01 15:24:03 2017 -0700
@@ -8,6 +8,7 @@
 
 from __future__ import absolute_import, unicode_literals
 
+import os
 import sys
 
 from _zstd_cffi import (
@@ -62,6 +63,26 @@
 COMPRESSOBJ_FLUSH_BLOCK = 1
 
 
+def _cpu_count():
+    # os.cpu_count() was introducd in Python 3.4.
+    try:
+        return os.cpu_count() or 0
+    except AttributeError:
+        pass
+
+    # Linux.
+    try:
+        if sys.version_info[0] == 2:
+            return os.sysconf(b'SC_NPROCESSORS_ONLN')
+        else:
+            return os.sysconf(u'SC_NPROCESSORS_ONLN')
+    except (AttributeError, ValueError):
+        pass
+
+    # TODO implement on other platforms.
+    return 0
+
+
 class ZstdError(Exception):
     pass
 
@@ -98,6 +119,14 @@
         self.target_length = target_length
         self.strategy = strategy
 
+        zresult = lib.ZSTD_checkCParams(self.as_compression_parameters())
+        if lib.ZSTD_isError(zresult):
+            raise ValueError('invalid compression parameters: %s',
+                             ffi.string(lib.ZSTD_getErrorName(zresult)))
+
+    def estimated_compression_context_size(self):
+        return lib.ZSTD_estimateCCtxSize(self.as_compression_parameters())
+
     def as_compression_parameters(self):
         p = ffi.new('ZSTD_compressionParameters *')[0]
         p.windowLog = self.window_log
@@ -140,12 +169,16 @@
         self._source_size = source_size
         self._write_size = write_size
         self._entered = False
+        self._mtcctx = compressor._cctx if compressor._multithreaded else None
 
     def __enter__(self):
         if self._entered:
             raise ZstdError('cannot __enter__ multiple times')
 
-        self._cstream = self._compressor._get_cstream(self._source_size)
+        if self._mtcctx:
+            self._compressor._init_mtcstream(self._source_size)
+        else:
+            self._compressor._ensure_cstream(self._source_size)
         self._entered = True
         return self
 
@@ -160,7 +193,10 @@
             out_buffer.pos = 0
 
             while True:
-                zresult = lib.ZSTD_endStream(self._cstream, out_buffer)
+                if self._mtcctx:
+                    zresult = lib.ZSTDMT_endStream(self._mtcctx, out_buffer)
+                else:
+                    zresult = lib.ZSTD_endStream(self._compressor._cstream, out_buffer)
                 if lib.ZSTD_isError(zresult):
                     raise ZstdError('error ending compression stream: %s' %
                                     ffi.string(lib.ZSTD_getErrorName(zresult)))
@@ -172,7 +208,6 @@
                 if zresult == 0:
                     break
 
-        self._cstream = None
         self._compressor = None
 
         return False
@@ -182,7 +217,7 @@
             raise ZstdError('cannot determine size of an inactive compressor; '
                             'call when a context manager is active')
 
-        return lib.ZSTD_sizeof_CStream(self._cstream)
+        return lib.ZSTD_sizeof_CStream(self._compressor._cstream)
 
     def write(self, data):
         if not self._entered:
@@ -205,7 +240,12 @@
         out_buffer.pos = 0
 
         while in_buffer.pos < in_buffer.size:
-            zresult = lib.ZSTD_compressStream(self._cstream, out_buffer, in_buffer)
+            if self._mtcctx:
+                zresult = lib.ZSTDMT_compressStream(self._mtcctx, out_buffer,
+                                                    in_buffer)
+            else:
+                zresult = lib.ZSTD_compressStream(self._compressor._cstream, out_buffer,
+                                                  in_buffer)
             if lib.ZSTD_isError(zresult):
                 raise ZstdError('zstd compress error: %s' %
                                 ffi.string(lib.ZSTD_getErrorName(zresult)))
@@ -230,7 +270,10 @@
         out_buffer.pos = 0
 
         while True:
-            zresult = lib.ZSTD_flushStream(self._cstream, out_buffer)
+            if self._mtcctx:
+                zresult = lib.ZSTDMT_flushStream(self._mtcctx, out_buffer)
+            else:
+                zresult = lib.ZSTD_flushStream(self._compressor._cstream, out_buffer)
             if lib.ZSTD_isError(zresult):
                 raise ZstdError('zstd compress error: %s' %
                                 ffi.string(lib.ZSTD_getErrorName(zresult)))
@@ -259,7 +302,12 @@
         chunks = []
 
         while source.pos < len(data):
-            zresult = lib.ZSTD_compressStream(self._cstream, self._out, source)
+            if self._mtcctx:
+                zresult = lib.ZSTDMT_compressStream(self._mtcctx,
+                                                    self._out, source)
+            else:
+                zresult = lib.ZSTD_compressStream(self._compressor._cstream, self._out,
+                                                  source)
             if lib.ZSTD_isError(zresult):
                 raise ZstdError('zstd compress error: %s' %
                                 ffi.string(lib.ZSTD_getErrorName(zresult)))
@@ -280,7 +328,10 @@
         assert self._out.pos == 0
 
         if flush_mode == COMPRESSOBJ_FLUSH_BLOCK:
-            zresult = lib.ZSTD_flushStream(self._cstream, self._out)
+            if self._mtcctx:
+                zresult = lib.ZSTDMT_flushStream(self._mtcctx, self._out)
+            else:
+                zresult = lib.ZSTD_flushStream(self._compressor._cstream, self._out)
             if lib.ZSTD_isError(zresult):
                 raise ZstdError('zstd compress error: %s' %
                                 ffi.string(lib.ZSTD_getErrorName(zresult)))
@@ -301,7 +352,10 @@
         chunks = []
 
         while True:
-            zresult = lib.ZSTD_endStream(self._cstream, self._out)
+            if self._mtcctx:
+                zresult = lib.ZSTDMT_endStream(self._mtcctx, self._out)
+            else:
+                zresult = lib.ZSTD_endStream(self._compressor._cstream, self._out)
             if lib.ZSTD_isError(zresult):
                 raise ZstdError('error ending compression stream: %s' %
                                 ffi.string(lib.ZSTD_getErroName(zresult)))
@@ -313,21 +367,21 @@
             if not zresult:
                 break
 
-        # GC compression stream immediately.
-        self._cstream = None
-
         return b''.join(chunks)
 
 
 class ZstdCompressor(object):
     def __init__(self, level=3, dict_data=None, compression_params=None,
                  write_checksum=False, write_content_size=False,
-                 write_dict_id=True):
+                 write_dict_id=True, threads=0):
         if level < 1:
             raise ValueError('level must be greater than 0')
         elif level > lib.ZSTD_maxCLevel():
             raise ValueError('level must be less than %d' % lib.ZSTD_maxCLevel())
 
+        if threads < 0:
+            threads = _cpu_count()
+
         self._compression_level = level
         self._dict_data = dict_data
         self._cparams = compression_params
@@ -336,16 +390,33 @@
         self._fparams.contentSizeFlag = write_content_size
         self._fparams.noDictIDFlag = not write_dict_id
 
-        cctx = lib.ZSTD_createCCtx()
-        if cctx == ffi.NULL:
-            raise MemoryError()
+        if threads:
+            cctx = lib.ZSTDMT_createCCtx(threads)
+            if cctx == ffi.NULL:
+                raise MemoryError()
 
-        self._cctx = ffi.gc(cctx, lib.ZSTD_freeCCtx)
+            self._cctx = ffi.gc(cctx, lib.ZSTDMT_freeCCtx)
+            self._multithreaded = True
+        else:
+            cctx = lib.ZSTD_createCCtx()
+            if cctx == ffi.NULL:
+                raise MemoryError()
+
+            self._cctx = ffi.gc(cctx, lib.ZSTD_freeCCtx)
+            self._multithreaded = False
+
+        self._cstream = None
 
     def compress(self, data, allow_empty=False):
         if len(data) == 0 and self._fparams.contentSizeFlag and not allow_empty:
             raise ValueError('cannot write empty inputs when writing content sizes')
 
+        if self._multithreaded and self._dict_data:
+            raise ZstdError('compress() cannot be used with both dictionaries and multi-threaded compression')
+
+        if self._multithreaded and self._cparams:
+            raise ZstdError('compress() cannot be used with both compression parameters and multi-threaded compression')
+
         # TODO use a CDict for performance.
         dict_data = ffi.NULL
         dict_size = 0
@@ -365,11 +436,17 @@
         dest_size = lib.ZSTD_compressBound(len(data))
         out = new_nonzero('char[]', dest_size)
 
-        zresult = lib.ZSTD_compress_advanced(self._cctx,
-                                             ffi.addressof(out), dest_size,
-                                             data, len(data),
-                                             dict_data, dict_size,
-                                             params)
+        if self._multithreaded:
+            zresult = lib.ZSTDMT_compressCCtx(self._cctx,
+                                              ffi.addressof(out), dest_size,
+                                              data, len(data),
+                                              self._compression_level)
+        else:
+            zresult = lib.ZSTD_compress_advanced(self._cctx,
+                                                 ffi.addressof(out), dest_size,
+                                                 data, len(data),
+                                                 dict_data, dict_size,
+                                                 params)
 
         if lib.ZSTD_isError(zresult):
             raise ZstdError('cannot compress: %s' %
@@ -378,9 +455,12 @@
         return ffi.buffer(out, zresult)[:]
 
     def compressobj(self, size=0):
-        cstream = self._get_cstream(size)
+        if self._multithreaded:
+            self._init_mtcstream(size)
+        else:
+            self._ensure_cstream(size)
+
         cobj = ZstdCompressionObj()
-        cobj._cstream = cstream
         cobj._out = ffi.new('ZSTD_outBuffer *')
         cobj._dst_buffer = ffi.new('char[]', COMPRESSION_RECOMMENDED_OUTPUT_SIZE)
         cobj._out.dst = cobj._dst_buffer
@@ -389,6 +469,11 @@
         cobj._compressor = self
         cobj._finished = False
 
+        if self._multithreaded:
+            cobj._mtcctx = self._cctx
+        else:
+            cobj._mtcctx = None
+
         return cobj
 
     def copy_stream(self, ifh, ofh, size=0,
@@ -400,7 +485,11 @@
         if not hasattr(ofh, 'write'):
             raise ValueError('second argument must have a write() method')
 
-        cstream = self._get_cstream(size)
+        mt = self._multithreaded
+        if mt:
+            self._init_mtcstream(size)
+        else:
+            self._ensure_cstream(size)
 
         in_buffer = ffi.new('ZSTD_inBuffer *')
         out_buffer = ffi.new('ZSTD_outBuffer *')
@@ -424,7 +513,11 @@
             in_buffer.pos = 0
 
             while in_buffer.pos < in_buffer.size:
-                zresult = lib.ZSTD_compressStream(cstream, out_buffer, in_buffer)
+                if mt:
+                    zresult = lib.ZSTDMT_compressStream(self._cctx, out_buffer, in_buffer)
+                else:
+                    zresult = lib.ZSTD_compressStream(self._cstream,
+                                                      out_buffer, in_buffer)
                 if lib.ZSTD_isError(zresult):
                     raise ZstdError('zstd compress error: %s' %
                                     ffi.string(lib.ZSTD_getErrorName(zresult)))
@@ -436,7 +529,10 @@
 
         # We've finished reading. Flush the compressor.
         while True:
-            zresult = lib.ZSTD_endStream(cstream, out_buffer)
+            if mt:
+                zresult = lib.ZSTDMT_endStream(self._cctx, out_buffer)
+            else:
+                zresult = lib.ZSTD_endStream(self._cstream, out_buffer)
             if lib.ZSTD_isError(zresult):
                 raise ZstdError('error ending compression stream: %s' %
                                 ffi.string(lib.ZSTD_getErrorName(zresult)))
@@ -472,7 +568,10 @@
             raise ValueError('must pass an object with a read() method or '
                              'conforms to buffer protocol')
 
-        cstream = self._get_cstream(size)
+        if self._multithreaded:
+            self._init_mtcstream(size)
+        else:
+            self._ensure_cstream(size)
 
         in_buffer = ffi.new('ZSTD_inBuffer *')
         out_buffer = ffi.new('ZSTD_outBuffer *')
@@ -512,7 +611,10 @@
             in_buffer.pos = 0
 
             while in_buffer.pos < in_buffer.size:
-                zresult = lib.ZSTD_compressStream(cstream, out_buffer, in_buffer)
+                if self._multithreaded:
+                    zresult = lib.ZSTDMT_compressStream(self._cctx, out_buffer, in_buffer)
+                else:
+                    zresult = lib.ZSTD_compressStream(self._cstream, out_buffer, in_buffer)
                 if lib.ZSTD_isError(zresult):
                     raise ZstdError('zstd compress error: %s' %
                                     ffi.string(lib.ZSTD_getErrorName(zresult)))
@@ -531,7 +633,10 @@
         # remains.
         while True:
             assert out_buffer.pos == 0
-            zresult = lib.ZSTD_endStream(cstream, out_buffer)
+            if self._multithreaded:
+                zresult = lib.ZSTDMT_endStream(self._cctx, out_buffer)
+            else:
+                zresult = lib.ZSTD_endStream(self._cstream, out_buffer)
             if lib.ZSTD_isError(zresult):
                 raise ZstdError('error ending compression stream: %s' %
                                 ffi.string(lib.ZSTD_getErrorName(zresult)))
@@ -544,7 +649,15 @@
             if zresult == 0:
                 break
 
-    def _get_cstream(self, size):
+    def _ensure_cstream(self, size):
+        if self._cstream:
+            zresult = lib.ZSTD_resetCStream(self._cstream, size)
+            if lib.ZSTD_isError(zresult):
+                raise ZstdError('could not reset CStream: %s' %
+                                ffi.string(lib.ZSTD_getErrorName(zresult)))
+
+            return
+
         cstream = lib.ZSTD_createCStream()
         if cstream == ffi.NULL:
             raise MemoryError()
@@ -571,7 +684,32 @@
             raise Exception('cannot init CStream: %s' %
                             ffi.string(lib.ZSTD_getErrorName(zresult)))
 
-        return cstream
+        self._cstream = cstream
+
+    def _init_mtcstream(self, size):
+        assert self._multithreaded
+
+        dict_data = ffi.NULL
+        dict_size = 0
+        if self._dict_data:
+            dict_data = self._dict_data.as_bytes()
+            dict_size = len(self._dict_data)
+
+        zparams = ffi.new('ZSTD_parameters *')[0]
+        if self._cparams:
+            zparams.cParams = self._cparams.as_compression_parameters()
+        else:
+            zparams.cParams = lib.ZSTD_getCParams(self._compression_level,
+                                                  size, dict_size)
+
+        zparams.fParams = self._fparams
+
+        zresult = lib.ZSTDMT_initCStream_advanced(self._cctx, dict_data, dict_size,
+                                                  zparams, size)
+
+        if lib.ZSTD_isError(zresult):
+            raise ZstdError('cannot init CStream: %s' %
+                            ffi.string(lib.ZSTD_getErrorName(zresult)))
 
 
 class FrameParameters(object):
@@ -601,9 +739,11 @@
 
 
 class ZstdCompressionDict(object):
-    def __init__(self, data):
+    def __init__(self, data, k=0, d=0):
         assert isinstance(data, bytes_type)
         self._data = data
+        self.k = k
+        self.d = d
 
     def __len__(self):
         return len(self._data)
@@ -615,7 +755,8 @@
         return self._data
 
 
-def train_dictionary(dict_size, samples, parameters=None):
+def train_dictionary(dict_size, samples, selectivity=0, level=0,
+                     notifications=0, dict_id=0):
     if not isinstance(samples, list):
         raise TypeError('samples must be a list')
 
@@ -636,10 +777,18 @@
 
     dict_data = new_nonzero('char[]', dict_size)
 
-    zresult = lib.ZDICT_trainFromBuffer(ffi.addressof(dict_data), dict_size,
-                                        ffi.addressof(samples_buffer),
-                                        ffi.addressof(sample_sizes, 0),
-                                        len(samples))
+    dparams = ffi.new('ZDICT_params_t *')[0]
+    dparams.selectivityLevel = selectivity
+    dparams.compressionLevel = level
+    dparams.notificationLevel = notifications
+    dparams.dictID = dict_id
+
+    zresult = lib.ZDICT_trainFromBuffer_advanced(
+        ffi.addressof(dict_data), dict_size,
+        ffi.addressof(samples_buffer),
+        ffi.addressof(sample_sizes, 0), len(samples),
+        dparams)
+
     if lib.ZDICT_isError(zresult):
         raise ZstdError('Cannot train dict: %s' %
                         ffi.string(lib.ZDICT_getErrorName(zresult)))
@@ -647,16 +796,73 @@
     return ZstdCompressionDict(ffi.buffer(dict_data, zresult)[:])
 
 
+def train_cover_dictionary(dict_size, samples, k=0, d=0,
+                           notifications=0, dict_id=0, level=0, optimize=False,
+                           steps=0, threads=0):
+    if not isinstance(samples, list):
+        raise TypeError('samples must be a list')
+
+    if threads < 0:
+        threads = _cpu_count()
+
+    total_size = sum(map(len, samples))
+
+    samples_buffer = new_nonzero('char[]', total_size)
+    sample_sizes = new_nonzero('size_t[]', len(samples))
+
+    offset = 0
+    for i, sample in enumerate(samples):
+        if not isinstance(sample, bytes_type):
+            raise ValueError('samples must be bytes')
+
+        l = len(sample)
+        ffi.memmove(samples_buffer + offset, sample, l)
+        offset += l
+        sample_sizes[i] = l
+
+    dict_data = new_nonzero('char[]', dict_size)
+
+    dparams = ffi.new('COVER_params_t *')[0]
+    dparams.k = k
+    dparams.d = d
+    dparams.steps = steps
+    dparams.nbThreads = threads
+    dparams.notificationLevel = notifications
+    dparams.dictID = dict_id
+    dparams.compressionLevel = level
+
+    if optimize:
+        zresult = lib.COVER_optimizeTrainFromBuffer(
+            ffi.addressof(dict_data), dict_size,
+            ffi.addressof(samples_buffer),
+            ffi.addressof(sample_sizes, 0), len(samples),
+            ffi.addressof(dparams))
+    else:
+        zresult = lib.COVER_trainFromBuffer(
+            ffi.addressof(dict_data), dict_size,
+            ffi.addressof(samples_buffer),
+            ffi.addressof(sample_sizes, 0), len(samples),
+            dparams)
+
+    if lib.ZDICT_isError(zresult):
+        raise ZstdError('cannot train dict: %s' %
+                        ffi.string(lib.ZDICT_getErrorName(zresult)))
+
+    return ZstdCompressionDict(ffi.buffer(dict_data, zresult)[:],
+                               k=dparams.k, d=dparams.d)
+
+
 class ZstdDecompressionObj(object):
     def __init__(self, decompressor):
         self._decompressor = decompressor
-        self._dstream = self._decompressor._get_dstream()
         self._finished = False
 
     def decompress(self, data):
         if self._finished:
             raise ZstdError('cannot use a decompressobj multiple times')
 
+        assert(self._decompressor._dstream)
+
         in_buffer = ffi.new('ZSTD_inBuffer *')
         out_buffer = ffi.new('ZSTD_outBuffer *')
 
@@ -673,14 +879,14 @@
         chunks = []
 
         while in_buffer.pos < in_buffer.size:
-            zresult = lib.ZSTD_decompressStream(self._dstream, out_buffer, in_buffer)
+            zresult = lib.ZSTD_decompressStream(self._decompressor._dstream,
+                                                out_buffer, in_buffer)
             if lib.ZSTD_isError(zresult):
                 raise ZstdError('zstd decompressor error: %s' %
                                 ffi.string(lib.ZSTD_getErrorName(zresult)))
 
             if zresult == 0:
                 self._finished = True
-                self._dstream = None
                 self._decompressor = None
 
             if out_buffer.pos:
@@ -695,28 +901,26 @@
         self._decompressor = decompressor
         self._writer = writer
         self._write_size = write_size
-        self._dstream = None
         self._entered = False
 
     def __enter__(self):
         if self._entered:
             raise ZstdError('cannot __enter__ multiple times')
 
-        self._dstream = self._decompressor._get_dstream()
+        self._decompressor._ensure_dstream()
         self._entered = True
 
         return self
 
     def __exit__(self, exc_type, exc_value, exc_tb):
         self._entered = False
-        self._dstream = None
 
     def memory_size(self):
-        if not self._dstream:
+        if not self._decompressor._dstream:
             raise ZstdError('cannot determine size of inactive decompressor '
                             'call when context manager is active')
 
-        return lib.ZSTD_sizeof_DStream(self._dstream)
+        return lib.ZSTD_sizeof_DStream(self._decompressor._dstream)
 
     def write(self, data):
         if not self._entered:
@@ -737,8 +941,10 @@
         out_buffer.size = len(dst_buffer)
         out_buffer.pos = 0
 
+        dstream = self._decompressor._dstream
+
         while in_buffer.pos < in_buffer.size:
-            zresult = lib.ZSTD_decompressStream(self._dstream, out_buffer, in_buffer)
+            zresult = lib.ZSTD_decompressStream(dstream, out_buffer, in_buffer)
             if lib.ZSTD_isError(zresult):
                 raise ZstdError('zstd decompress error: %s' %
                                 ffi.string(lib.ZSTD_getErrorName(zresult)))
@@ -760,6 +966,7 @@
             raise MemoryError()
 
         self._refdctx = ffi.gc(dctx, lib.ZSTD_freeDCtx)
+        self._dstream = None
 
     @property
     def _ddict(self):
@@ -816,6 +1023,7 @@
         return ffi.buffer(result_buffer, zresult)[:]
 
     def decompressobj(self):
+        self._ensure_dstream()
         return ZstdDecompressionObj(self)
 
     def read_from(self, reader, read_size=DECOMPRESSION_RECOMMENDED_INPUT_SIZE,
@@ -843,7 +1051,7 @@
 
                 buffer_offset = skip_bytes
 
-        dstream = self._get_dstream()
+        self._ensure_dstream()
 
         in_buffer = ffi.new('ZSTD_inBuffer *')
         out_buffer = ffi.new('ZSTD_outBuffer *')
@@ -878,7 +1086,7 @@
             while in_buffer.pos < in_buffer.size:
                 assert out_buffer.pos == 0
 
-                zresult = lib.ZSTD_decompressStream(dstream, out_buffer, in_buffer)
+                zresult = lib.ZSTD_decompressStream(self._dstream, out_buffer, in_buffer)
                 if lib.ZSTD_isError(zresult):
                     raise ZstdError('zstd decompress error: %s' %
                                     ffi.string(lib.ZSTD_getErrorName(zresult)))
@@ -910,7 +1118,7 @@
         if not hasattr(ofh, 'write'):
             raise ValueError('second argument must have a write() method')
 
-        dstream = self._get_dstream()
+        self._ensure_dstream()
 
         in_buffer = ffi.new('ZSTD_inBuffer *')
         out_buffer = ffi.new('ZSTD_outBuffer *')
@@ -936,7 +1144,7 @@
 
             # Flush all read data to output.
             while in_buffer.pos < in_buffer.size:
-                zresult = lib.ZSTD_decompressStream(dstream, out_buffer, in_buffer)
+                zresult = lib.ZSTD_decompressStream(self._dstream, out_buffer, in_buffer)
                 if lib.ZSTD_isError(zresult):
                     raise ZstdError('zstd decompressor error: %s' %
                                     ffi.string(lib.ZSTD_getErrorName(zresult)))
@@ -1021,22 +1229,29 @@
 
         return ffi.buffer(last_buffer, len(last_buffer))[:]
 
-    def _get_dstream(self):
-        dstream = lib.ZSTD_createDStream()
-        if dstream == ffi.NULL:
+    def _ensure_dstream(self):
+        if self._dstream:
+            zresult = lib.ZSTD_resetDStream(self._dstream)
+            if lib.ZSTD_isError(zresult):
+                raise ZstdError('could not reset DStream: %s' %
+                                ffi.string(lib.ZSTD_getErrorName(zresult)))
+
+            return
+
+        self._dstream = lib.ZSTD_createDStream()
+        if self._dstream == ffi.NULL:
             raise MemoryError()
 
-        dstream = ffi.gc(dstream, lib.ZSTD_freeDStream)
+        self._dstream = ffi.gc(self._dstream, lib.ZSTD_freeDStream)
 
         if self._dict_data:
-            zresult = lib.ZSTD_initDStream_usingDict(dstream,
+            zresult = lib.ZSTD_initDStream_usingDict(self._dstream,
                                                      self._dict_data.as_bytes(),
                                                      len(self._dict_data))
         else:
-            zresult = lib.ZSTD_initDStream(dstream)
+            zresult = lib.ZSTD_initDStream(self._dstream)
 
         if lib.ZSTD_isError(zresult):
+            self._dstream = None
             raise ZstdError('could not initialize DStream: %s' %
                             ffi.string(lib.ZSTD_getErrorName(zresult)))
-
-        return dstream