contrib/python-zstandard/zstd_cffi.py
author Augie Fackler <augie@google.com>
Wed, 12 Apr 2017 11:23:55 -0700
branchstable
changeset 32050 77eaf9539499
parent 30435 b86a448a2965
child 30895 c32454d69b85
permissions -rw-r--r--
dispatch: protect against malicious 'hg serve --stdio' invocations (sec) Some shared-ssh installations assume that 'hg serve --stdio' is a safe command to run for minimally trusted users. Unfortunately, the messy implementation of argument parsing here meant that trying to access a repo named '--debugger' would give the user a pdb prompt, thereby sidestepping any hoped-for sandboxing. Serving repositories over HTTP(S) is unaffected. We're not currently hardening any subcommands other than 'serve'. If your service exposes other commands to users with arbitrary repository names, it is imperative that you defend against repository names of '--debugger' and anything starting with '--config'. The read-only mode of hg-ssh stopped working because it provided its hook configuration to "hg serve --stdio" via --config parameter. This is banned for security reasons now. This patch switches it to directly call ui.setconfig(). If your custom hosting infrastructure relies on passing --config to "hg serve --stdio", you'll need to find a different way to get that configuration into Mercurial, either by using ui.setconfig() as hg-ssh does in this patch, or by placing an hgrc file someplace where Mercurial will read it. mitrandir@fb.com provided some extra fixes for the dispatch code and for hg-ssh in places that I overlooked.

# Copyright (c) 2016-present, Gregory Szorc
# All rights reserved.
#
# This software may be modified and distributed under the terms
# of the BSD license. See the LICENSE file for details.

"""Python interface to the Zstandard (zstd) compression library."""

from __future__ import absolute_import, unicode_literals

import io

from _zstd_cffi import (
    ffi,
    lib,
)


_CSTREAM_IN_SIZE = lib.ZSTD_CStreamInSize()
_CSTREAM_OUT_SIZE = lib.ZSTD_CStreamOutSize()


class _ZstdCompressionWriter(object):
    def __init__(self, cstream, writer):
        self._cstream = cstream
        self._writer = writer

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, exc_tb):
        if not exc_type and not exc_value and not exc_tb:
            out_buffer = ffi.new('ZSTD_outBuffer *')
            out_buffer.dst = ffi.new('char[]', _CSTREAM_OUT_SIZE)
            out_buffer.size = _CSTREAM_OUT_SIZE
            out_buffer.pos = 0

            while True:
                res = lib.ZSTD_endStream(self._cstream, out_buffer)
                if lib.ZSTD_isError(res):
                    raise Exception('error ending compression stream: %s' % lib.ZSTD_getErrorName)

                if out_buffer.pos:
                    self._writer.write(ffi.buffer(out_buffer.dst, out_buffer.pos))
                    out_buffer.pos = 0

                if res == 0:
                    break

        return False

    def write(self, data):
        out_buffer = ffi.new('ZSTD_outBuffer *')
        out_buffer.dst = ffi.new('char[]', _CSTREAM_OUT_SIZE)
        out_buffer.size = _CSTREAM_OUT_SIZE
        out_buffer.pos = 0

        # TODO can we reuse existing memory?
        in_buffer = ffi.new('ZSTD_inBuffer *')
        in_buffer.src = ffi.new('char[]', data)
        in_buffer.size = len(data)
        in_buffer.pos = 0
        while in_buffer.pos < in_buffer.size:
            res = lib.ZSTD_compressStream(self._cstream, out_buffer, in_buffer)
            if lib.ZSTD_isError(res):
                raise Exception('zstd compress error: %s' % lib.ZSTD_getErrorName(res))

            if out_buffer.pos:
                self._writer.write(ffi.buffer(out_buffer.dst, out_buffer.pos))
                out_buffer.pos = 0


class ZstdCompressor(object):
    def __init__(self, level=3, dict_data=None, compression_params=None):
        if dict_data:
            raise Exception('dict_data not yet supported')
        if compression_params:
            raise Exception('compression_params not yet supported')

        self._compression_level = level

    def compress(self, data):
        # Just use the stream API for now.
        output = io.BytesIO()
        with self.write_to(output) as compressor:
            compressor.write(data)
        return output.getvalue()

    def copy_stream(self, ifh, ofh):
        cstream = self._get_cstream()

        in_buffer = ffi.new('ZSTD_inBuffer *')
        out_buffer = ffi.new('ZSTD_outBuffer *')

        out_buffer.dst = ffi.new('char[]', _CSTREAM_OUT_SIZE)
        out_buffer.size = _CSTREAM_OUT_SIZE
        out_buffer.pos = 0

        total_read, total_write = 0, 0

        while True:
            data = ifh.read(_CSTREAM_IN_SIZE)
            if not data:
                break

            total_read += len(data)

            in_buffer.src = ffi.new('char[]', data)
            in_buffer.size = len(data)
            in_buffer.pos = 0

            while in_buffer.pos < in_buffer.size:
                res = lib.ZSTD_compressStream(cstream, out_buffer, in_buffer)
                if lib.ZSTD_isError(res):
                    raise Exception('zstd compress error: %s' %
                                    lib.ZSTD_getErrorName(res))

                if out_buffer.pos:
                    ofh.write(ffi.buffer(out_buffer.dst, out_buffer.pos))
                    total_write = out_buffer.pos
                    out_buffer.pos = 0

        # We've finished reading. Flush the compressor.
        while True:
            res = lib.ZSTD_endStream(cstream, out_buffer)
            if lib.ZSTD_isError(res):
                raise Exception('error ending compression stream: %s' %
                                lib.ZSTD_getErrorName(res))

            if out_buffer.pos:
                ofh.write(ffi.buffer(out_buffer.dst, out_buffer.pos))
                total_write += out_buffer.pos
                out_buffer.pos = 0

            if res == 0:
                break

        return total_read, total_write

    def write_to(self, writer):
        return _ZstdCompressionWriter(self._get_cstream(), writer)

    def _get_cstream(self):
        cstream = lib.ZSTD_createCStream()
        cstream = ffi.gc(cstream, lib.ZSTD_freeCStream)

        res = lib.ZSTD_initCStream(cstream, self._compression_level)
        if lib.ZSTD_isError(res):
            raise Exception('cannot init CStream: %s' %
                            lib.ZSTD_getErrorName(res))

        return cstream