mercurial/util.py
changeset 38713 27391d74aaa2
parent 38575 152f4822d210
child 38797 8751d1e2a7ff
--- a/mercurial/util.py	Mon Jul 16 16:46:32 2018 +0200
+++ b/mercurial/util.py	Thu Jul 12 18:46:10 2018 +0200
@@ -322,6 +322,11 @@
             self._fillbuffer()
         return self._frombuffer(size)
 
+    def unbufferedread(self, size):
+        if not self._eof and self._lenbuf == 0:
+            self._fillbuffer(max(size, _chunksize))
+        return self._frombuffer(min(self._lenbuf, size))
+
     def readline(self, *args, **kwargs):
         if 1 < len(self._buffer):
             # this should not happen because both read and readline end with a
@@ -363,9 +368,9 @@
             self._lenbuf = 0
         return data
 
-    def _fillbuffer(self):
+    def _fillbuffer(self, size=_chunksize):
         """read data to the buffer"""
-        data = os.read(self._input.fileno(), _chunksize)
+        data = os.read(self._input.fileno(), size)
         if not data:
             self._eof = True
         else:
@@ -3302,6 +3307,104 @@
         """
         raise NotImplementedError()
 
+class _CompressedStreamReader(object):
+    def __init__(self, fh):
+        if safehasattr(fh, 'unbufferedread'):
+            self._reader = fh.unbufferedread
+        else:
+            self._reader = fh.read
+        self._pending = []
+        self._pos = 0
+        self._eof = False
+
+    def _decompress(self, chunk):
+        raise NotImplementedError()
+
+    def read(self, l):
+        buf = []
+        while True:
+            while self._pending:
+                if len(self._pending[0]) > l + self._pos:
+                    newbuf = self._pending[0]
+                    buf.append(newbuf[self._pos:self._pos + l])
+                    self._pos += l
+                    return ''.join(buf)
+
+                newbuf = self._pending.pop(0)
+                if self._pos:
+                    buf.append(newbuf[self._pos:])
+                    l -= len(newbuf) - self._pos
+                else:
+                    buf.append(newbuf)
+                    l -= len(newbuf)
+                self._pos = 0
+
+            if self._eof:
+                return ''.join(buf)
+            chunk = self._reader(65536)
+            self._decompress(chunk)
+
+class _GzipCompressedStreamReader(_CompressedStreamReader):
+    def __init__(self, fh):
+        super(_GzipCompressedStreamReader, self).__init__(fh)
+        self._decompobj = zlib.decompressobj()
+    def _decompress(self, chunk):
+        newbuf = self._decompobj.decompress(chunk)
+        if newbuf:
+            self._pending.append(newbuf)
+        d = self._decompobj.copy()
+        try:
+            d.decompress('x')
+            d.flush()
+            if d.unused_data == 'x':
+                self._eof = True
+        except zlib.error:
+            pass
+
+class _BZ2CompressedStreamReader(_CompressedStreamReader):
+    def __init__(self, fh):
+        super(_BZ2CompressedStreamReader, self).__init__(fh)
+        self._decompobj = bz2.BZ2Decompressor()
+    def _decompress(self, chunk):
+        newbuf = self._decompobj.decompress(chunk)
+        if newbuf:
+            self._pending.append(newbuf)
+        try:
+            while True:
+                newbuf = self._decompobj.decompress('')
+                if newbuf:
+                    self._pending.append(newbuf)
+                else:
+                    break
+        except EOFError:
+            self._eof = True
+
+class _TruncatedBZ2CompressedStreamReader(_BZ2CompressedStreamReader):
+    def __init__(self, fh):
+        super(_TruncatedBZ2CompressedStreamReader, self).__init__(fh)
+        newbuf = self._decompobj.decompress('BZ')
+        if newbuf:
+            self._pending.append(newbuf)
+
+class _ZstdCompressedStreamReader(_CompressedStreamReader):
+    def __init__(self, fh, zstd):
+        super(_ZstdCompressedStreamReader, self).__init__(fh)
+        self._zstd = zstd
+        self._decompobj = zstd.ZstdDecompressor().decompressobj()
+    def _decompress(self, chunk):
+        newbuf = self._decompobj.decompress(chunk)
+        if newbuf:
+            self._pending.append(newbuf)
+        try:
+            while True:
+                newbuf = self._decompobj.decompress('')
+                if newbuf:
+                    self._pending.append(newbuf)
+                else:
+                    break
+        except self._zstd.ZstdError:
+            self._eof = True
+
 class _zlibengine(compressionengine):
     def name(self):
         return 'zlib'
@@ -3335,15 +3438,7 @@
         yield z.flush()
 
     def decompressorreader(self, fh):
-        def gen():
-            d = zlib.decompressobj()
-            for chunk in filechunkiter(fh):
-                while chunk:
-                    # Limit output size to limit memory.
-                    yield d.decompress(chunk, 2 ** 18)
-                    chunk = d.unconsumed_tail
-
-        return chunkbuffer(gen())
+        return _GzipCompressedStreamReader(fh)
 
     class zlibrevlogcompressor(object):
         def compress(self, data):
@@ -3423,12 +3518,7 @@
         yield z.flush()
 
     def decompressorreader(self, fh):
-        def gen():
-            d = bz2.BZ2Decompressor()
-            for chunk in filechunkiter(fh):
-                yield d.decompress(chunk)
-
-        return chunkbuffer(gen())
+        return _BZ2CompressedStreamReader(fh)
 
 compengines.register(_bz2engine())
 
@@ -3442,14 +3532,7 @@
     # We don't implement compressstream because it is hackily handled elsewhere.
 
     def decompressorreader(self, fh):
-        def gen():
-            # The input stream doesn't have the 'BZ' header. So add it back.
-            d = bz2.BZ2Decompressor()
-            d.decompress('BZ')
-            for chunk in filechunkiter(fh):
-                yield d.decompress(chunk)
-
-        return chunkbuffer(gen())
+        return _TruncatedBZ2CompressedStreamReader(fh)
 
 compengines.register(_truncatedbz2engine())
 
@@ -3544,9 +3627,7 @@
         yield z.flush()
 
     def decompressorreader(self, fh):
-        zstd = self._module
-        dctx = zstd.ZstdDecompressor()
-        return chunkbuffer(dctx.read_from(fh))
+        return _ZstdCompressedStreamReader(fh, self._module)
 
     class zstdrevlogcompressor(object):
         def __init__(self, zstd, level=3):