mercurial/util.py
changeset 30761 7283719e2bfd
parent 30745 c1b7b2285522
child 30773 c390b40fe1d7
--- a/mercurial/util.py	Sat Dec 24 13:56:36 2016 -0700
+++ b/mercurial/util.py	Sat Dec 24 13:51:12 2016 -0700
@@ -2957,6 +2957,13 @@
 
 # compression code
 
+SERVERROLE = 'server'
+CLIENTROLE = 'client'
+
+compewireprotosupport = collections.namedtuple(u'compenginewireprotosupport',
+                                               (u'name', u'serverpriority',
+                                                u'clientpriority'))
+
 class compressormanager(object):
     """Holds registrations of various compression engines.
 
@@ -2973,6 +2980,8 @@
         self._bundlenames = {}
         # Internal bundle identifier to engine name.
         self._bundletypes = {}
+        # Wire proto identifier to engine name.
+        self._wiretypes = {}
 
     def __getitem__(self, key):
         return self._engines[key]
@@ -3014,6 +3023,16 @@
 
             self._bundletypes[bundletype] = name
 
+        wiresupport = engine.wireprotosupport()
+        if wiresupport:
+            wiretype = wiresupport.name
+            if wiretype in self._wiretypes:
+                raise error.Abort(_('wire protocol compression %s already '
+                                    'registered by %s') %
+                                  (wiretype, self._wiretypes[wiretype]))
+
+            self._wiretypes[wiretype] = name
+
         self._engines[name] = engine
 
     @property
@@ -3050,6 +3069,38 @@
                               engine.name())
         return engine
 
+    def supportedwireengines(self, role, onlyavailable=True):
+        """Obtain compression engines that support the wire protocol.
+
+        Returns a list of engines in prioritized order, most desired first.
+
+        If ``onlyavailable`` is set, filter out engines that can't be
+        loaded.
+        """
+        assert role in (SERVERROLE, CLIENTROLE)
+
+        attr = 'serverpriority' if role == SERVERROLE else 'clientpriority'
+
+        engines = [self._engines[e] for e in self._wiretypes.values()]
+        if onlyavailable:
+            engines = [e for e in engines if e.available()]
+
+        def getkey(e):
+            # Sort first by priority, highest first. In case of tie, sort
+            # alphabetically. This is arbitrary, but ensures output is
+            # stable.
+            w = e.wireprotosupport()
+            return -1 * getattr(w, attr), w.name
+
+        return list(sorted(engines, key=getkey))
+
+    def forwiretype(self, wiretype):
+        engine = self._engines[self._wiretypes[wiretype]]
+        if not engine.available():
+            raise error.Abort(_('compression engine %s could not be loaded') %
+                              engine.name())
+        return engine
+
 compengines = compressormanager()
 
 class compressionengine(object):
@@ -3090,6 +3141,31 @@
         """
         return None
 
+    def wireprotosupport(self):
+        """Declare support for this compression format on the wire protocol.
+
+        If this compression engine isn't supported for compressing wire
+        protocol payloads, returns None.
+
+        Otherwise, returns ``compenginewireprotosupport`` with the following
+        fields:
+
+        * String format identifier
+        * Integer priority for the server
+        * Integer priority for the client
+
+        The integer priorities are used to order the advertisement of format
+        support by server and client. The highest integer is advertised
+        first. Integers with non-positive values aren't advertised.
+
+        The priority values are somewhat arbitrary and only used for default
+        ordering. The relative order can be changed via config options.
+
+        If wire protocol compression is supported, the class must also implement
+        ``compressstream`` and ``decompressorreader``.
+        """
+        return None
+
     def compressstream(self, it, opts=None):
         """Compress an iterator of chunks.
 
@@ -3118,6 +3194,9 @@
     def bundletype(self):
         return 'gzip', 'GZ'
 
+    def wireprotosupport(self):
+        return compewireprotosupport('zlib', 20, 20)
+
     def compressstream(self, it, opts=None):
         opts = opts or {}
 
@@ -3151,6 +3230,11 @@
     def bundletype(self):
         return 'bzip2', 'BZ'
 
+    # We declare a protocol name but don't advertise by default because
+    # it is slow.
+    def wireprotosupport(self):
+        return compewireprotosupport('bzip2', 0, 0)
+
     def compressstream(self, it, opts=None):
         opts = opts or {}
         z = bz2.BZ2Compressor(opts.get('level', 9))
@@ -3199,6 +3283,12 @@
     def bundletype(self):
         return 'none', 'UN'
 
+    # Clients always support uncompressed payloads. Servers don't because
+    # unless you are on a fast network, uncompressed payloads can easily
+    # saturate your network pipe.
+    def wireprotosupport(self):
+        return compewireprotosupport('none', 0, 10)
+
     def compressstream(self, it, opts=None):
         return it
 
@@ -3229,6 +3319,9 @@
     def bundletype(self):
         return 'zstd', 'ZS'
 
+    def wireprotosupport(self):
+        return compewireprotosupport('zstd', 50, 50)
+
     def compressstream(self, it, opts=None):
         opts = opts or {}
         # zstd level 3 is almost always significantly faster than zlib