contrib/python-zstandard/c-ext/decompressor.c
changeset 42070 675775c33ab6
parent 40121 73fef626dae3
child 42812 a4e32fd539ab
--- a/contrib/python-zstandard/c-ext/decompressor.c	Thu Apr 04 15:24:03 2019 -0700
+++ b/contrib/python-zstandard/c-ext/decompressor.c	Thu Apr 04 17:34:43 2019 -0700
@@ -17,7 +17,7 @@
 int ensure_dctx(ZstdDecompressor* decompressor, int loadDict) {
 	size_t zresult;
 
-	ZSTD_DCtx_reset(decompressor->dctx);
+	ZSTD_DCtx_reset(decompressor->dctx, ZSTD_reset_session_only);
 
 	if (decompressor->maxWindowSize) {
 		zresult = ZSTD_DCtx_setMaxWindowSize(decompressor->dctx, decompressor->maxWindowSize);
@@ -229,7 +229,7 @@
 
 		while (input.pos < input.size) {
 			Py_BEGIN_ALLOW_THREADS
-			zresult = ZSTD_decompress_generic(self->dctx, &output, &input);
+			zresult = ZSTD_decompressStream(self->dctx, &output, &input);
 			Py_END_ALLOW_THREADS
 
 			if (ZSTD_isError(zresult)) {
@@ -379,7 +379,7 @@
 	inBuffer.pos = 0;
 
 	Py_BEGIN_ALLOW_THREADS
-	zresult = ZSTD_decompress_generic(self->dctx, &outBuffer, &inBuffer);
+	zresult = ZSTD_decompressStream(self->dctx, &outBuffer, &inBuffer);
 	Py_END_ALLOW_THREADS
 
 	if (ZSTD_isError(zresult)) {
@@ -550,28 +550,35 @@
 }
 
 PyDoc_STRVAR(Decompressor_stream_reader__doc__,
-"stream_reader(source, [read_size=default])\n"
+"stream_reader(source, [read_size=default, [read_across_frames=False]])\n"
 "\n"
 "Obtain an object that behaves like an I/O stream that can be used for\n"
 "reading decompressed output from an object.\n"
 "\n"
 "The source object can be any object with a ``read(size)`` method or that\n"
 "conforms to the buffer protocol.\n"
+"\n"
+"``read_across_frames`` controls the behavior of ``read()`` when the end\n"
+"of a zstd frame is reached. When ``True``, ``read()`` can potentially\n"
+"return data belonging to multiple zstd frames. When ``False``, ``read()``\n"
+"will return when the end of a frame is reached.\n"
 );
 
 static ZstdDecompressionReader* Decompressor_stream_reader(ZstdDecompressor* self, PyObject* args, PyObject* kwargs) {
 	static char* kwlist[] = {
 		"source",
 		"read_size",
+		"read_across_frames",
 		NULL
 	};
 
 	PyObject* source;
 	size_t readSize = ZSTD_DStreamInSize();
+	PyObject* readAcrossFrames = NULL;
 	ZstdDecompressionReader* result;
 
-	if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|k:stream_reader", kwlist,
-		&source, &readSize)) {
+	if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|kO:stream_reader", kwlist,
+		&source, &readSize, &readAcrossFrames)) {
 		return NULL;
 	}
 
@@ -604,6 +611,7 @@
 
 	result->decompressor = self;
 	Py_INCREF(self);
+	result->readAcrossFrames = readAcrossFrames ? PyObject_IsTrue(readAcrossFrames) : 0;
 
 	return result;
 }
@@ -625,15 +633,17 @@
 	static char* kwlist[] = {
 		"writer",
 		"write_size",
+		"write_return_read",
 		NULL
 	};
 
 	PyObject* writer;
 	size_t outSize = ZSTD_DStreamOutSize();
+	PyObject* writeReturnRead = NULL;
 	ZstdDecompressionWriter* result;
 
-	if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|k:stream_writer", kwlist,
-		&writer, &outSize)) {
+	if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|kO:stream_writer", kwlist,
+		&writer, &outSize, &writeReturnRead)) {
 		return NULL;
 	}
 
@@ -642,6 +652,10 @@
 		return NULL;
 	}
 
+	if (ensure_dctx(self, 1)) {
+		return NULL;
+	}
+
 	result = (ZstdDecompressionWriter*)PyObject_CallObject((PyObject*)&ZstdDecompressionWriterType, NULL);
 	if (!result) {
 		return NULL;
@@ -654,6 +668,7 @@
 	Py_INCREF(result->writer);
 
 	result->outSize = outSize;
+	result->writeReturnRead = writeReturnRead ? PyObject_IsTrue(writeReturnRead) : 0;
 
 	return result;
 }
@@ -756,7 +771,7 @@
 	inBuffer.pos = 0;
 
 	Py_BEGIN_ALLOW_THREADS
-	zresult = ZSTD_decompress_generic(self->dctx, &outBuffer, &inBuffer);
+	zresult = ZSTD_decompressStream(self->dctx, &outBuffer, &inBuffer);
 	Py_END_ALLOW_THREADS
 	if (ZSTD_isError(zresult)) {
 		PyErr_Format(ZstdError, "could not decompress chunk 0: %s", ZSTD_getErrorName(zresult));
@@ -852,7 +867,7 @@
 			outBuffer.pos = 0;
 
 			Py_BEGIN_ALLOW_THREADS
-			zresult = ZSTD_decompress_generic(self->dctx, &outBuffer, &inBuffer);
+			zresult = ZSTD_decompressStream(self->dctx, &outBuffer, &inBuffer);
 			Py_END_ALLOW_THREADS
 			if (ZSTD_isError(zresult)) {
 				PyErr_Format(ZstdError, "could not decompress chunk %zd: %s",
@@ -892,7 +907,7 @@
 			outBuffer.pos = 0;
 
 			Py_BEGIN_ALLOW_THREADS
-			zresult = ZSTD_decompress_generic(self->dctx, &outBuffer, &inBuffer);
+			zresult = ZSTD_decompressStream(self->dctx, &outBuffer, &inBuffer);
 			Py_END_ALLOW_THREADS
 			if (ZSTD_isError(zresult)) {
 				PyErr_Format(ZstdError, "could not decompress chunk %zd: %s",
@@ -1176,7 +1191,7 @@
 		inBuffer.size = sourceSize;
 		inBuffer.pos = 0;
 
-		zresult = ZSTD_decompress_generic(state->dctx, &outBuffer, &inBuffer);
+		zresult = ZSTD_decompressStream(state->dctx, &outBuffer, &inBuffer);
 		if (ZSTD_isError(zresult)) {
 			state->error = WorkerError_zstd;
 			state->zresult = zresult;