11import sys
22from itertools import zip_longest
3- from typing import List , Union
3+ from typing import List , Set , Union
44
55from flake8_annotations .enums import AnnotationType , ClassDecoratorType , FunctionType
66
1717
1818 PY_GTE_38 = False
1919
20- __version__ = "2.0.0 "
20+ __version__ = "2.0.1 "
2121
2222AST_ARG_TYPES = ("args" , "vararg" , "kwonlyargs" , "kwarg" )
2323if PY_GTE_38 :
2424 # Positional-only args introduced in Python 3.8
2525 AST_ARG_TYPES += ("posonlyargs" ,)
2626
2727AST_FUNCTION_TYPES = Union [ast .FunctionDef , ast .AsyncFunctionDef ]
28+ AST_DEF_NODES = Union [ast .FunctionDef , ast .AsyncFunctionDef , ast .ClassDef ]
2829
2930
3031class Argument :
@@ -60,8 +61,9 @@ def __str__(self) -> str:
6061 def __repr__ (self ) -> str :
6162 """Format the Argument object into its "official" representation."""
6263 return (
63- f"Argument({ self .argname !r} , { self .lineno } , { self .col_offset } , { self .annotation_type } , "
64- f"{ self .has_type_annotation } , { self .has_3107_annotation } , { self .has_type_comment } )"
64+ f"Argument(argname={ self .argname !r} , lineno={ self .lineno } , col_offset={ self .col_offset } , " # noqa: E501
65+ f"annotation_type={ self .annotation_type } , has_type_annotation={ self .has_type_annotation } , " # noqa: E501
66+ f"has_3107_annotation={ self .has_3107_annotation } , has_type_comment={ self .has_type_comment } )" # noqa: E501
6567 )
6668
6769 @classmethod
@@ -147,9 +149,12 @@ def __str__(self) -> str:
147149 def __repr__ (self ) -> str :
148150 """Format the Function object into its "official" representation."""
149151 return (
150- f"Function({ self .name !r} , { self .lineno } , { self .col_offset } , { self .function_type } , "
151- f"{ self .is_class_method } , { self .class_decorator_type } , { self .is_return_annotated } , "
152- f"{ self .has_type_comment } , { self .has_only_none_returns } , { self .args } )"
152+ f"Function(name={ self .name !r} , lineno={ self .lineno } , col_offset={ self .col_offset } , "
153+ f"function_type={ self .function_type } , is_class_method={ self .is_class_method } , "
154+ f"class_decorator_type={ self .class_decorator_type } , "
155+ f"is_return_annotated={ self .is_return_annotated } , "
156+ f"has_type_comment={ self .has_type_comment } , "
157+ f"has_only_none_returns={ self .has_only_none_returns } , args={ self .args } )"
153158 )
154159
155160 @classmethod
@@ -192,7 +197,9 @@ def from_function_node(cls, node: AST_FUNCTION_TYPES, lines: List[str], **kwargs
192197 while True :
193198 # To account for multiline docstrings, rewind through the lines until we find the line
194199 # containing the :
195- colon_loc = lines [def_end_lineno - 1 ].find (":" )
200+ # Use str.rfind() to account for annotations on the same line, definition closure should
201+ # be the last : on the line
202+ colon_loc = lines [def_end_lineno - 1 ].rfind (":" )
196203 if colon_loc == - 1 :
197204 def_end_lineno -= 1
198205 else :
@@ -215,7 +222,7 @@ def from_function_node(cls, node: AST_FUNCTION_TYPES, lines: List[str], **kwargs
215222 new_function = cls .try_type_comment (new_function , node )
216223
217224 # Check for the presence of non-`None` returns using the special-case return node visitor
218- return_visitor = ReturnVisitor ()
225+ return_visitor = ReturnVisitor (node )
219226 return_visitor .visit (node )
220227 new_function .has_only_none_returns = return_visitor .has_only_none_returns
221228
@@ -298,49 +305,35 @@ class FunctionVisitor(ast.NodeVisitor):
298305
299306 def __init__ (self , lines : List [str ]):
300307 self .lines = lines
301- self .function_definitions = []
308+ self .function_definitions : List [Function ] = []
309+ self ._context : List [AST_DEF_NODES ] = []
302310
303- def visit_FunctionDef (self , node : ast . FunctionDef ) -> None :
311+ def switch_context (self , node : AST_DEF_NODES ) -> None :
304312 """
305- Handle a visit to a function definition .
313+ Utilize a context switcher as a generic function visitor in order to track function context .
306314
307- Note: This will not contain class methods, these are included in the body of ClassDef
308- statements
309- """
310- self .function_definitions .append (Function .from_function_node (node , self .lines ))
311- self .generic_visit (node ) # Walk through any nested functions
312-
313- def visit_AsyncFunctionDef (self , node : ast .AsyncFunctionDef ) -> None :
314- """
315- Handle a visit to a coroutine definition.
316-
317- Note: This will not contain class methods, these are included in the body of ClassDef
318- statements
319- """
320- self .function_definitions .append (Function .from_function_node (node , self .lines ))
321- self .generic_visit (node ) # Walk through any nested functions
315+ Without keeping track of context, it's challenging to reliably differentiate class methods
316+ from "regular" functions, especially in the case of nested classes.
322317
323- def visit_ClassDef ( self , node : ast . ClassDef ) -> None :
318+ Thank you for the inspiration @isidentical :)
324319 """
325- Handle a visit to a class definition.
320+ if isinstance (node , (ast .FunctionDef , ast .AsyncFunctionDef )):
321+ # Check for non-empty context first to prevent IndexErrors for non-nested nodes
322+ if self ._context and isinstance (self ._context [- 1 ], ast .ClassDef ):
323+ # Check if current context is a ClassDef node & pass the appropriate flag
324+ self .function_definitions .append (
325+ Function .from_function_node (node , self .lines , is_class_method = True )
326+ )
327+ else :
328+ self .function_definitions .append (Function .from_function_node (node , self .lines ))
326329
327- Class methods will all be contained in the body of the node
328- """
329- method_nodes = [
330- child_node
331- for child_node in node .body
332- if isinstance (child_node , (ast .FunctionDef , ast .AsyncFunctionDef ))
333- ]
334- self .function_definitions .extend (
335- [
336- Function .from_function_node (method_node , self .lines , is_class_method = True )
337- for method_node in method_nodes
338- ]
339- )
330+ self ._context .append (node )
331+ self .generic_visit (node )
332+ self ._context .pop ()
340333
341- # Use ast.NodeVisitor.generic_visit to start down the nested method chain
342- for sub_node in node . body :
343- self . generic_visit ( sub_node )
334+ visit_FunctionDef = switch_context
335+ visit_AsyncFunctionDef = switch_context
336+ visit_ClassDef = switch_context
344337
345338
346339class ReturnVisitor (ast .NodeVisitor ):
@@ -353,13 +346,29 @@ class ReturnVisitor(ast.NodeVisitor):
353346 If the function node being visited has no return statement, or contains only return
354347 statement(s) that explicitly return `None`, the `instance.has_only_none_returns` flag will be
355348 set to `True`.
349+
350+ Due to the generic visiting being done, we need to keep track of the context in which a
351+ non-`None` return node is found. These functions are added to a set that is checked to see
352+ whether nor not the parent node is present.
356353 """
357354
358- def __init__ (self ):
359- self .has_only_none_returns = True
355+ def __init__ (self , parent_node : AST_FUNCTION_TYPES ):
356+ self .parent_node = parent_node
357+ self ._context : List [AST_FUNCTION_TYPES ] = []
358+ self ._non_none_return_nodes : Set [AST_FUNCTION_TYPES ] = set ()
359+
360+ @property
361+ def has_only_none_returns (self ) -> bool :
362+ """Return `True` if the parent node isn't in the visited nodes that don't return `None`."""
363+ return self .parent_node not in self ._non_none_return_nodes
360364
361365 def visit_Return (self , node : ast .Return ) -> None :
362- """Check each Return node to see if it returns anything other than `None`."""
366+ """
367+ Check each Return node to see if it returns anything other than `None`.
368+
369+ If the node being visited returns anything other than `None`, its parent context is added to
370+ the set of non-returning child nodes of the parent node.
371+ """
363372 if node .value is not None :
364373 # In the event of an explicit `None` return (`return None`), the node body will be an
365374 # instance of either `ast.Constant` (3.8+) or `ast.NameConstant`, which we need to check
@@ -368,4 +377,21 @@ def visit_Return(self, node: ast.Return) -> None:
368377 if node .value .value is None :
369378 return
370379
371- self .has_only_none_returns = False
380+ self ._non_none_return_nodes .add (self ._context [- 1 ])
381+
382+ def switch_context (self , node : AST_FUNCTION_TYPES ) -> None :
383+ """
384+ Utilize a context switcher as a generic visitor in order to properly track function context.
385+
386+ Using a traditional `ast.generic_visit` setup, return nodes of nested functions are visited
387+ without any knowledge of their context, causing the top-level function to potentially be
388+ mis-classified.
389+
390+ Thank you for the inspiration @isidentical :)
391+ """
392+ self ._context .append (node )
393+ self .generic_visit (node )
394+ self ._context .pop ()
395+
396+ visit_FunctionDef = switch_context
397+ visit_AsyncFunctionDef = switch_context
0 commit comments