contrib/import-checker.py
changeset 43075 57875cf423c9
parent 42813 268662aac075
child 43084 c2e284cee333
--- a/contrib/import-checker.py	Fri Oct 04 15:53:45 2019 -0400
+++ b/contrib/import-checker.py	Sat Oct 05 10:29:34 2019 -0400
@@ -10,7 +10,7 @@
 # Import a minimal set of stdlib modules needed for list_stdlib_modules()
 # to work when run from a virtualenv.  The modules were chosen empirically
 # so that the return value matches the return value without virtualenv.
-if True: # disable lexical sorting checks
+if True:  # disable lexical sorting checks
     try:
         import BaseHTTPServer as basehttpserver
     except ImportError:
@@ -47,9 +47,7 @@
 )
 
 # Whitelist of symbols that can be directly imported.
-directsymbols = (
-    'demandimport',
-)
+directsymbols = ('demandimport',)
 
 # Modules that must be aliased because they are commonly confused with
 # common variables and can create aliasing and readability issues.
@@ -57,6 +55,7 @@
     'ui': 'uimod',
 }
 
+
 def usingabsolute(root):
     """Whether absolute imports are being used."""
     if sys.version_info[0] >= 3:
@@ -71,6 +70,7 @@
 
     return False
 
+
 def walklocal(root):
     """Recursively yield all descendant nodes but not in a different scope"""
     todo = collections.deque(ast.iter_child_nodes(root))
@@ -82,6 +82,7 @@
             todo.extend(ast.iter_child_nodes(node))
         yield node, newscope
 
+
 def dotted_name_of_path(path):
     """Given a relative path to a source file, return its dotted module name.
 
@@ -91,11 +92,12 @@
     'zlib'
     """
     parts = path.replace(os.sep, '/').split('/')
-    parts[-1] = parts[-1].split('.', 1)[0] # remove .py and .so and .ARCH.so
+    parts[-1] = parts[-1].split('.', 1)[0]  # remove .py and .so and .ARCH.so
     if parts[-1].endswith('module'):
         parts[-1] = parts[-1][:-6]
     return '.'.join(parts)
 
+
 def fromlocalfunc(modulename, localmods):
     """Get a function to examine which locally defined module the
     target source imports via a specified name.
@@ -164,6 +166,7 @@
     prefix = '.'.join(modulename.split('.')[:-1])
     if prefix:
         prefix += '.'
+
     def fromlocal(name, level=0):
         # name is false value when relative imports are used.
         if not name:
@@ -175,8 +178,9 @@
                 # Check relative name first.
                 candidates = [prefix + name, name]
             else:
-                candidates = ['.'.join(modulename.split('.')[:-level]) +
-                              '.' + name]
+                candidates = [
+                    '.'.join(modulename.split('.')[:-level]) + '.' + name
+                ]
 
         for n in candidates:
             if n in localmods:
@@ -185,18 +189,21 @@
             if dottedpath in localmods:
                 return (n, dottedpath, True)
         return False
+
     return fromlocal
 
+
 def populateextmods(localmods):
     """Populate C extension modules based on pure modules"""
     newlocalmods = set(localmods)
     for n in localmods:
         if n.startswith('mercurial.pure.'):
-            m = n[len('mercurial.pure.'):]
+            m = n[len('mercurial.pure.') :]
             newlocalmods.add('mercurial.cext.' + m)
             newlocalmods.add('mercurial.cffi._' + m)
     return newlocalmods
 
+
 def list_stdlib_modules():
     """List the modules present in the stdlib.
 
@@ -232,13 +239,13 @@
     for m in ['msvcrt', '_winreg']:
         yield m
     yield '__builtin__'
-    yield 'builtins' # python3 only
-    yield 'importlib.abc' # python3 only
-    yield 'importlib.machinery' # python3 only
-    yield 'importlib.util' # python3 only
+    yield 'builtins'  # python3 only
+    yield 'importlib.abc'  # python3 only
+    yield 'importlib.machinery'  # python3 only
+    yield 'importlib.util'  # python3 only
     for m in 'fcntl', 'grp', 'pwd', 'termios':  # Unix only
         yield m
-    for m in 'cPickle', 'datetime': # in Python (not C) on PyPy
+    for m in 'cPickle', 'datetime':  # in Python (not C) on PyPy
         yield m
     for m in ['cffi']:
         yield m
@@ -264,14 +271,17 @@
     for libpath in sys.path:
         # We want to walk everything in sys.path that starts with something in
         # stdlib_prefixes, but not directories from the hg sources.
-        if (os.path.abspath(libpath).startswith(sourceroot)
-            or not any(libpath.startswith(p) for p in stdlib_prefixes)):
+        if os.path.abspath(libpath).startswith(sourceroot) or not any(
+            libpath.startswith(p) for p in stdlib_prefixes
+        ):
             continue
         for top, dirs, files in os.walk(libpath):
             for i, d in reversed(list(enumerate(dirs))):
-                if (not os.path.exists(os.path.join(top, d, '__init__.py'))
-                    or top == libpath and d in ('hgdemandimport', 'hgext',
-                                                'mercurial')):
+                if (
+                    not os.path.exists(os.path.join(top, d, '__init__.py'))
+                    or top == libpath
+                    and d in ('hgdemandimport', 'hgext', 'mercurial')
+                ):
                     del dirs[i]
             for name in files:
                 if not name.endswith(('.py', '.so', '.pyc', '.pyo', '.pyd')):
@@ -280,12 +290,14 @@
                     full_path = top
                 else:
                     full_path = os.path.join(top, name)
-                rel_path = full_path[len(libpath) + 1:]
+                rel_path = full_path[len(libpath) + 1 :]
                 mod = dotted_name_of_path(rel_path)
                 yield mod
 
+
 stdlib_modules = set(list_stdlib_modules())
 
+
 def imported_modules(source, modulename, f, localmods, ignore_nested=False):
     """Given the source of a file as a string, yield the names
     imported by that file.
@@ -383,6 +395,7 @@
                 # lookup
                 yield dottedpath
 
+
 def verify_import_convention(module, source, localmods):
     """Verify imports match our established coding convention.
 
@@ -400,6 +413,7 @@
     else:
         return verify_stdlib_on_own_line(root)
 
+
 def verify_modern_convention(module, root, localmods, root_col_offset=0):
     """Verify a file conforms to the modern import convention rules.
 
@@ -443,19 +457,24 @@
     seenlevels = set()
 
     for node, newscope in walklocal(root):
+
         def msg(fmt, *args):
             return (fmt % args, node.lineno)
+
         if newscope:
             # Check for local imports in function
-            for r in verify_modern_convention(module, node, localmods,
-                                              node.col_offset + 4):
+            for r in verify_modern_convention(
+                module, node, localmods, node.col_offset + 4
+            ):
                 yield r
         elif isinstance(node, ast.Import):
             # Disallow "import foo, bar" and require separate imports
             # for each module.
             if len(node.names) > 1:
-                yield msg('multiple imported names: %s',
-                          ', '.join(n.name for n in node.names))
+                yield msg(
+                    'multiple imported names: %s',
+                    ', '.join(n.name for n in node.names),
+                )
 
             name = node.names[0].name
             asname = node.names[0].asname
@@ -465,16 +484,20 @@
             # Ignore sorting rules on imports inside blocks.
             if node.col_offset == root_col_offset:
                 if lastname and name < lastname and laststdlib == stdlib:
-                    yield msg('imports not lexically sorted: %s < %s',
-                              name, lastname)
+                    yield msg(
+                        'imports not lexically sorted: %s < %s', name, lastname
+                    )
 
             lastname = name
             laststdlib = stdlib
 
             # stdlib imports should be before local imports.
             if stdlib and seenlocal and node.col_offset == root_col_offset:
-                yield msg('stdlib import "%s" follows local import: %s',
-                          name, seenlocal)
+                yield msg(
+                    'stdlib import "%s" follows local import: %s',
+                    name,
+                    seenlocal,
+                )
 
             if not stdlib:
                 seenlocal = name
@@ -485,13 +508,16 @@
                 yield msg('import should be relative: %s', name)
 
             if name in requirealias and asname != requirealias[name]:
-                yield msg('%s module must be "as" aliased to %s',
-                          name, requirealias[name])
+                yield msg(
+                    '%s module must be "as" aliased to %s',
+                    name,
+                    requirealias[name],
+                )
 
         elif isinstance(node, ast.ImportFrom):
             # Resolve the full imported module name.
             if node.level > 0:
-                fullname = '.'.join(module.split('.')[:-node.level])
+                fullname = '.'.join(module.split('.')[: -node.level])
                 if node.module:
                     fullname += '.%s' % node.module
             else:
@@ -508,7 +534,8 @@
                 if not fullname or (
                     fullname in stdlib_modules
                     and fullname not in localmods
-                    and fullname + '.__init__' not in localmods):
+                    and fullname + '.__init__' not in localmods
+                ):
                     yield msg('relative import of stdlib module')
                 else:
                     seenlocal = fullname
@@ -518,19 +545,24 @@
             found = fromlocal(node.module, node.level)
             if found and found[2]:  # node.module is a package
                 prefix = found[0] + '.'
-                symbols = (n.name for n in node.names
-                           if not fromlocal(prefix + n.name))
+                symbols = (
+                    n.name for n in node.names if not fromlocal(prefix + n.name)
+                )
             else:
                 symbols = (n.name for n in node.names)
             symbols = [sym for sym in symbols if sym not in directsymbols]
             if node.module and node.col_offset == root_col_offset:
                 if symbols and fullname not in allowsymbolimports:
-                    yield msg('direct symbol import %s from %s',
-                              ', '.join(symbols), fullname)
+                    yield msg(
+                        'direct symbol import %s from %s',
+                        ', '.join(symbols),
+                        fullname,
+                    )
 
                 if symbols and seennonsymbollocal:
-                    yield msg('symbol import follows non-symbol import: %s',
-                              fullname)
+                    yield msg(
+                        'symbol import follows non-symbol import: %s', fullname
+                    )
             if not symbols and fullname not in stdlib_modules:
                 seennonsymbollocal = True
 
@@ -538,15 +570,19 @@
                 assert node.level
 
                 # Only allow 1 group per level.
-                if (node.level in seenlevels
-                    and node.col_offset == root_col_offset):
-                    yield msg('multiple "from %s import" statements',
-                              '.' * node.level)
+                if (
+                    node.level in seenlevels
+                    and node.col_offset == root_col_offset
+                ):
+                    yield msg(
+                        'multiple "from %s import" statements', '.' * node.level
+                    )
 
                 # Higher-level groups come before lower-level groups.
                 if any(node.level > l for l in seenlevels):
-                    yield msg('higher-level import should come first: %s',
-                              fullname)
+                    yield msg(
+                        'higher-level import should come first: %s', fullname
+                    )
 
                 seenlevels.add(node.level)
 
@@ -556,14 +592,23 @@
 
             for n in node.names:
                 if lastentryname and n.name < lastentryname:
-                    yield msg('imports from %s not lexically sorted: %s < %s',
-                              fullname, n.name, lastentryname)
+                    yield msg(
+                        'imports from %s not lexically sorted: %s < %s',
+                        fullname,
+                        n.name,
+                        lastentryname,
+                    )
 
                 lastentryname = n.name
 
                 if n.name in requirealias and n.asname != requirealias[n.name]:
-                    yield msg('%s from %s must be "as" aliased to %s',
-                              n.name, fullname, requirealias[n.name])
+                    yield msg(
+                        '%s from %s must be "as" aliased to %s',
+                        n.name,
+                        fullname,
+                        requirealias[n.name],
+                    )
+
 
 def verify_stdlib_on_own_line(root):
     """Given some python source, verify that stdlib imports are done
@@ -582,13 +627,20 @@
             for n in node.names:
                 from_stdlib[n.name in stdlib_modules].append(n.name)
             if from_stdlib[True] and from_stdlib[False]:
-                yield ('mixed imports\n   stdlib:    %s\n   relative:  %s' %
-                       (', '.join(sorted(from_stdlib[True])),
-                        ', '.join(sorted(from_stdlib[False]))), node.lineno)
+                yield (
+                    'mixed imports\n   stdlib:    %s\n   relative:  %s'
+                    % (
+                        ', '.join(sorted(from_stdlib[True])),
+                        ', '.join(sorted(from_stdlib[False])),
+                    ),
+                    node.lineno,
+                )
+
 
 class CircularImport(Exception):
     pass
 
+
 def checkmod(mod, imports):
     shortest = {}
     visit = [[mod]]
@@ -603,6 +655,7 @@
                     continue
                 visit.append(path + [i])
 
+
 def rotatecycle(cycle):
     """arrange a cycle so that the lexicographically first module listed first
 
@@ -613,6 +666,7 @@
     idx = cycle.index(lowest)
     return cycle[idx:] + cycle[:idx] + [lowest]
 
+
 def find_cycles(imports):
     """Find cycles in an already-loaded import graph.
 
@@ -636,9 +690,11 @@
             cycles.add(" -> ".join(rotatecycle(cycle)))
     return cycles
 
+
 def _cycle_sortkey(c):
     return len(c), c
 
+
 def embedded(f, modname, src):
     """Extract embedded python code
 
@@ -680,6 +736,7 @@
             modname = modname.decode('utf8')
         yield code, "%s[%d]" % (modname, starts), name, starts - 1
 
+
 def sources(f, modname):
     """Yields possibly multiple sources from a filepath
 
@@ -700,6 +757,7 @@
             for script, modname, t, line in embedded(f, modname, src):
                 yield script, modname.encode('utf8'), t, line
 
+
 def main(argv):
     if len(argv) < 2 or (argv[1] == '-' and len(argv) > 2):
         print('Usage: %s {-|file [file] [file] ...}')
@@ -721,15 +779,19 @@
         for src, modname, name, line in sources(source_path, localmodname):
             try:
                 used_imports[modname] = sorted(
-                    imported_modules(src, modname, name, localmods,
-                                     ignore_nested=True))
-                for error, lineno in verify_import_convention(modname, src,
-                                                              localmods):
+                    imported_modules(
+                        src, modname, name, localmods, ignore_nested=True
+                    )
+                )
+                for error, lineno in verify_import_convention(
+                    modname, src, localmods
+                ):
                     any_errors = True
                     print('%s:%d: %s' % (source_path, lineno + line, error))
             except SyntaxError as e:
-                print('%s:%d: SyntaxError: %s' %
-                      (source_path, e.lineno + line, e))
+                print(
+                    '%s:%d: SyntaxError: %s' % (source_path, e.lineno + line, e)
+                )
     cycles = find_cycles(used_imports)
     if cycles:
         firstmods = set()
@@ -745,5 +807,6 @@
         any_errors = True
     return any_errors != 0
 
+
 if __name__ == '__main__':
     sys.exit(int(main(sys.argv)))