contrib/import-checker.py
changeset 26965 1fa66d3ad28d
parent 26964 5abba2c92da3
child 27018 e5be48dd8215
equal deleted inserted replaced
26964:5abba2c92da3 26965:1fa66d3ad28d
     1 #!/usr/bin/env python
     1 #!/usr/bin/env python
     2 
     2 
     3 import ast
     3 import ast
       
     4 import collections
     4 import os
     5 import os
     5 import sys
     6 import sys
     6 
     7 
     7 # Import a minimal set of stdlib modules needed for list_stdlib_modules()
     8 # Import a minimal set of stdlib modules needed for list_stdlib_modules()
     8 # to work when run from a virtualenv.  The modules were chosen empirically
     9 # to work when run from a virtualenv.  The modules were chosen empirically
    34                 for n in node.names:
    35                 for n in node.names:
    35                     if n.name == 'absolute_import':
    36                     if n.name == 'absolute_import':
    36                         return True
    37                         return True
    37 
    38 
    38     return False
    39     return False
       
    40 
       
    41 def walklocal(root):
       
    42     """Recursively yield all descendant nodes but not in a different scope"""
       
    43     todo = collections.deque(ast.iter_child_nodes(root))
       
    44     yield root, False
       
    45     while todo:
       
    46         node = todo.popleft()
       
    47         newscope = isinstance(node, ast.FunctionDef)
       
    48         if not newscope:
       
    49             todo.extend(ast.iter_child_nodes(node))
       
    50         yield node, newscope
    39 
    51 
    40 def dotted_name_of_path(path, trimpure=False):
    52 def dotted_name_of_path(path, trimpure=False):
    41     """Given a relative path to a source file, return its dotted module name.
    53     """Given a relative path to a source file, return its dotted module name.
    42 
    54 
    43     >>> dotted_name_of_path('mercurial/error.py')
    55     >>> dotted_name_of_path('mercurial/error.py')
   322     if absolute:
   334     if absolute:
   323         return verify_modern_convention(module, root)
   335         return verify_modern_convention(module, root)
   324     else:
   336     else:
   325         return verify_stdlib_on_own_line(root)
   337         return verify_stdlib_on_own_line(root)
   326 
   338 
   327 def verify_modern_convention(module, root):
   339 def verify_modern_convention(module, root, root_col_offset=0):
   328     """Verify a file conforms to the modern import convention rules.
   340     """Verify a file conforms to the modern import convention rules.
   329 
   341 
   330     The rules of the modern convention are:
   342     The rules of the modern convention are:
   331 
   343 
   332     * Ordering is stdlib followed by local imports. Each group is lexically
   344     * Ordering is stdlib followed by local imports. Each group is lexically
   359     # The last name to be imported (for sorting).
   371     # The last name to be imported (for sorting).
   360     lastname = None
   372     lastname = None
   361     # Relative import levels encountered so far.
   373     # Relative import levels encountered so far.
   362     seenlevels = set()
   374     seenlevels = set()
   363 
   375 
   364     for node in ast.walk(root):
   376     for node, newscope in walklocal(root):
   365         def msg(fmt, *args):
   377         def msg(fmt, *args):
   366             return (fmt % args, node.lineno)
   378             return (fmt % args, node.lineno)
   367         if isinstance(node, ast.Import):
   379         if newscope:
       
   380             # Check for local imports in function
       
   381             for r in verify_modern_convention(module, node,
       
   382                                               node.col_offset + 4):
       
   383                 yield r
       
   384         elif isinstance(node, ast.Import):
   368             # Disallow "import foo, bar" and require separate imports
   385             # Disallow "import foo, bar" and require separate imports
   369             # for each module.
   386             # for each module.
   370             if len(node.names) > 1:
   387             if len(node.names) > 1:
   371                 yield msg('multiple imported names: %s',
   388                 yield msg('multiple imported names: %s',
   372                           ', '.join(n.name for n in node.names))
   389                           ', '.join(n.name for n in node.names))
   373 
   390 
   374             name = node.names[0].name
   391             name = node.names[0].name
   375             asname = node.names[0].asname
   392             asname = node.names[0].asname
   376 
   393 
   377             # Ignore sorting rules on imports inside blocks.
   394             # Ignore sorting rules on imports inside blocks.
   378             if node.col_offset == 0:
   395             if node.col_offset == root_col_offset:
   379                 if lastname and name < lastname:
   396                 if lastname and name < lastname:
   380                     yield msg('imports not lexically sorted: %s < %s',
   397                     yield msg('imports not lexically sorted: %s < %s',
   381                               name, lastname)
   398                               name, lastname)
   382 
   399 
   383                 lastname = name
   400                 lastname = name
   384 
   401 
   385             # stdlib imports should be before local imports.
   402             # stdlib imports should be before local imports.
   386             stdlib = name in stdlib_modules
   403             stdlib = name in stdlib_modules
   387             if stdlib and seenlocal and node.col_offset == 0:
   404             if stdlib and seenlocal and node.col_offset == root_col_offset:
   388                 yield msg('stdlib import follows local import: %s', name)
   405                 yield msg('stdlib import follows local import: %s', name)
   389 
   406 
   390             if not stdlib:
   407             if not stdlib:
   391                 seenlocal = True
   408                 seenlocal = True
   392 
   409 
   421                 else:
   438                 else:
   422                     seenlocal = True
   439                     seenlocal = True
   423 
   440 
   424             # Direct symbol import is only allowed from certain modules and
   441             # Direct symbol import is only allowed from certain modules and
   425             # must occur before non-symbol imports.
   442             # must occur before non-symbol imports.
   426             if node.module and node.col_offset == 0:
   443             if node.module and node.col_offset == root_col_offset:
   427                 if fullname not in allowsymbolimports:
   444                 if fullname not in allowsymbolimports:
   428                     yield msg('direct symbol import from %s', fullname)
   445                     yield msg('direct symbol import from %s', fullname)
   429 
   446 
   430                 if seennonsymbolrelative:
   447                 if seennonsymbolrelative:
   431                     yield msg('symbol import follows non-symbol import: %s',
   448                     yield msg('symbol import follows non-symbol import: %s',
   434             if not node.module:
   451             if not node.module:
   435                 assert node.level
   452                 assert node.level
   436                 seennonsymbolrelative = True
   453                 seennonsymbolrelative = True
   437 
   454 
   438                 # Only allow 1 group per level.
   455                 # Only allow 1 group per level.
   439                 if node.level in seenlevels and node.col_offset == 0:
   456                 if (node.level in seenlevels
       
   457                     and node.col_offset == root_col_offset):
   440                     yield msg('multiple "from %s import" statements',
   458                     yield msg('multiple "from %s import" statements',
   441                               '.' * node.level)
   459                               '.' * node.level)
   442 
   460 
   443                 # Higher-level groups come before lower-level groups.
   461                 # Higher-level groups come before lower-level groups.
   444                 if any(node.level > l for l in seenlevels):
   462                 if any(node.level > l for l in seenlevels):