mercurial/templateutil.py
changeset 38448 dae829b4de78
parent 38447 b6294c113794
child 38449 bc8d925342f0
--- a/mercurial/templateutil.py	Sun Jun 17 16:10:38 2018 +0900
+++ b/mercurial/templateutil.py	Thu Jun 14 22:33:26 2018 +0900
@@ -63,6 +63,14 @@
         value depending on the self type"""
 
     @abc.abstractmethod
+    def filter(self, context, mapping, select):
+        """Return new container of the same type which includes only the
+        selected elements
+
+        select() takes each item as a wrapped object and returns True/False.
+        """
+
+    @abc.abstractmethod
     def itermaps(self, context):
         """Yield each template mapping"""
 
@@ -130,6 +138,10 @@
             raise error.ParseError(_('empty string'))
         return func(pycompat.iterbytestr(self._value))
 
+    def filter(self, context, mapping, select):
+        raise error.ParseError(_('%r is not filterable')
+                               % pycompat.bytestr(self._value))
+
     def itermaps(self, context):
         raise error.ParseError(_('%r is not iterable of mappings')
                                % pycompat.bytestr(self._value))
@@ -164,6 +176,9 @@
     def getmax(self, context, mapping):
         raise error.ParseError(_("%r is not iterable") % self._value)
 
+    def filter(self, context, mapping, select):
+        raise error.ParseError(_("%r is not iterable") % self._value)
+
     def itermaps(self, context):
         raise error.ParseError(_('%r is not iterable of mappings')
                                % self._value)
@@ -208,6 +223,9 @@
     def getmax(self, context, mapping):
         raise error.ParseError(_('date is not iterable'))
 
+    def filter(self, context, mapping, select):
+        raise error.ParseError(_('date is not iterable'))
+
     def join(self, context, mapping, sep):
         raise error.ParseError(_("date is not iterable"))
 
@@ -273,6 +291,14 @@
             return val
         return hybriditem(None, key, val, self._makemap)
 
+    def filter(self, context, mapping, select):
+        if util.safehasattr(self._values, 'get'):
+            values = {k: v for k, v in self._values.iteritems()
+                      if select(self._wrapvalue(k, v))}
+        else:
+            values = [v for v in self._values if select(self._wrapvalue(v, v))]
+        return hybrid(None, values, self._makemap, self._joinfmt, self._keytype)
+
     def itermaps(self, context):
         makemap = self._makemap
         for x in self._values:
@@ -336,6 +362,10 @@
         w = makewrapped(context, mapping, self._value)
         return w.getmax(context, mapping)
 
+    def filter(self, context, mapping, select):
+        w = makewrapped(context, mapping, self._value)
+        return w.filter(context, mapping, select)
+
     def join(self, context, mapping, sep):
         w = makewrapped(context, mapping, self._value)
         return w.join(context, mapping, sep)
@@ -384,6 +414,9 @@
     def getmax(self, context, mapping):
         raise error.ParseError(_('not comparable'))
 
+    def filter(self, context, mapping, select):
+        raise error.ParseError(_('not filterable without template'))
+
     def join(self, context, mapping, sep):
         mapsiter = _iteroverlaymaps(context, mapping, self.itermaps(context))
         if self._name:
@@ -472,6 +505,17 @@
             raise error.ParseError(_('empty sequence'))
         return func(xs)
 
+    @staticmethod
+    def _filteredgen(context, mapping, make, args, select):
+        for x in make(context, *args):
+            s = stringify(context, mapping, x)
+            if select(wrappedbytes(s)):
+                yield s
+
+    def filter(self, context, mapping, select):
+        args = (mapping, self._make, self._args, select)
+        return mappedgenerator(self._filteredgen, args)
+
     def itermaps(self, context):
         raise error.ParseError(_('list of strings is not mappable'))