contrib/python-zstandard/zstd_cffi.py
branchstable
changeset 42146 4a8d9ed86475
parent 41984 d1c33b2442a7
parent 42143 29569f2db929
child 42147 807a6ca6d096
equal deleted inserted replaced
41984:d1c33b2442a7 42146:4a8d9ed86475
     1 # Copyright (c) 2016-present, Gregory Szorc
       
     2 # All rights reserved.
       
     3 #
       
     4 # This software may be modified and distributed under the terms
       
     5 # of the BSD license. See the LICENSE file for details.
       
     6 
       
     7 """Python interface to the Zstandard (zstd) compression library."""
       
     8 
       
     9 from __future__ import absolute_import, unicode_literals
       
    10 
       
    11 # This should match what the C extension exports.
       
    12 __all__ = [
       
    13     #'BufferSegment',
       
    14     #'BufferSegments',
       
    15     #'BufferWithSegments',
       
    16     #'BufferWithSegmentsCollection',
       
    17     'CompressionParameters',
       
    18     'ZstdCompressionDict',
       
    19     'ZstdCompressionParameters',
       
    20     'ZstdCompressor',
       
    21     'ZstdError',
       
    22     'ZstdDecompressor',
       
    23     'FrameParameters',
       
    24     'estimate_decompression_context_size',
       
    25     'frame_content_size',
       
    26     'frame_header_size',
       
    27     'get_frame_parameters',
       
    28     'train_dictionary',
       
    29 
       
    30     # Constants.
       
    31     'COMPRESSOBJ_FLUSH_FINISH',
       
    32     'COMPRESSOBJ_FLUSH_BLOCK',
       
    33     'ZSTD_VERSION',
       
    34     'FRAME_HEADER',
       
    35     'CONTENTSIZE_UNKNOWN',
       
    36     'CONTENTSIZE_ERROR',
       
    37     'MAX_COMPRESSION_LEVEL',
       
    38     'COMPRESSION_RECOMMENDED_INPUT_SIZE',
       
    39     'COMPRESSION_RECOMMENDED_OUTPUT_SIZE',
       
    40     'DECOMPRESSION_RECOMMENDED_INPUT_SIZE',
       
    41     'DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE',
       
    42     'MAGIC_NUMBER',
       
    43     'BLOCKSIZELOG_MAX',
       
    44     'BLOCKSIZE_MAX',
       
    45     'WINDOWLOG_MIN',
       
    46     'WINDOWLOG_MAX',
       
    47     'CHAINLOG_MIN',
       
    48     'CHAINLOG_MAX',
       
    49     'HASHLOG_MIN',
       
    50     'HASHLOG_MAX',
       
    51     'HASHLOG3_MAX',
       
    52     'SEARCHLOG_MIN',
       
    53     'SEARCHLOG_MAX',
       
    54     'SEARCHLENGTH_MIN',
       
    55     'SEARCHLENGTH_MAX',
       
    56     'TARGETLENGTH_MIN',
       
    57     'TARGETLENGTH_MAX',
       
    58     'LDM_MINMATCH_MIN',
       
    59     'LDM_MINMATCH_MAX',
       
    60     'LDM_BUCKETSIZELOG_MAX',
       
    61     'STRATEGY_FAST',
       
    62     'STRATEGY_DFAST',
       
    63     'STRATEGY_GREEDY',
       
    64     'STRATEGY_LAZY',
       
    65     'STRATEGY_LAZY2',
       
    66     'STRATEGY_BTLAZY2',
       
    67     'STRATEGY_BTOPT',
       
    68     'STRATEGY_BTULTRA',
       
    69     'DICT_TYPE_AUTO',
       
    70     'DICT_TYPE_RAWCONTENT',
       
    71     'DICT_TYPE_FULLDICT',
       
    72     'FORMAT_ZSTD1',
       
    73     'FORMAT_ZSTD1_MAGICLESS',
       
    74 ]
       
    75 
       
    76 import io
       
    77 import os
       
    78 import sys
       
    79 
       
    80 from _zstd_cffi import (
       
    81     ffi,
       
    82     lib,
       
    83 )
       
    84 
       
    85 if sys.version_info[0] == 2:
       
    86     bytes_type = str
       
    87     int_type = long
       
    88 else:
       
    89     bytes_type = bytes
       
    90     int_type = int
       
    91 
       
    92 
       
    93 COMPRESSION_RECOMMENDED_INPUT_SIZE = lib.ZSTD_CStreamInSize()
       
    94 COMPRESSION_RECOMMENDED_OUTPUT_SIZE = lib.ZSTD_CStreamOutSize()
       
    95 DECOMPRESSION_RECOMMENDED_INPUT_SIZE = lib.ZSTD_DStreamInSize()
       
    96 DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE = lib.ZSTD_DStreamOutSize()
       
    97 
       
    98 new_nonzero = ffi.new_allocator(should_clear_after_alloc=False)
       
    99 
       
   100 
       
   101 MAX_COMPRESSION_LEVEL = lib.ZSTD_maxCLevel()
       
   102 MAGIC_NUMBER = lib.ZSTD_MAGICNUMBER
       
   103 FRAME_HEADER = b'\x28\xb5\x2f\xfd'
       
   104 CONTENTSIZE_UNKNOWN = lib.ZSTD_CONTENTSIZE_UNKNOWN
       
   105 CONTENTSIZE_ERROR = lib.ZSTD_CONTENTSIZE_ERROR
       
   106 ZSTD_VERSION = (lib.ZSTD_VERSION_MAJOR, lib.ZSTD_VERSION_MINOR, lib.ZSTD_VERSION_RELEASE)
       
   107 
       
   108 BLOCKSIZELOG_MAX = lib.ZSTD_BLOCKSIZELOG_MAX
       
   109 BLOCKSIZE_MAX = lib.ZSTD_BLOCKSIZE_MAX
       
   110 WINDOWLOG_MIN = lib.ZSTD_WINDOWLOG_MIN
       
   111 WINDOWLOG_MAX = lib.ZSTD_WINDOWLOG_MAX
       
   112 CHAINLOG_MIN = lib.ZSTD_CHAINLOG_MIN
       
   113 CHAINLOG_MAX = lib.ZSTD_CHAINLOG_MAX
       
   114 HASHLOG_MIN = lib.ZSTD_HASHLOG_MIN
       
   115 HASHLOG_MAX = lib.ZSTD_HASHLOG_MAX
       
   116 HASHLOG3_MAX = lib.ZSTD_HASHLOG3_MAX
       
   117 SEARCHLOG_MIN = lib.ZSTD_SEARCHLOG_MIN
       
   118 SEARCHLOG_MAX = lib.ZSTD_SEARCHLOG_MAX
       
   119 SEARCHLENGTH_MIN = lib.ZSTD_SEARCHLENGTH_MIN
       
   120 SEARCHLENGTH_MAX = lib.ZSTD_SEARCHLENGTH_MAX
       
   121 TARGETLENGTH_MIN = lib.ZSTD_TARGETLENGTH_MIN
       
   122 TARGETLENGTH_MAX = lib.ZSTD_TARGETLENGTH_MAX
       
   123 LDM_MINMATCH_MIN = lib.ZSTD_LDM_MINMATCH_MIN
       
   124 LDM_MINMATCH_MAX = lib.ZSTD_LDM_MINMATCH_MAX
       
   125 LDM_BUCKETSIZELOG_MAX = lib.ZSTD_LDM_BUCKETSIZELOG_MAX
       
   126 
       
   127 STRATEGY_FAST = lib.ZSTD_fast
       
   128 STRATEGY_DFAST = lib.ZSTD_dfast
       
   129 STRATEGY_GREEDY = lib.ZSTD_greedy
       
   130 STRATEGY_LAZY = lib.ZSTD_lazy
       
   131 STRATEGY_LAZY2 = lib.ZSTD_lazy2
       
   132 STRATEGY_BTLAZY2 = lib.ZSTD_btlazy2
       
   133 STRATEGY_BTOPT = lib.ZSTD_btopt
       
   134 STRATEGY_BTULTRA = lib.ZSTD_btultra
       
   135 
       
   136 DICT_TYPE_AUTO = lib.ZSTD_dct_auto
       
   137 DICT_TYPE_RAWCONTENT = lib.ZSTD_dct_rawContent
       
   138 DICT_TYPE_FULLDICT = lib.ZSTD_dct_fullDict
       
   139 
       
   140 FORMAT_ZSTD1 = lib.ZSTD_f_zstd1
       
   141 FORMAT_ZSTD1_MAGICLESS = lib.ZSTD_f_zstd1_magicless
       
   142 
       
   143 COMPRESSOBJ_FLUSH_FINISH = 0
       
   144 COMPRESSOBJ_FLUSH_BLOCK = 1
       
   145 
       
   146 
       
   147 def _cpu_count():
       
   148     # os.cpu_count() was introducd in Python 3.4.
       
   149     try:
       
   150         return os.cpu_count() or 0
       
   151     except AttributeError:
       
   152         pass
       
   153 
       
   154     # Linux.
       
   155     try:
       
   156         if sys.version_info[0] == 2:
       
   157             return os.sysconf(b'SC_NPROCESSORS_ONLN')
       
   158         else:
       
   159             return os.sysconf(u'SC_NPROCESSORS_ONLN')
       
   160     except (AttributeError, ValueError):
       
   161         pass
       
   162 
       
   163     # TODO implement on other platforms.
       
   164     return 0
       
   165 
       
   166 
       
   167 class ZstdError(Exception):
       
   168     pass
       
   169 
       
   170 
       
   171 def _zstd_error(zresult):
       
   172     # Resolves to bytes on Python 2 and 3. We use the string for formatting
       
   173     # into error messages, which will be literal unicode. So convert it to
       
   174     # unicode.
       
   175     return ffi.string(lib.ZSTD_getErrorName(zresult)).decode('utf-8')
       
   176 
       
   177 def _make_cctx_params(params):
       
   178     res = lib.ZSTD_createCCtxParams()
       
   179     if res == ffi.NULL:
       
   180         raise MemoryError()
       
   181 
       
   182     res = ffi.gc(res, lib.ZSTD_freeCCtxParams)
       
   183 
       
   184     attrs = [
       
   185         (lib.ZSTD_p_format, params.format),
       
   186         (lib.ZSTD_p_compressionLevel, params.compression_level),
       
   187         (lib.ZSTD_p_windowLog, params.window_log),
       
   188         (lib.ZSTD_p_hashLog, params.hash_log),
       
   189         (lib.ZSTD_p_chainLog, params.chain_log),
       
   190         (lib.ZSTD_p_searchLog, params.search_log),
       
   191         (lib.ZSTD_p_minMatch, params.min_match),
       
   192         (lib.ZSTD_p_targetLength, params.target_length),
       
   193         (lib.ZSTD_p_compressionStrategy, params.compression_strategy),
       
   194         (lib.ZSTD_p_contentSizeFlag, params.write_content_size),
       
   195         (lib.ZSTD_p_checksumFlag, params.write_checksum),
       
   196         (lib.ZSTD_p_dictIDFlag, params.write_dict_id),
       
   197         (lib.ZSTD_p_nbWorkers, params.threads),
       
   198         (lib.ZSTD_p_jobSize, params.job_size),
       
   199         (lib.ZSTD_p_overlapSizeLog, params.overlap_size_log),
       
   200         (lib.ZSTD_p_forceMaxWindow, params.force_max_window),
       
   201         (lib.ZSTD_p_enableLongDistanceMatching, params.enable_ldm),
       
   202         (lib.ZSTD_p_ldmHashLog, params.ldm_hash_log),
       
   203         (lib.ZSTD_p_ldmMinMatch, params.ldm_min_match),
       
   204         (lib.ZSTD_p_ldmBucketSizeLog, params.ldm_bucket_size_log),
       
   205         (lib.ZSTD_p_ldmHashEveryLog, params.ldm_hash_every_log),
       
   206     ]
       
   207 
       
   208     for param, value in attrs:
       
   209         _set_compression_parameter(res, param, value)
       
   210 
       
   211     return res
       
   212 
       
   213 class ZstdCompressionParameters(object):
       
   214     @staticmethod
       
   215     def from_level(level, source_size=0, dict_size=0, **kwargs):
       
   216         params = lib.ZSTD_getCParams(level, source_size, dict_size)
       
   217 
       
   218         args = {
       
   219             'window_log': 'windowLog',
       
   220             'chain_log': 'chainLog',
       
   221             'hash_log': 'hashLog',
       
   222             'search_log': 'searchLog',
       
   223             'min_match': 'searchLength',
       
   224             'target_length': 'targetLength',
       
   225             'compression_strategy': 'strategy',
       
   226         }
       
   227 
       
   228         for arg, attr in args.items():
       
   229             if arg not in kwargs:
       
   230                 kwargs[arg] = getattr(params, attr)
       
   231 
       
   232         return ZstdCompressionParameters(**kwargs)
       
   233 
       
   234     def __init__(self, format=0, compression_level=0, window_log=0, hash_log=0,
       
   235                  chain_log=0, search_log=0, min_match=0, target_length=0,
       
   236                  compression_strategy=0, write_content_size=1, write_checksum=0,
       
   237                  write_dict_id=0, job_size=0, overlap_size_log=0,
       
   238                  force_max_window=0, enable_ldm=0, ldm_hash_log=0,
       
   239                  ldm_min_match=0, ldm_bucket_size_log=0, ldm_hash_every_log=0,
       
   240                  threads=0):
       
   241 
       
   242         if threads < 0:
       
   243             threads = _cpu_count()
       
   244 
       
   245         self.format = format
       
   246         self.compression_level = compression_level
       
   247         self.window_log = window_log
       
   248         self.hash_log = hash_log
       
   249         self.chain_log = chain_log
       
   250         self.search_log = search_log
       
   251         self.min_match = min_match
       
   252         self.target_length = target_length
       
   253         self.compression_strategy = compression_strategy
       
   254         self.write_content_size = write_content_size
       
   255         self.write_checksum = write_checksum
       
   256         self.write_dict_id = write_dict_id
       
   257         self.job_size = job_size
       
   258         self.overlap_size_log = overlap_size_log
       
   259         self.force_max_window = force_max_window
       
   260         self.enable_ldm = enable_ldm
       
   261         self.ldm_hash_log = ldm_hash_log
       
   262         self.ldm_min_match = ldm_min_match
       
   263         self.ldm_bucket_size_log = ldm_bucket_size_log
       
   264         self.ldm_hash_every_log = ldm_hash_every_log
       
   265         self.threads = threads
       
   266 
       
   267         self.params = _make_cctx_params(self)
       
   268 
       
   269     def estimated_compression_context_size(self):
       
   270         return lib.ZSTD_estimateCCtxSize_usingCCtxParams(self.params)
       
   271 
       
   272 CompressionParameters = ZstdCompressionParameters
       
   273 
       
   274 def estimate_decompression_context_size():
       
   275     return lib.ZSTD_estimateDCtxSize()
       
   276 
       
   277 
       
   278 def _set_compression_parameter(params, param, value):
       
   279     zresult = lib.ZSTD_CCtxParam_setParameter(params, param,
       
   280                                               ffi.cast('unsigned', value))
       
   281     if lib.ZSTD_isError(zresult):
       
   282         raise ZstdError('unable to set compression context parameter: %s' %
       
   283                         _zstd_error(zresult))
       
   284 
       
   285 class ZstdCompressionWriter(object):
       
   286     def __init__(self, compressor, writer, source_size, write_size):
       
   287         self._compressor = compressor
       
   288         self._writer = writer
       
   289         self._source_size = source_size
       
   290         self._write_size = write_size
       
   291         self._entered = False
       
   292         self._bytes_compressed = 0
       
   293 
       
   294     def __enter__(self):
       
   295         if self._entered:
       
   296             raise ZstdError('cannot __enter__ multiple times')
       
   297 
       
   298         zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._compressor._cctx,
       
   299                                                   self._source_size)
       
   300         if lib.ZSTD_isError(zresult):
       
   301             raise ZstdError('error setting source size: %s' %
       
   302                             _zstd_error(zresult))
       
   303 
       
   304         self._entered = True
       
   305         return self
       
   306 
       
   307     def __exit__(self, exc_type, exc_value, exc_tb):
       
   308         self._entered = False
       
   309 
       
   310         if not exc_type and not exc_value and not exc_tb:
       
   311             dst_buffer = ffi.new('char[]', self._write_size)
       
   312 
       
   313             out_buffer = ffi.new('ZSTD_outBuffer *')
       
   314             in_buffer = ffi.new('ZSTD_inBuffer *')
       
   315 
       
   316             out_buffer.dst = dst_buffer
       
   317             out_buffer.size = len(dst_buffer)
       
   318             out_buffer.pos = 0
       
   319 
       
   320             in_buffer.src = ffi.NULL
       
   321             in_buffer.size = 0
       
   322             in_buffer.pos = 0
       
   323 
       
   324             while True:
       
   325                 zresult = lib.ZSTD_compress_generic(self._compressor._cctx,
       
   326                                                     out_buffer, in_buffer,
       
   327                                                     lib.ZSTD_e_end)
       
   328 
       
   329                 if lib.ZSTD_isError(zresult):
       
   330                     raise ZstdError('error ending compression stream: %s' %
       
   331                                     _zstd_error(zresult))
       
   332 
       
   333                 if out_buffer.pos:
       
   334                     self._writer.write(ffi.buffer(out_buffer.dst, out_buffer.pos)[:])
       
   335                     out_buffer.pos = 0
       
   336 
       
   337                 if zresult == 0:
       
   338                     break
       
   339 
       
   340         self._compressor = None
       
   341 
       
   342         return False
       
   343 
       
   344     def memory_size(self):
       
   345         if not self._entered:
       
   346             raise ZstdError('cannot determine size of an inactive compressor; '
       
   347                             'call when a context manager is active')
       
   348 
       
   349         return lib.ZSTD_sizeof_CCtx(self._compressor._cctx)
       
   350 
       
   351     def write(self, data):
       
   352         if not self._entered:
       
   353             raise ZstdError('write() must be called from an active context '
       
   354                             'manager')
       
   355 
       
   356         total_write = 0
       
   357 
       
   358         data_buffer = ffi.from_buffer(data)
       
   359 
       
   360         in_buffer = ffi.new('ZSTD_inBuffer *')
       
   361         in_buffer.src = data_buffer
       
   362         in_buffer.size = len(data_buffer)
       
   363         in_buffer.pos = 0
       
   364 
       
   365         out_buffer = ffi.new('ZSTD_outBuffer *')
       
   366         dst_buffer = ffi.new('char[]', self._write_size)
       
   367         out_buffer.dst = dst_buffer
       
   368         out_buffer.size = self._write_size
       
   369         out_buffer.pos = 0
       
   370 
       
   371         while in_buffer.pos < in_buffer.size:
       
   372             zresult = lib.ZSTD_compress_generic(self._compressor._cctx,
       
   373                                                 out_buffer, in_buffer,
       
   374                                                 lib.ZSTD_e_continue)
       
   375             if lib.ZSTD_isError(zresult):
       
   376                 raise ZstdError('zstd compress error: %s' %
       
   377                                 _zstd_error(zresult))
       
   378 
       
   379             if out_buffer.pos:
       
   380                 self._writer.write(ffi.buffer(out_buffer.dst, out_buffer.pos)[:])
       
   381                 total_write += out_buffer.pos
       
   382                 self._bytes_compressed += out_buffer.pos
       
   383                 out_buffer.pos = 0
       
   384 
       
   385         return total_write
       
   386 
       
   387     def flush(self):
       
   388         if not self._entered:
       
   389             raise ZstdError('flush must be called from an active context manager')
       
   390 
       
   391         total_write = 0
       
   392 
       
   393         out_buffer = ffi.new('ZSTD_outBuffer *')
       
   394         dst_buffer = ffi.new('char[]', self._write_size)
       
   395         out_buffer.dst = dst_buffer
       
   396         out_buffer.size = self._write_size
       
   397         out_buffer.pos = 0
       
   398 
       
   399         in_buffer = ffi.new('ZSTD_inBuffer *')
       
   400         in_buffer.src = ffi.NULL
       
   401         in_buffer.size = 0
       
   402         in_buffer.pos = 0
       
   403 
       
   404         while True:
       
   405             zresult = lib.ZSTD_compress_generic(self._compressor._cctx,
       
   406                                                 out_buffer, in_buffer,
       
   407                                                 lib.ZSTD_e_flush)
       
   408             if lib.ZSTD_isError(zresult):
       
   409                 raise ZstdError('zstd compress error: %s' %
       
   410                                 _zstd_error(zresult))
       
   411 
       
   412             if out_buffer.pos:
       
   413                 self._writer.write(ffi.buffer(out_buffer.dst, out_buffer.pos)[:])
       
   414                 total_write += out_buffer.pos
       
   415                 self._bytes_compressed += out_buffer.pos
       
   416                 out_buffer.pos = 0
       
   417 
       
   418             if not zresult:
       
   419                 break
       
   420 
       
   421         return total_write
       
   422 
       
   423     def tell(self):
       
   424         return self._bytes_compressed
       
   425 
       
   426 
       
   427 class ZstdCompressionObj(object):
       
   428     def compress(self, data):
       
   429         if self._finished:
       
   430             raise ZstdError('cannot call compress() after compressor finished')
       
   431 
       
   432         data_buffer = ffi.from_buffer(data)
       
   433         source = ffi.new('ZSTD_inBuffer *')
       
   434         source.src = data_buffer
       
   435         source.size = len(data_buffer)
       
   436         source.pos = 0
       
   437 
       
   438         chunks = []
       
   439 
       
   440         while source.pos < len(data):
       
   441             zresult = lib.ZSTD_compress_generic(self._compressor._cctx,
       
   442                                                 self._out,
       
   443                                                 source,
       
   444                                                 lib.ZSTD_e_continue)
       
   445             if lib.ZSTD_isError(zresult):
       
   446                 raise ZstdError('zstd compress error: %s' %
       
   447                                 _zstd_error(zresult))
       
   448 
       
   449             if self._out.pos:
       
   450                 chunks.append(ffi.buffer(self._out.dst, self._out.pos)[:])
       
   451                 self._out.pos = 0
       
   452 
       
   453         return b''.join(chunks)
       
   454 
       
   455     def flush(self, flush_mode=COMPRESSOBJ_FLUSH_FINISH):
       
   456         if flush_mode not in (COMPRESSOBJ_FLUSH_FINISH, COMPRESSOBJ_FLUSH_BLOCK):
       
   457             raise ValueError('flush mode not recognized')
       
   458 
       
   459         if self._finished:
       
   460             raise ZstdError('compressor object already finished')
       
   461 
       
   462         if flush_mode == COMPRESSOBJ_FLUSH_BLOCK:
       
   463             z_flush_mode = lib.ZSTD_e_flush
       
   464         elif flush_mode == COMPRESSOBJ_FLUSH_FINISH:
       
   465             z_flush_mode = lib.ZSTD_e_end
       
   466             self._finished = True
       
   467         else:
       
   468             raise ZstdError('unhandled flush mode')
       
   469 
       
   470         assert self._out.pos == 0
       
   471 
       
   472         in_buffer = ffi.new('ZSTD_inBuffer *')
       
   473         in_buffer.src = ffi.NULL
       
   474         in_buffer.size = 0
       
   475         in_buffer.pos = 0
       
   476 
       
   477         chunks = []
       
   478 
       
   479         while True:
       
   480             zresult = lib.ZSTD_compress_generic(self._compressor._cctx,
       
   481                                                 self._out,
       
   482                                                 in_buffer,
       
   483                                                 z_flush_mode)
       
   484             if lib.ZSTD_isError(zresult):
       
   485                 raise ZstdError('error ending compression stream: %s' %
       
   486                                 _zstd_error(zresult))
       
   487 
       
   488             if self._out.pos:
       
   489                 chunks.append(ffi.buffer(self._out.dst, self._out.pos)[:])
       
   490                 self._out.pos = 0
       
   491 
       
   492             if not zresult:
       
   493                 break
       
   494 
       
   495         return b''.join(chunks)
       
   496 
       
   497 
       
   498 class ZstdCompressionChunker(object):
       
   499     def __init__(self, compressor, chunk_size):
       
   500         self._compressor = compressor
       
   501         self._out = ffi.new('ZSTD_outBuffer *')
       
   502         self._dst_buffer = ffi.new('char[]', chunk_size)
       
   503         self._out.dst = self._dst_buffer
       
   504         self._out.size = chunk_size
       
   505         self._out.pos = 0
       
   506 
       
   507         self._in = ffi.new('ZSTD_inBuffer *')
       
   508         self._in.src = ffi.NULL
       
   509         self._in.size = 0
       
   510         self._in.pos = 0
       
   511         self._finished = False
       
   512 
       
   513     def compress(self, data):
       
   514         if self._finished:
       
   515             raise ZstdError('cannot call compress() after compression finished')
       
   516 
       
   517         if self._in.src != ffi.NULL:
       
   518             raise ZstdError('cannot perform operation before consuming output '
       
   519                             'from previous operation')
       
   520 
       
   521         data_buffer = ffi.from_buffer(data)
       
   522 
       
   523         if not len(data_buffer):
       
   524             return
       
   525 
       
   526         self._in.src = data_buffer
       
   527         self._in.size = len(data_buffer)
       
   528         self._in.pos = 0
       
   529 
       
   530         while self._in.pos < self._in.size:
       
   531             zresult = lib.ZSTD_compress_generic(self._compressor._cctx,
       
   532                                                 self._out,
       
   533                                                 self._in,
       
   534                                                 lib.ZSTD_e_continue)
       
   535 
       
   536             if self._in.pos == self._in.size:
       
   537                 self._in.src = ffi.NULL
       
   538                 self._in.size = 0
       
   539                 self._in.pos = 0
       
   540 
       
   541             if lib.ZSTD_isError(zresult):
       
   542                 raise ZstdError('zstd compress error: %s' %
       
   543                                 _zstd_error(zresult))
       
   544 
       
   545             if self._out.pos == self._out.size:
       
   546                 yield ffi.buffer(self._out.dst, self._out.pos)[:]
       
   547                 self._out.pos = 0
       
   548 
       
   549     def flush(self):
       
   550         if self._finished:
       
   551             raise ZstdError('cannot call flush() after compression finished')
       
   552 
       
   553         if self._in.src != ffi.NULL:
       
   554             raise ZstdError('cannot call flush() before consuming output from '
       
   555                             'previous operation')
       
   556 
       
   557         while True:
       
   558             zresult = lib.ZSTD_compress_generic(self._compressor._cctx,
       
   559                                                 self._out, self._in,
       
   560                                                 lib.ZSTD_e_flush)
       
   561             if lib.ZSTD_isError(zresult):
       
   562                 raise ZstdError('zstd compress error: %s' % _zstd_error(zresult))
       
   563 
       
   564             if self._out.pos:
       
   565                 yield ffi.buffer(self._out.dst, self._out.pos)[:]
       
   566                 self._out.pos = 0
       
   567 
       
   568             if not zresult:
       
   569                 return
       
   570 
       
   571     def finish(self):
       
   572         if self._finished:
       
   573             raise ZstdError('cannot call finish() after compression finished')
       
   574 
       
   575         if self._in.src != ffi.NULL:
       
   576             raise ZstdError('cannot call finish() before consuming output from '
       
   577                             'previous operation')
       
   578 
       
   579         while True:
       
   580             zresult = lib.ZSTD_compress_generic(self._compressor._cctx,
       
   581                                                 self._out, self._in,
       
   582                                                 lib.ZSTD_e_end)
       
   583             if lib.ZSTD_isError(zresult):
       
   584                 raise ZstdError('zstd compress error: %s' % _zstd_error(zresult))
       
   585 
       
   586             if self._out.pos:
       
   587                 yield ffi.buffer(self._out.dst, self._out.pos)[:]
       
   588                 self._out.pos = 0
       
   589 
       
   590             if not zresult:
       
   591                 self._finished = True
       
   592                 return
       
   593 
       
   594 
       
   595 class CompressionReader(object):
       
   596     def __init__(self, compressor, source, read_size):
       
   597         self._compressor = compressor
       
   598         self._source = source
       
   599         self._read_size = read_size
       
   600         self._entered = False
       
   601         self._closed = False
       
   602         self._bytes_compressed = 0
       
   603         self._finished_input = False
       
   604         self._finished_output = False
       
   605 
       
   606         self._in_buffer = ffi.new('ZSTD_inBuffer *')
       
   607         # Holds a ref so backing bytes in self._in_buffer stay alive.
       
   608         self._source_buffer = None
       
   609 
       
   610     def __enter__(self):
       
   611         if self._entered:
       
   612             raise ValueError('cannot __enter__ multiple times')
       
   613 
       
   614         self._entered = True
       
   615         return self
       
   616 
       
   617     def __exit__(self, exc_type, exc_value, exc_tb):
       
   618         self._entered = False
       
   619         self._closed = True
       
   620         self._source = None
       
   621         self._compressor = None
       
   622 
       
   623         return False
       
   624 
       
   625     def readable(self):
       
   626         return True
       
   627 
       
   628     def writable(self):
       
   629         return False
       
   630 
       
   631     def seekable(self):
       
   632         return False
       
   633 
       
   634     def readline(self):
       
   635         raise io.UnsupportedOperation()
       
   636 
       
   637     def readlines(self):
       
   638         raise io.UnsupportedOperation()
       
   639 
       
   640     def write(self, data):
       
   641         raise OSError('stream is not writable')
       
   642 
       
   643     def writelines(self, ignored):
       
   644         raise OSError('stream is not writable')
       
   645 
       
   646     def isatty(self):
       
   647         return False
       
   648 
       
   649     def flush(self):
       
   650         return None
       
   651 
       
   652     def close(self):
       
   653         self._closed = True
       
   654         return None
       
   655 
       
   656     @property
       
   657     def closed(self):
       
   658         return self._closed
       
   659 
       
   660     def tell(self):
       
   661         return self._bytes_compressed
       
   662 
       
   663     def readall(self):
       
   664         raise NotImplementedError()
       
   665 
       
   666     def __iter__(self):
       
   667         raise io.UnsupportedOperation()
       
   668 
       
   669     def __next__(self):
       
   670         raise io.UnsupportedOperation()
       
   671 
       
   672     next = __next__
       
   673 
       
   674     def read(self, size=-1):
       
   675         if self._closed:
       
   676             raise ValueError('stream is closed')
       
   677 
       
   678         if self._finished_output:
       
   679             return b''
       
   680 
       
   681         if size < 1:
       
   682             raise ValueError('cannot read negative or size 0 amounts')
       
   683 
       
   684         # Need a dedicated ref to dest buffer otherwise it gets collected.
       
   685         dst_buffer = ffi.new('char[]', size)
       
   686         out_buffer = ffi.new('ZSTD_outBuffer *')
       
   687         out_buffer.dst = dst_buffer
       
   688         out_buffer.size = size
       
   689         out_buffer.pos = 0
       
   690 
       
   691         def compress_input():
       
   692             if self._in_buffer.pos >= self._in_buffer.size:
       
   693                 return
       
   694 
       
   695             old_pos = out_buffer.pos
       
   696 
       
   697             zresult = lib.ZSTD_compress_generic(self._compressor._cctx,
       
   698                                                 out_buffer, self._in_buffer,
       
   699                                                 lib.ZSTD_e_continue)
       
   700 
       
   701             self._bytes_compressed += out_buffer.pos - old_pos
       
   702 
       
   703             if self._in_buffer.pos == self._in_buffer.size:
       
   704                 self._in_buffer.src = ffi.NULL
       
   705                 self._in_buffer.pos = 0
       
   706                 self._in_buffer.size = 0
       
   707                 self._source_buffer = None
       
   708 
       
   709                 if not hasattr(self._source, 'read'):
       
   710                     self._finished_input = True
       
   711 
       
   712             if lib.ZSTD_isError(zresult):
       
   713                 raise ZstdError('zstd compress error: %s',
       
   714                                 _zstd_error(zresult))
       
   715 
       
   716             if out_buffer.pos and out_buffer.pos == out_buffer.size:
       
   717                 return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
       
   718 
       
   719         def get_input():
       
   720             if self._finished_input:
       
   721                 return
       
   722 
       
   723             if hasattr(self._source, 'read'):
       
   724                 data = self._source.read(self._read_size)
       
   725 
       
   726                 if not data:
       
   727                     self._finished_input = True
       
   728                     return
       
   729 
       
   730                 self._source_buffer = ffi.from_buffer(data)
       
   731                 self._in_buffer.src = self._source_buffer
       
   732                 self._in_buffer.size = len(self._source_buffer)
       
   733                 self._in_buffer.pos = 0
       
   734             else:
       
   735                 self._source_buffer = ffi.from_buffer(self._source)
       
   736                 self._in_buffer.src = self._source_buffer
       
   737                 self._in_buffer.size = len(self._source_buffer)
       
   738                 self._in_buffer.pos = 0
       
   739 
       
   740         result = compress_input()
       
   741         if result:
       
   742             return result
       
   743 
       
   744         while not self._finished_input:
       
   745             get_input()
       
   746             result = compress_input()
       
   747             if result:
       
   748                 return result
       
   749 
       
   750         # EOF
       
   751         old_pos = out_buffer.pos
       
   752 
       
   753         zresult = lib.ZSTD_compress_generic(self._compressor._cctx,
       
   754                                             out_buffer, self._in_buffer,
       
   755                                             lib.ZSTD_e_end)
       
   756 
       
   757         self._bytes_compressed += out_buffer.pos - old_pos
       
   758 
       
   759         if lib.ZSTD_isError(zresult):
       
   760             raise ZstdError('error ending compression stream: %s',
       
   761                             _zstd_error(zresult))
       
   762 
       
   763         if zresult == 0:
       
   764             self._finished_output = True
       
   765 
       
   766         return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
       
   767 
       
   768 class ZstdCompressor(object):
       
   769     def __init__(self, level=3, dict_data=None, compression_params=None,
       
   770                  write_checksum=None, write_content_size=None,
       
   771                  write_dict_id=None, threads=0):
       
   772         if level > lib.ZSTD_maxCLevel():
       
   773             raise ValueError('level must be less than %d' % lib.ZSTD_maxCLevel())
       
   774 
       
   775         if threads < 0:
       
   776             threads = _cpu_count()
       
   777 
       
   778         if compression_params and write_checksum is not None:
       
   779             raise ValueError('cannot define compression_params and '
       
   780                              'write_checksum')
       
   781 
       
   782         if compression_params and write_content_size is not None:
       
   783             raise ValueError('cannot define compression_params and '
       
   784                              'write_content_size')
       
   785 
       
   786         if compression_params and write_dict_id is not None:
       
   787             raise ValueError('cannot define compression_params and '
       
   788                              'write_dict_id')
       
   789 
       
   790         if compression_params and threads:
       
   791             raise ValueError('cannot define compression_params and threads')
       
   792 
       
   793         if compression_params:
       
   794             self._params = _make_cctx_params(compression_params)
       
   795         else:
       
   796             if write_dict_id is None:
       
   797                 write_dict_id = True
       
   798 
       
   799             params = lib.ZSTD_createCCtxParams()
       
   800             if params == ffi.NULL:
       
   801                 raise MemoryError()
       
   802 
       
   803             self._params = ffi.gc(params, lib.ZSTD_freeCCtxParams)
       
   804 
       
   805             _set_compression_parameter(self._params,
       
   806                                        lib.ZSTD_p_compressionLevel,
       
   807                                        level)
       
   808 
       
   809             _set_compression_parameter(
       
   810                 self._params,
       
   811                 lib.ZSTD_p_contentSizeFlag,
       
   812                 write_content_size if write_content_size is not None else 1)
       
   813 
       
   814             _set_compression_parameter(self._params,
       
   815                                        lib.ZSTD_p_checksumFlag,
       
   816                                        1 if write_checksum else 0)
       
   817 
       
   818             _set_compression_parameter(self._params,
       
   819                                        lib.ZSTD_p_dictIDFlag,
       
   820                                        1 if write_dict_id else 0)
       
   821 
       
   822             if threads:
       
   823                 _set_compression_parameter(self._params,
       
   824                                            lib.ZSTD_p_nbWorkers,
       
   825                                            threads)
       
   826 
       
   827         cctx = lib.ZSTD_createCCtx()
       
   828         if cctx == ffi.NULL:
       
   829             raise MemoryError()
       
   830 
       
   831         self._cctx = cctx
       
   832         self._dict_data = dict_data
       
   833 
       
   834         # We defer setting up garbage collection until after calling
       
   835         # _setup_cctx() to ensure the memory size estimate is more accurate.
       
   836         try:
       
   837             self._setup_cctx()
       
   838         finally:
       
   839             self._cctx = ffi.gc(cctx, lib.ZSTD_freeCCtx,
       
   840                                 size=lib.ZSTD_sizeof_CCtx(cctx))
       
   841 
       
   842     def _setup_cctx(self):
       
   843         zresult = lib.ZSTD_CCtx_setParametersUsingCCtxParams(self._cctx,
       
   844                                                              self._params)
       
   845         if lib.ZSTD_isError(zresult):
       
   846             raise ZstdError('could not set compression parameters: %s' %
       
   847                             _zstd_error(zresult))
       
   848 
       
   849         dict_data = self._dict_data
       
   850 
       
   851         if dict_data:
       
   852             if dict_data._cdict:
       
   853                 zresult = lib.ZSTD_CCtx_refCDict(self._cctx, dict_data._cdict)
       
   854             else:
       
   855                 zresult = lib.ZSTD_CCtx_loadDictionary_advanced(
       
   856                     self._cctx, dict_data.as_bytes(), len(dict_data),
       
   857                     lib.ZSTD_dlm_byRef, dict_data._dict_type)
       
   858 
       
   859             if lib.ZSTD_isError(zresult):
       
   860                 raise ZstdError('could not load compression dictionary: %s' %
       
   861                                 _zstd_error(zresult))
       
   862 
       
   863     def memory_size(self):
       
   864         return lib.ZSTD_sizeof_CCtx(self._cctx)
       
   865 
       
   866     def compress(self, data):
       
   867         lib.ZSTD_CCtx_reset(self._cctx)
       
   868 
       
   869         data_buffer = ffi.from_buffer(data)
       
   870 
       
   871         dest_size = lib.ZSTD_compressBound(len(data_buffer))
       
   872         out = new_nonzero('char[]', dest_size)
       
   873 
       
   874         zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, len(data_buffer))
       
   875         if lib.ZSTD_isError(zresult):
       
   876             raise ZstdError('error setting source size: %s' %
       
   877                             _zstd_error(zresult))
       
   878 
       
   879         out_buffer = ffi.new('ZSTD_outBuffer *')
       
   880         in_buffer = ffi.new('ZSTD_inBuffer *')
       
   881 
       
   882         out_buffer.dst = out
       
   883         out_buffer.size = dest_size
       
   884         out_buffer.pos = 0
       
   885 
       
   886         in_buffer.src = data_buffer
       
   887         in_buffer.size = len(data_buffer)
       
   888         in_buffer.pos = 0
       
   889 
       
   890         zresult = lib.ZSTD_compress_generic(self._cctx,
       
   891                                             out_buffer,
       
   892                                             in_buffer,
       
   893                                             lib.ZSTD_e_end)
       
   894 
       
   895         if lib.ZSTD_isError(zresult):
       
   896             raise ZstdError('cannot compress: %s' %
       
   897                             _zstd_error(zresult))
       
   898         elif zresult:
       
   899             raise ZstdError('unexpected partial frame flush')
       
   900 
       
   901         return ffi.buffer(out, out_buffer.pos)[:]
       
   902 
       
   903     def compressobj(self, size=-1):
       
   904         lib.ZSTD_CCtx_reset(self._cctx)
       
   905 
       
   906         if size < 0:
       
   907             size = lib.ZSTD_CONTENTSIZE_UNKNOWN
       
   908 
       
   909         zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size)
       
   910         if lib.ZSTD_isError(zresult):
       
   911             raise ZstdError('error setting source size: %s' %
       
   912                             _zstd_error(zresult))
       
   913 
       
   914         cobj = ZstdCompressionObj()
       
   915         cobj._out = ffi.new('ZSTD_outBuffer *')
       
   916         cobj._dst_buffer = ffi.new('char[]', COMPRESSION_RECOMMENDED_OUTPUT_SIZE)
       
   917         cobj._out.dst = cobj._dst_buffer
       
   918         cobj._out.size = COMPRESSION_RECOMMENDED_OUTPUT_SIZE
       
   919         cobj._out.pos = 0
       
   920         cobj._compressor = self
       
   921         cobj._finished = False
       
   922 
       
   923         return cobj
       
   924 
       
   925     def chunker(self, size=-1, chunk_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE):
       
   926         lib.ZSTD_CCtx_reset(self._cctx)
       
   927 
       
   928         if size < 0:
       
   929             size = lib.ZSTD_CONTENTSIZE_UNKNOWN
       
   930 
       
   931         zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size)
       
   932         if lib.ZSTD_isError(zresult):
       
   933             raise ZstdError('error setting source size: %s' %
       
   934                             _zstd_error(zresult))
       
   935 
       
   936         return ZstdCompressionChunker(self, chunk_size=chunk_size)
       
   937 
       
   938     def copy_stream(self, ifh, ofh, size=-1,
       
   939                     read_size=COMPRESSION_RECOMMENDED_INPUT_SIZE,
       
   940                     write_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE):
       
   941 
       
   942         if not hasattr(ifh, 'read'):
       
   943             raise ValueError('first argument must have a read() method')
       
   944         if not hasattr(ofh, 'write'):
       
   945             raise ValueError('second argument must have a write() method')
       
   946 
       
   947         lib.ZSTD_CCtx_reset(self._cctx)
       
   948 
       
   949         if size < 0:
       
   950             size = lib.ZSTD_CONTENTSIZE_UNKNOWN
       
   951 
       
   952         zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size)
       
   953         if lib.ZSTD_isError(zresult):
       
   954             raise ZstdError('error setting source size: %s' %
       
   955                             _zstd_error(zresult))
       
   956 
       
   957         in_buffer = ffi.new('ZSTD_inBuffer *')
       
   958         out_buffer = ffi.new('ZSTD_outBuffer *')
       
   959 
       
   960         dst_buffer = ffi.new('char[]', write_size)
       
   961         out_buffer.dst = dst_buffer
       
   962         out_buffer.size = write_size
       
   963         out_buffer.pos = 0
       
   964 
       
   965         total_read, total_write = 0, 0
       
   966 
       
   967         while True:
       
   968             data = ifh.read(read_size)
       
   969             if not data:
       
   970                 break
       
   971 
       
   972             data_buffer = ffi.from_buffer(data)
       
   973             total_read += len(data_buffer)
       
   974             in_buffer.src = data_buffer
       
   975             in_buffer.size = len(data_buffer)
       
   976             in_buffer.pos = 0
       
   977 
       
   978             while in_buffer.pos < in_buffer.size:
       
   979                 zresult = lib.ZSTD_compress_generic(self._cctx,
       
   980                                                     out_buffer,
       
   981                                                     in_buffer,
       
   982                                                     lib.ZSTD_e_continue)
       
   983                 if lib.ZSTD_isError(zresult):
       
   984                     raise ZstdError('zstd compress error: %s' %
       
   985                                     _zstd_error(zresult))
       
   986 
       
   987                 if out_buffer.pos:
       
   988                     ofh.write(ffi.buffer(out_buffer.dst, out_buffer.pos))
       
   989                     total_write += out_buffer.pos
       
   990                     out_buffer.pos = 0
       
   991 
       
   992         # We've finished reading. Flush the compressor.
       
   993         while True:
       
   994             zresult = lib.ZSTD_compress_generic(self._cctx,
       
   995                                                 out_buffer,
       
   996                                                 in_buffer,
       
   997                                                 lib.ZSTD_e_end)
       
   998             if lib.ZSTD_isError(zresult):
       
   999                 raise ZstdError('error ending compression stream: %s' %
       
  1000                                 _zstd_error(zresult))
       
  1001 
       
  1002             if out_buffer.pos:
       
  1003                 ofh.write(ffi.buffer(out_buffer.dst, out_buffer.pos))
       
  1004                 total_write += out_buffer.pos
       
  1005                 out_buffer.pos = 0
       
  1006 
       
  1007             if zresult == 0:
       
  1008                 break
       
  1009 
       
  1010         return total_read, total_write
       
  1011 
       
  1012     def stream_reader(self, source, size=-1,
       
  1013                       read_size=COMPRESSION_RECOMMENDED_INPUT_SIZE):
       
  1014         lib.ZSTD_CCtx_reset(self._cctx)
       
  1015 
       
  1016         try:
       
  1017             size = len(source)
       
  1018         except Exception:
       
  1019             pass
       
  1020 
       
  1021         if size < 0:
       
  1022             size = lib.ZSTD_CONTENTSIZE_UNKNOWN
       
  1023 
       
  1024         zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size)
       
  1025         if lib.ZSTD_isError(zresult):
       
  1026             raise ZstdError('error setting source size: %s' %
       
  1027                             _zstd_error(zresult))
       
  1028 
       
  1029         return CompressionReader(self, source, read_size)
       
  1030 
       
  1031     def stream_writer(self, writer, size=-1,
       
  1032                  write_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE):
       
  1033 
       
  1034         if not hasattr(writer, 'write'):
       
  1035             raise ValueError('must pass an object with a write() method')
       
  1036 
       
  1037         lib.ZSTD_CCtx_reset(self._cctx)
       
  1038 
       
  1039         if size < 0:
       
  1040             size = lib.ZSTD_CONTENTSIZE_UNKNOWN
       
  1041 
       
  1042         return ZstdCompressionWriter(self, writer, size, write_size)
       
  1043 
       
  1044     write_to = stream_writer
       
  1045 
       
  1046     def read_to_iter(self, reader, size=-1,
       
  1047                      read_size=COMPRESSION_RECOMMENDED_INPUT_SIZE,
       
  1048                      write_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE):
       
  1049         if hasattr(reader, 'read'):
       
  1050             have_read = True
       
  1051         elif hasattr(reader, '__getitem__'):
       
  1052             have_read = False
       
  1053             buffer_offset = 0
       
  1054             size = len(reader)
       
  1055         else:
       
  1056             raise ValueError('must pass an object with a read() method or '
       
  1057                              'conforms to buffer protocol')
       
  1058 
       
  1059         lib.ZSTD_CCtx_reset(self._cctx)
       
  1060 
       
  1061         if size < 0:
       
  1062             size = lib.ZSTD_CONTENTSIZE_UNKNOWN
       
  1063 
       
  1064         zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size)
       
  1065         if lib.ZSTD_isError(zresult):
       
  1066             raise ZstdError('error setting source size: %s' %
       
  1067                             _zstd_error(zresult))
       
  1068 
       
  1069         in_buffer = ffi.new('ZSTD_inBuffer *')
       
  1070         out_buffer = ffi.new('ZSTD_outBuffer *')
       
  1071 
       
  1072         in_buffer.src = ffi.NULL
       
  1073         in_buffer.size = 0
       
  1074         in_buffer.pos = 0
       
  1075 
       
  1076         dst_buffer = ffi.new('char[]', write_size)
       
  1077         out_buffer.dst = dst_buffer
       
  1078         out_buffer.size = write_size
       
  1079         out_buffer.pos = 0
       
  1080 
       
  1081         while True:
       
  1082             # We should never have output data sitting around after a previous
       
  1083             # iteration.
       
  1084             assert out_buffer.pos == 0
       
  1085 
       
  1086             # Collect input data.
       
  1087             if have_read:
       
  1088                 read_result = reader.read(read_size)
       
  1089             else:
       
  1090                 remaining = len(reader) - buffer_offset
       
  1091                 slice_size = min(remaining, read_size)
       
  1092                 read_result = reader[buffer_offset:buffer_offset + slice_size]
       
  1093                 buffer_offset += slice_size
       
  1094 
       
  1095             # No new input data. Break out of the read loop.
       
  1096             if not read_result:
       
  1097                 break
       
  1098 
       
  1099             # Feed all read data into the compressor and emit output until
       
  1100             # exhausted.
       
  1101             read_buffer = ffi.from_buffer(read_result)
       
  1102             in_buffer.src = read_buffer
       
  1103             in_buffer.size = len(read_buffer)
       
  1104             in_buffer.pos = 0
       
  1105 
       
  1106             while in_buffer.pos < in_buffer.size:
       
  1107                 zresult = lib.ZSTD_compress_generic(self._cctx, out_buffer, in_buffer,
       
  1108                                                     lib.ZSTD_e_continue)
       
  1109                 if lib.ZSTD_isError(zresult):
       
  1110                     raise ZstdError('zstd compress error: %s' %
       
  1111                                     _zstd_error(zresult))
       
  1112 
       
  1113                 if out_buffer.pos:
       
  1114                     data = ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
       
  1115                     out_buffer.pos = 0
       
  1116                     yield data
       
  1117 
       
  1118             assert out_buffer.pos == 0
       
  1119 
       
  1120             # And repeat the loop to collect more data.
       
  1121             continue
       
  1122 
       
  1123         # If we get here, input is exhausted. End the stream and emit what
       
  1124         # remains.
       
  1125         while True:
       
  1126             assert out_buffer.pos == 0
       
  1127             zresult = lib.ZSTD_compress_generic(self._cctx,
       
  1128                                                 out_buffer,
       
  1129                                                 in_buffer,
       
  1130                                                 lib.ZSTD_e_end)
       
  1131             if lib.ZSTD_isError(zresult):
       
  1132                 raise ZstdError('error ending compression stream: %s' %
       
  1133                                 _zstd_error(zresult))
       
  1134 
       
  1135             if out_buffer.pos:
       
  1136                 data = ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
       
  1137                 out_buffer.pos = 0
       
  1138                 yield data
       
  1139 
       
  1140             if zresult == 0:
       
  1141                 break
       
  1142 
       
  1143     read_from = read_to_iter
       
  1144 
       
  1145     def frame_progression(self):
       
  1146         progression = lib.ZSTD_getFrameProgression(self._cctx)
       
  1147 
       
  1148         return progression.ingested, progression.consumed, progression.produced
       
  1149 
       
  1150 
       
  1151 class FrameParameters(object):
       
  1152     def __init__(self, fparams):
       
  1153         self.content_size = fparams.frameContentSize
       
  1154         self.window_size = fparams.windowSize
       
  1155         self.dict_id = fparams.dictID
       
  1156         self.has_checksum = bool(fparams.checksumFlag)
       
  1157 
       
  1158 
       
  1159 def frame_content_size(data):
       
  1160     data_buffer = ffi.from_buffer(data)
       
  1161 
       
  1162     size = lib.ZSTD_getFrameContentSize(data_buffer, len(data_buffer))
       
  1163 
       
  1164     if size == lib.ZSTD_CONTENTSIZE_ERROR:
       
  1165         raise ZstdError('error when determining content size')
       
  1166     elif size == lib.ZSTD_CONTENTSIZE_UNKNOWN:
       
  1167         return -1
       
  1168     else:
       
  1169         return size
       
  1170 
       
  1171 
       
  1172 def frame_header_size(data):
       
  1173     data_buffer = ffi.from_buffer(data)
       
  1174 
       
  1175     zresult = lib.ZSTD_frameHeaderSize(data_buffer, len(data_buffer))
       
  1176     if lib.ZSTD_isError(zresult):
       
  1177         raise ZstdError('could not determine frame header size: %s' %
       
  1178                         _zstd_error(zresult))
       
  1179 
       
  1180     return zresult
       
  1181 
       
  1182 
       
  1183 def get_frame_parameters(data):
       
  1184     params = ffi.new('ZSTD_frameHeader *')
       
  1185 
       
  1186     data_buffer = ffi.from_buffer(data)
       
  1187     zresult = lib.ZSTD_getFrameHeader(params, data_buffer, len(data_buffer))
       
  1188     if lib.ZSTD_isError(zresult):
       
  1189         raise ZstdError('cannot get frame parameters: %s' %
       
  1190                         _zstd_error(zresult))
       
  1191 
       
  1192     if zresult:
       
  1193         raise ZstdError('not enough data for frame parameters; need %d bytes' %
       
  1194                         zresult)
       
  1195 
       
  1196     return FrameParameters(params[0])
       
  1197 
       
  1198 
       
  1199 class ZstdCompressionDict(object):
       
  1200     def __init__(self, data, dict_type=DICT_TYPE_AUTO, k=0, d=0):
       
  1201         assert isinstance(data, bytes_type)
       
  1202         self._data = data
       
  1203         self.k = k
       
  1204         self.d = d
       
  1205 
       
  1206         if dict_type not in (DICT_TYPE_AUTO, DICT_TYPE_RAWCONTENT,
       
  1207                              DICT_TYPE_FULLDICT):
       
  1208             raise ValueError('invalid dictionary load mode: %d; must use '
       
  1209                              'DICT_TYPE_* constants')
       
  1210 
       
  1211         self._dict_type = dict_type
       
  1212         self._cdict = None
       
  1213 
       
  1214     def __len__(self):
       
  1215         return len(self._data)
       
  1216 
       
  1217     def dict_id(self):
       
  1218         return int_type(lib.ZDICT_getDictID(self._data, len(self._data)))
       
  1219 
       
  1220     def as_bytes(self):
       
  1221         return self._data
       
  1222 
       
  1223     def precompute_compress(self, level=0, compression_params=None):
       
  1224         if level and compression_params:
       
  1225             raise ValueError('must only specify one of level or '
       
  1226                              'compression_params')
       
  1227 
       
  1228         if not level and not compression_params:
       
  1229             raise ValueError('must specify one of level or compression_params')
       
  1230 
       
  1231         if level:
       
  1232             cparams = lib.ZSTD_getCParams(level, 0, len(self._data))
       
  1233         else:
       
  1234             cparams = ffi.new('ZSTD_compressionParameters')
       
  1235             cparams.chainLog = compression_params.chain_log
       
  1236             cparams.hashLog = compression_params.hash_log
       
  1237             cparams.searchLength = compression_params.min_match
       
  1238             cparams.searchLog = compression_params.search_log
       
  1239             cparams.strategy = compression_params.compression_strategy
       
  1240             cparams.targetLength = compression_params.target_length
       
  1241             cparams.windowLog = compression_params.window_log
       
  1242 
       
  1243         cdict = lib.ZSTD_createCDict_advanced(self._data, len(self._data),
       
  1244                                               lib.ZSTD_dlm_byRef,
       
  1245                                               self._dict_type,
       
  1246                                               cparams,
       
  1247                                               lib.ZSTD_defaultCMem)
       
  1248         if cdict == ffi.NULL:
       
  1249             raise ZstdError('unable to precompute dictionary')
       
  1250 
       
  1251         self._cdict = ffi.gc(cdict, lib.ZSTD_freeCDict,
       
  1252                              size=lib.ZSTD_sizeof_CDict(cdict))
       
  1253 
       
  1254     @property
       
  1255     def _ddict(self):
       
  1256         ddict = lib.ZSTD_createDDict_advanced(self._data, len(self._data),
       
  1257                                               lib.ZSTD_dlm_byRef,
       
  1258                                               self._dict_type,
       
  1259                                               lib.ZSTD_defaultCMem)
       
  1260 
       
  1261         if ddict == ffi.NULL:
       
  1262             raise ZstdError('could not create decompression dict')
       
  1263 
       
  1264         ddict = ffi.gc(ddict, lib.ZSTD_freeDDict,
       
  1265                        size=lib.ZSTD_sizeof_DDict(ddict))
       
  1266         self.__dict__['_ddict'] = ddict
       
  1267 
       
  1268         return ddict
       
  1269 
       
  1270 def train_dictionary(dict_size, samples, k=0, d=0, notifications=0, dict_id=0,
       
  1271                      level=0, steps=0, threads=0):
       
  1272     if not isinstance(samples, list):
       
  1273         raise TypeError('samples must be a list')
       
  1274 
       
  1275     if threads < 0:
       
  1276         threads = _cpu_count()
       
  1277 
       
  1278     total_size = sum(map(len, samples))
       
  1279 
       
  1280     samples_buffer = new_nonzero('char[]', total_size)
       
  1281     sample_sizes = new_nonzero('size_t[]', len(samples))
       
  1282 
       
  1283     offset = 0
       
  1284     for i, sample in enumerate(samples):
       
  1285         if not isinstance(sample, bytes_type):
       
  1286             raise ValueError('samples must be bytes')
       
  1287 
       
  1288         l = len(sample)
       
  1289         ffi.memmove(samples_buffer + offset, sample, l)
       
  1290         offset += l
       
  1291         sample_sizes[i] = l
       
  1292 
       
  1293     dict_data = new_nonzero('char[]', dict_size)
       
  1294 
       
  1295     dparams = ffi.new('ZDICT_cover_params_t *')[0]
       
  1296     dparams.k = k
       
  1297     dparams.d = d
       
  1298     dparams.steps = steps
       
  1299     dparams.nbThreads = threads
       
  1300     dparams.zParams.notificationLevel = notifications
       
  1301     dparams.zParams.dictID = dict_id
       
  1302     dparams.zParams.compressionLevel = level
       
  1303 
       
  1304     if (not dparams.k and not dparams.d and not dparams.steps
       
  1305         and not dparams.nbThreads and not dparams.zParams.notificationLevel
       
  1306         and not dparams.zParams.dictID
       
  1307         and not dparams.zParams.compressionLevel):
       
  1308         zresult = lib.ZDICT_trainFromBuffer(
       
  1309             ffi.addressof(dict_data), dict_size,
       
  1310             ffi.addressof(samples_buffer),
       
  1311             ffi.addressof(sample_sizes, 0), len(samples))
       
  1312     elif dparams.steps or dparams.nbThreads:
       
  1313         zresult = lib.ZDICT_optimizeTrainFromBuffer_cover(
       
  1314             ffi.addressof(dict_data), dict_size,
       
  1315             ffi.addressof(samples_buffer),
       
  1316             ffi.addressof(sample_sizes, 0), len(samples),
       
  1317             ffi.addressof(dparams))
       
  1318     else:
       
  1319         zresult = lib.ZDICT_trainFromBuffer_cover(
       
  1320             ffi.addressof(dict_data), dict_size,
       
  1321             ffi.addressof(samples_buffer),
       
  1322             ffi.addressof(sample_sizes, 0), len(samples),
       
  1323             dparams)
       
  1324 
       
  1325     if lib.ZDICT_isError(zresult):
       
  1326         msg = ffi.string(lib.ZDICT_getErrorName(zresult)).decode('utf-8')
       
  1327         raise ZstdError('cannot train dict: %s' % msg)
       
  1328 
       
  1329     return ZstdCompressionDict(ffi.buffer(dict_data, zresult)[:],
       
  1330                                dict_type=DICT_TYPE_FULLDICT,
       
  1331                                k=dparams.k, d=dparams.d)
       
  1332 
       
  1333 
       
  1334 class ZstdDecompressionObj(object):
       
  1335     def __init__(self, decompressor, write_size):
       
  1336         self._decompressor = decompressor
       
  1337         self._write_size = write_size
       
  1338         self._finished = False
       
  1339 
       
  1340     def decompress(self, data):
       
  1341         if self._finished:
       
  1342             raise ZstdError('cannot use a decompressobj multiple times')
       
  1343 
       
  1344         in_buffer = ffi.new('ZSTD_inBuffer *')
       
  1345         out_buffer = ffi.new('ZSTD_outBuffer *')
       
  1346 
       
  1347         data_buffer = ffi.from_buffer(data)
       
  1348         in_buffer.src = data_buffer
       
  1349         in_buffer.size = len(data_buffer)
       
  1350         in_buffer.pos = 0
       
  1351 
       
  1352         dst_buffer = ffi.new('char[]', self._write_size)
       
  1353         out_buffer.dst = dst_buffer
       
  1354         out_buffer.size = len(dst_buffer)
       
  1355         out_buffer.pos = 0
       
  1356 
       
  1357         chunks = []
       
  1358 
       
  1359         while True:
       
  1360             zresult = lib.ZSTD_decompress_generic(self._decompressor._dctx,
       
  1361                                                   out_buffer, in_buffer)
       
  1362             if lib.ZSTD_isError(zresult):
       
  1363                 raise ZstdError('zstd decompressor error: %s' %
       
  1364                                 _zstd_error(zresult))
       
  1365 
       
  1366             if zresult == 0:
       
  1367                 self._finished = True
       
  1368                 self._decompressor = None
       
  1369 
       
  1370             if out_buffer.pos:
       
  1371                 chunks.append(ffi.buffer(out_buffer.dst, out_buffer.pos)[:])
       
  1372 
       
  1373             if (zresult == 0 or
       
  1374                     (in_buffer.pos == in_buffer.size and out_buffer.pos == 0)):
       
  1375                 break
       
  1376 
       
  1377             out_buffer.pos = 0
       
  1378 
       
  1379         return b''.join(chunks)
       
  1380 
       
  1381 
       
  1382 class DecompressionReader(object):
       
  1383     def __init__(self, decompressor, source, read_size):
       
  1384         self._decompressor = decompressor
       
  1385         self._source = source
       
  1386         self._read_size = read_size
       
  1387         self._entered = False
       
  1388         self._closed = False
       
  1389         self._bytes_decompressed = 0
       
  1390         self._finished_input = False
       
  1391         self._finished_output = False
       
  1392         self._in_buffer = ffi.new('ZSTD_inBuffer *')
       
  1393         # Holds a ref to self._in_buffer.src.
       
  1394         self._source_buffer = None
       
  1395 
       
  1396     def __enter__(self):
       
  1397         if self._entered:
       
  1398             raise ValueError('cannot __enter__ multiple times')
       
  1399 
       
  1400         self._entered = True
       
  1401         return self
       
  1402 
       
  1403     def __exit__(self, exc_type, exc_value, exc_tb):
       
  1404         self._entered = False
       
  1405         self._closed = True
       
  1406         self._source = None
       
  1407         self._decompressor = None
       
  1408 
       
  1409         return False
       
  1410 
       
  1411     def readable(self):
       
  1412         return True
       
  1413 
       
  1414     def writable(self):
       
  1415         return False
       
  1416 
       
  1417     def seekable(self):
       
  1418         return True
       
  1419 
       
  1420     def readline(self):
       
  1421         raise NotImplementedError()
       
  1422 
       
  1423     def readlines(self):
       
  1424         raise NotImplementedError()
       
  1425 
       
  1426     def write(self, data):
       
  1427         raise io.UnsupportedOperation()
       
  1428 
       
  1429     def writelines(self, lines):
       
  1430         raise io.UnsupportedOperation()
       
  1431 
       
  1432     def isatty(self):
       
  1433         return False
       
  1434 
       
  1435     def flush(self):
       
  1436         return None
       
  1437 
       
  1438     def close(self):
       
  1439         self._closed = True
       
  1440         return None
       
  1441 
       
  1442     @property
       
  1443     def closed(self):
       
  1444         return self._closed
       
  1445 
       
  1446     def tell(self):
       
  1447         return self._bytes_decompressed
       
  1448 
       
  1449     def readall(self):
       
  1450         raise NotImplementedError()
       
  1451 
       
  1452     def __iter__(self):
       
  1453         raise NotImplementedError()
       
  1454 
       
  1455     def __next__(self):
       
  1456         raise NotImplementedError()
       
  1457 
       
  1458     next = __next__
       
  1459 
       
  1460     def read(self, size):
       
  1461         if self._closed:
       
  1462             raise ValueError('stream is closed')
       
  1463 
       
  1464         if self._finished_output:
       
  1465             return b''
       
  1466 
       
  1467         if size < 1:
       
  1468             raise ValueError('cannot read negative or size 0 amounts')
       
  1469 
       
  1470         dst_buffer = ffi.new('char[]', size)
       
  1471         out_buffer = ffi.new('ZSTD_outBuffer *')
       
  1472         out_buffer.dst = dst_buffer
       
  1473         out_buffer.size = size
       
  1474         out_buffer.pos = 0
       
  1475 
       
  1476         def decompress():
       
  1477             zresult = lib.ZSTD_decompress_generic(self._decompressor._dctx,
       
  1478                                                   out_buffer, self._in_buffer)
       
  1479 
       
  1480             if self._in_buffer.pos == self._in_buffer.size:
       
  1481                 self._in_buffer.src = ffi.NULL
       
  1482                 self._in_buffer.pos = 0
       
  1483                 self._in_buffer.size = 0
       
  1484                 self._source_buffer = None
       
  1485 
       
  1486                 if not hasattr(self._source, 'read'):
       
  1487                     self._finished_input = True
       
  1488 
       
  1489             if lib.ZSTD_isError(zresult):
       
  1490                 raise ZstdError('zstd decompress error: %s',
       
  1491                                 _zstd_error(zresult))
       
  1492             elif zresult == 0:
       
  1493                 self._finished_output = True
       
  1494 
       
  1495             if out_buffer.pos and out_buffer.pos == out_buffer.size:
       
  1496                 self._bytes_decompressed += out_buffer.size
       
  1497                 return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
       
  1498 
       
  1499         def get_input():
       
  1500             if self._finished_input:
       
  1501                 return
       
  1502 
       
  1503             if hasattr(self._source, 'read'):
       
  1504                 data = self._source.read(self._read_size)
       
  1505 
       
  1506                 if not data:
       
  1507                     self._finished_input = True
       
  1508                     return
       
  1509 
       
  1510                 self._source_buffer = ffi.from_buffer(data)
       
  1511                 self._in_buffer.src = self._source_buffer
       
  1512                 self._in_buffer.size = len(self._source_buffer)
       
  1513                 self._in_buffer.pos = 0
       
  1514             else:
       
  1515                 self._source_buffer = ffi.from_buffer(self._source)
       
  1516                 self._in_buffer.src = self._source_buffer
       
  1517                 self._in_buffer.size = len(self._source_buffer)
       
  1518                 self._in_buffer.pos = 0
       
  1519 
       
  1520         get_input()
       
  1521         result = decompress()
       
  1522         if result:
       
  1523             return result
       
  1524 
       
  1525         while not self._finished_input:
       
  1526             get_input()
       
  1527             result = decompress()
       
  1528             if result:
       
  1529                 return result
       
  1530 
       
  1531         self._bytes_decompressed += out_buffer.pos
       
  1532         return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
       
  1533 
       
  1534     def seek(self, pos, whence=os.SEEK_SET):
       
  1535         if self._closed:
       
  1536             raise ValueError('stream is closed')
       
  1537 
       
  1538         read_amount = 0
       
  1539 
       
  1540         if whence == os.SEEK_SET:
       
  1541             if pos < 0:
       
  1542                 raise ValueError('cannot seek to negative position with SEEK_SET')
       
  1543 
       
  1544             if pos < self._bytes_decompressed:
       
  1545                 raise ValueError('cannot seek zstd decompression stream '
       
  1546                                  'backwards')
       
  1547 
       
  1548             read_amount = pos - self._bytes_decompressed
       
  1549 
       
  1550         elif whence == os.SEEK_CUR:
       
  1551             if pos < 0:
       
  1552                 raise ValueError('cannot seek zstd decompression stream '
       
  1553                                  'backwards')
       
  1554 
       
  1555             read_amount = pos
       
  1556         elif whence == os.SEEK_END:
       
  1557             raise ValueError('zstd decompression streams cannot be seeked '
       
  1558                              'with SEEK_END')
       
  1559 
       
  1560         while read_amount:
       
  1561             result = self.read(min(read_amount,
       
  1562                                    DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE))
       
  1563 
       
  1564             if not result:
       
  1565                 break
       
  1566 
       
  1567             read_amount -= len(result)
       
  1568 
       
  1569         return self._bytes_decompressed
       
  1570 
       
  1571 class ZstdDecompressionWriter(object):
       
  1572     def __init__(self, decompressor, writer, write_size):
       
  1573         self._decompressor = decompressor
       
  1574         self._writer = writer
       
  1575         self._write_size = write_size
       
  1576         self._entered = False
       
  1577 
       
  1578     def __enter__(self):
       
  1579         if self._entered:
       
  1580             raise ZstdError('cannot __enter__ multiple times')
       
  1581 
       
  1582         self._decompressor._ensure_dctx()
       
  1583         self._entered = True
       
  1584 
       
  1585         return self
       
  1586 
       
  1587     def __exit__(self, exc_type, exc_value, exc_tb):
       
  1588         self._entered = False
       
  1589 
       
  1590     def memory_size(self):
       
  1591         if not self._decompressor._dctx:
       
  1592             raise ZstdError('cannot determine size of inactive decompressor '
       
  1593                             'call when context manager is active')
       
  1594 
       
  1595         return lib.ZSTD_sizeof_DCtx(self._decompressor._dctx)
       
  1596 
       
  1597     def write(self, data):
       
  1598         if not self._entered:
       
  1599             raise ZstdError('write must be called from an active context manager')
       
  1600 
       
  1601         total_write = 0
       
  1602 
       
  1603         in_buffer = ffi.new('ZSTD_inBuffer *')
       
  1604         out_buffer = ffi.new('ZSTD_outBuffer *')
       
  1605 
       
  1606         data_buffer = ffi.from_buffer(data)
       
  1607         in_buffer.src = data_buffer
       
  1608         in_buffer.size = len(data_buffer)
       
  1609         in_buffer.pos = 0
       
  1610 
       
  1611         dst_buffer = ffi.new('char[]', self._write_size)
       
  1612         out_buffer.dst = dst_buffer
       
  1613         out_buffer.size = len(dst_buffer)
       
  1614         out_buffer.pos = 0
       
  1615 
       
  1616         dctx = self._decompressor._dctx
       
  1617 
       
  1618         while in_buffer.pos < in_buffer.size:
       
  1619             zresult = lib.ZSTD_decompress_generic(dctx, out_buffer, in_buffer)
       
  1620             if lib.ZSTD_isError(zresult):
       
  1621                 raise ZstdError('zstd decompress error: %s' %
       
  1622                                 _zstd_error(zresult))
       
  1623 
       
  1624             if out_buffer.pos:
       
  1625                 self._writer.write(ffi.buffer(out_buffer.dst, out_buffer.pos)[:])
       
  1626                 total_write += out_buffer.pos
       
  1627                 out_buffer.pos = 0
       
  1628 
       
  1629         return total_write
       
  1630 
       
  1631 
       
  1632 class ZstdDecompressor(object):
       
  1633     def __init__(self, dict_data=None, max_window_size=0, format=FORMAT_ZSTD1):
       
  1634         self._dict_data = dict_data
       
  1635         self._max_window_size = max_window_size
       
  1636         self._format = format
       
  1637 
       
  1638         dctx = lib.ZSTD_createDCtx()
       
  1639         if dctx == ffi.NULL:
       
  1640             raise MemoryError()
       
  1641 
       
  1642         self._dctx = dctx
       
  1643 
       
  1644         # Defer setting up garbage collection until full state is loaded so
       
  1645         # the memory size is more accurate.
       
  1646         try:
       
  1647             self._ensure_dctx()
       
  1648         finally:
       
  1649             self._dctx = ffi.gc(dctx, lib.ZSTD_freeDCtx,
       
  1650                                 size=lib.ZSTD_sizeof_DCtx(dctx))
       
  1651 
       
  1652     def memory_size(self):
       
  1653         return lib.ZSTD_sizeof_DCtx(self._dctx)
       
  1654 
       
  1655     def decompress(self, data, max_output_size=0):
       
  1656         self._ensure_dctx()
       
  1657 
       
  1658         data_buffer = ffi.from_buffer(data)
       
  1659 
       
  1660         output_size = lib.ZSTD_getFrameContentSize(data_buffer, len(data_buffer))
       
  1661 
       
  1662         if output_size == lib.ZSTD_CONTENTSIZE_ERROR:
       
  1663             raise ZstdError('error determining content size from frame header')
       
  1664         elif output_size == 0:
       
  1665             return b''
       
  1666         elif output_size == lib.ZSTD_CONTENTSIZE_UNKNOWN:
       
  1667             if not max_output_size:
       
  1668                 raise ZstdError('could not determine content size in frame header')
       
  1669 
       
  1670             result_buffer = ffi.new('char[]', max_output_size)
       
  1671             result_size = max_output_size
       
  1672             output_size = 0
       
  1673         else:
       
  1674             result_buffer = ffi.new('char[]', output_size)
       
  1675             result_size = output_size
       
  1676 
       
  1677         out_buffer = ffi.new('ZSTD_outBuffer *')
       
  1678         out_buffer.dst = result_buffer
       
  1679         out_buffer.size = result_size
       
  1680         out_buffer.pos = 0
       
  1681 
       
  1682         in_buffer = ffi.new('ZSTD_inBuffer *')
       
  1683         in_buffer.src = data_buffer
       
  1684         in_buffer.size = len(data_buffer)
       
  1685         in_buffer.pos = 0
       
  1686 
       
  1687         zresult = lib.ZSTD_decompress_generic(self._dctx, out_buffer, in_buffer)
       
  1688         if lib.ZSTD_isError(zresult):
       
  1689             raise ZstdError('decompression error: %s' %
       
  1690                             _zstd_error(zresult))
       
  1691         elif zresult:
       
  1692             raise ZstdError('decompression error: did not decompress full frame')
       
  1693         elif output_size and out_buffer.pos != output_size:
       
  1694             raise ZstdError('decompression error: decompressed %d bytes; expected %d' %
       
  1695                             (zresult, output_size))
       
  1696 
       
  1697         return ffi.buffer(result_buffer, out_buffer.pos)[:]
       
  1698 
       
  1699     def stream_reader(self, source, read_size=DECOMPRESSION_RECOMMENDED_INPUT_SIZE):
       
  1700         self._ensure_dctx()
       
  1701         return DecompressionReader(self, source, read_size)
       
  1702 
       
  1703     def decompressobj(self, write_size=DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE):
       
  1704         if write_size < 1:
       
  1705             raise ValueError('write_size must be positive')
       
  1706 
       
  1707         self._ensure_dctx()
       
  1708         return ZstdDecompressionObj(self, write_size=write_size)
       
  1709 
       
  1710     def read_to_iter(self, reader, read_size=DECOMPRESSION_RECOMMENDED_INPUT_SIZE,
       
  1711                      write_size=DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE,
       
  1712                      skip_bytes=0):
       
  1713         if skip_bytes >= read_size:
       
  1714             raise ValueError('skip_bytes must be smaller than read_size')
       
  1715 
       
  1716         if hasattr(reader, 'read'):
       
  1717             have_read = True
       
  1718         elif hasattr(reader, '__getitem__'):
       
  1719             have_read = False
       
  1720             buffer_offset = 0
       
  1721             size = len(reader)
       
  1722         else:
       
  1723             raise ValueError('must pass an object with a read() method or '
       
  1724                              'conforms to buffer protocol')
       
  1725 
       
  1726         if skip_bytes:
       
  1727             if have_read:
       
  1728                 reader.read(skip_bytes)
       
  1729             else:
       
  1730                 if skip_bytes > size:
       
  1731                     raise ValueError('skip_bytes larger than first input chunk')
       
  1732 
       
  1733                 buffer_offset = skip_bytes
       
  1734 
       
  1735         self._ensure_dctx()
       
  1736 
       
  1737         in_buffer = ffi.new('ZSTD_inBuffer *')
       
  1738         out_buffer = ffi.new('ZSTD_outBuffer *')
       
  1739 
       
  1740         dst_buffer = ffi.new('char[]', write_size)
       
  1741         out_buffer.dst = dst_buffer
       
  1742         out_buffer.size = len(dst_buffer)
       
  1743         out_buffer.pos = 0
       
  1744 
       
  1745         while True:
       
  1746             assert out_buffer.pos == 0
       
  1747 
       
  1748             if have_read:
       
  1749                 read_result = reader.read(read_size)
       
  1750             else:
       
  1751                 remaining = size - buffer_offset
       
  1752                 slice_size = min(remaining, read_size)
       
  1753                 read_result = reader[buffer_offset:buffer_offset + slice_size]
       
  1754                 buffer_offset += slice_size
       
  1755 
       
  1756             # No new input. Break out of read loop.
       
  1757             if not read_result:
       
  1758                 break
       
  1759 
       
  1760             # Feed all read data into decompressor and emit output until
       
  1761             # exhausted.
       
  1762             read_buffer = ffi.from_buffer(read_result)
       
  1763             in_buffer.src = read_buffer
       
  1764             in_buffer.size = len(read_buffer)
       
  1765             in_buffer.pos = 0
       
  1766 
       
  1767             while in_buffer.pos < in_buffer.size:
       
  1768                 assert out_buffer.pos == 0
       
  1769 
       
  1770                 zresult = lib.ZSTD_decompress_generic(self._dctx, out_buffer, in_buffer)
       
  1771                 if lib.ZSTD_isError(zresult):
       
  1772                     raise ZstdError('zstd decompress error: %s' %
       
  1773                                     _zstd_error(zresult))
       
  1774 
       
  1775                 if out_buffer.pos:
       
  1776                     data = ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
       
  1777                     out_buffer.pos = 0
       
  1778                     yield data
       
  1779 
       
  1780                 if zresult == 0:
       
  1781                     return
       
  1782 
       
  1783             # Repeat loop to collect more input data.
       
  1784             continue
       
  1785 
       
  1786         # If we get here, input is exhausted.
       
  1787 
       
  1788     read_from = read_to_iter
       
  1789 
       
  1790     def stream_writer(self, writer, write_size=DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE):
       
  1791         if not hasattr(writer, 'write'):
       
  1792             raise ValueError('must pass an object with a write() method')
       
  1793 
       
  1794         return ZstdDecompressionWriter(self, writer, write_size)
       
  1795 
       
  1796     write_to = stream_writer
       
  1797 
       
  1798     def copy_stream(self, ifh, ofh,
       
  1799                     read_size=DECOMPRESSION_RECOMMENDED_INPUT_SIZE,
       
  1800                     write_size=DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE):
       
  1801         if not hasattr(ifh, 'read'):
       
  1802             raise ValueError('first argument must have a read() method')
       
  1803         if not hasattr(ofh, 'write'):
       
  1804             raise ValueError('second argument must have a write() method')
       
  1805 
       
  1806         self._ensure_dctx()
       
  1807 
       
  1808         in_buffer = ffi.new('ZSTD_inBuffer *')
       
  1809         out_buffer = ffi.new('ZSTD_outBuffer *')
       
  1810 
       
  1811         dst_buffer = ffi.new('char[]', write_size)
       
  1812         out_buffer.dst = dst_buffer
       
  1813         out_buffer.size = write_size
       
  1814         out_buffer.pos = 0
       
  1815 
       
  1816         total_read, total_write = 0, 0
       
  1817 
       
  1818         # Read all available input.
       
  1819         while True:
       
  1820             data = ifh.read(read_size)
       
  1821             if not data:
       
  1822                 break
       
  1823 
       
  1824             data_buffer = ffi.from_buffer(data)
       
  1825             total_read += len(data_buffer)
       
  1826             in_buffer.src = data_buffer
       
  1827             in_buffer.size = len(data_buffer)
       
  1828             in_buffer.pos = 0
       
  1829 
       
  1830             # Flush all read data to output.
       
  1831             while in_buffer.pos < in_buffer.size:
       
  1832                 zresult = lib.ZSTD_decompress_generic(self._dctx, out_buffer, in_buffer)
       
  1833                 if lib.ZSTD_isError(zresult):
       
  1834                     raise ZstdError('zstd decompressor error: %s' %
       
  1835                                     _zstd_error(zresult))
       
  1836 
       
  1837                 if out_buffer.pos:
       
  1838                     ofh.write(ffi.buffer(out_buffer.dst, out_buffer.pos))
       
  1839                     total_write += out_buffer.pos
       
  1840                     out_buffer.pos = 0
       
  1841 
       
  1842             # Continue loop to keep reading.
       
  1843 
       
  1844         return total_read, total_write
       
  1845 
       
  1846     def decompress_content_dict_chain(self, frames):
       
  1847         if not isinstance(frames, list):
       
  1848             raise TypeError('argument must be a list')
       
  1849 
       
  1850         if not frames:
       
  1851             raise ValueError('empty input chain')
       
  1852 
       
  1853         # First chunk should not be using a dictionary. We handle it specially.
       
  1854         chunk = frames[0]
       
  1855         if not isinstance(chunk, bytes_type):
       
  1856             raise ValueError('chunk 0 must be bytes')
       
  1857 
       
  1858         # All chunks should be zstd frames and should have content size set.
       
  1859         chunk_buffer = ffi.from_buffer(chunk)
       
  1860         params = ffi.new('ZSTD_frameHeader *')
       
  1861         zresult = lib.ZSTD_getFrameHeader(params, chunk_buffer, len(chunk_buffer))
       
  1862         if lib.ZSTD_isError(zresult):
       
  1863             raise ValueError('chunk 0 is not a valid zstd frame')
       
  1864         elif zresult:
       
  1865             raise ValueError('chunk 0 is too small to contain a zstd frame')
       
  1866 
       
  1867         if params.frameContentSize == lib.ZSTD_CONTENTSIZE_UNKNOWN:
       
  1868             raise ValueError('chunk 0 missing content size in frame')
       
  1869 
       
  1870         self._ensure_dctx(load_dict=False)
       
  1871 
       
  1872         last_buffer = ffi.new('char[]', params.frameContentSize)
       
  1873 
       
  1874         out_buffer = ffi.new('ZSTD_outBuffer *')
       
  1875         out_buffer.dst = last_buffer
       
  1876         out_buffer.size = len(last_buffer)
       
  1877         out_buffer.pos = 0
       
  1878 
       
  1879         in_buffer = ffi.new('ZSTD_inBuffer *')
       
  1880         in_buffer.src = chunk_buffer
       
  1881         in_buffer.size = len(chunk_buffer)
       
  1882         in_buffer.pos = 0
       
  1883 
       
  1884         zresult = lib.ZSTD_decompress_generic(self._dctx, out_buffer, in_buffer)
       
  1885         if lib.ZSTD_isError(zresult):
       
  1886             raise ZstdError('could not decompress chunk 0: %s' %
       
  1887                             _zstd_error(zresult))
       
  1888         elif zresult:
       
  1889             raise ZstdError('chunk 0 did not decompress full frame')
       
  1890 
       
  1891         # Special case of chain length of 1
       
  1892         if len(frames) == 1:
       
  1893             return ffi.buffer(last_buffer, len(last_buffer))[:]
       
  1894 
       
  1895         i = 1
       
  1896         while i < len(frames):
       
  1897             chunk = frames[i]
       
  1898             if not isinstance(chunk, bytes_type):
       
  1899                 raise ValueError('chunk %d must be bytes' % i)
       
  1900 
       
  1901             chunk_buffer = ffi.from_buffer(chunk)
       
  1902             zresult = lib.ZSTD_getFrameHeader(params, chunk_buffer, len(chunk_buffer))
       
  1903             if lib.ZSTD_isError(zresult):
       
  1904                 raise ValueError('chunk %d is not a valid zstd frame' % i)
       
  1905             elif zresult:
       
  1906                 raise ValueError('chunk %d is too small to contain a zstd frame' % i)
       
  1907 
       
  1908             if params.frameContentSize == lib.ZSTD_CONTENTSIZE_UNKNOWN:
       
  1909                 raise ValueError('chunk %d missing content size in frame' % i)
       
  1910 
       
  1911             dest_buffer = ffi.new('char[]', params.frameContentSize)
       
  1912 
       
  1913             out_buffer.dst = dest_buffer
       
  1914             out_buffer.size = len(dest_buffer)
       
  1915             out_buffer.pos = 0
       
  1916 
       
  1917             in_buffer.src = chunk_buffer
       
  1918             in_buffer.size = len(chunk_buffer)
       
  1919             in_buffer.pos = 0
       
  1920 
       
  1921             zresult = lib.ZSTD_decompress_generic(self._dctx, out_buffer, in_buffer)
       
  1922             if lib.ZSTD_isError(zresult):
       
  1923                 raise ZstdError('could not decompress chunk %d: %s' %
       
  1924                                 _zstd_error(zresult))
       
  1925             elif zresult:
       
  1926                 raise ZstdError('chunk %d did not decompress full frame' % i)
       
  1927 
       
  1928             last_buffer = dest_buffer
       
  1929             i += 1
       
  1930 
       
  1931         return ffi.buffer(last_buffer, len(last_buffer))[:]
       
  1932 
       
  1933     def _ensure_dctx(self, load_dict=True):
       
  1934         lib.ZSTD_DCtx_reset(self._dctx)
       
  1935 
       
  1936         if self._max_window_size:
       
  1937             zresult = lib.ZSTD_DCtx_setMaxWindowSize(self._dctx,
       
  1938                                                      self._max_window_size)
       
  1939             if lib.ZSTD_isError(zresult):
       
  1940                 raise ZstdError('unable to set max window size: %s' %
       
  1941                                 _zstd_error(zresult))
       
  1942 
       
  1943         zresult = lib.ZSTD_DCtx_setFormat(self._dctx, self._format)
       
  1944         if lib.ZSTD_isError(zresult):
       
  1945             raise ZstdError('unable to set decoding format: %s' %
       
  1946                             _zstd_error(zresult))
       
  1947 
       
  1948         if self._dict_data and load_dict:
       
  1949             zresult = lib.ZSTD_DCtx_refDDict(self._dctx, self._dict_data._ddict)
       
  1950             if lib.ZSTD_isError(zresult):
       
  1951                 raise ZstdError('unable to reference prepared dictionary: %s' %
       
  1952                                 _zstd_error(zresult))