From 334642055d5e2b65f2a55acc37580f2ee145e693 Mon Sep 17 00:00:00 2001 From: lucemia Date: Mon, 10 Mar 2025 17:53:59 +0800 Subject: [PATCH 1/2] add existing only mode --- src/auto_detect_exceptions/cli.py | 16 ++++++++++--- src/auto_detect_exceptions/docstring_utils.py | 22 ++++++++++++++---- src/auto_detect_exceptions/tests/test_cli.py | 23 +++++++++++-------- .../tests/test_docstring_utils.py | 17 ++++++++++---- 4 files changed, 57 insertions(+), 21 deletions(-) diff --git a/src/auto_detect_exceptions/cli.py b/src/auto_detect_exceptions/cli.py index cce9c3c..41b2f4a 100644 --- a/src/auto_detect_exceptions/cli.py +++ b/src/auto_detect_exceptions/cli.py @@ -10,13 +10,14 @@ from .docstring_utils import update_function_docstrings -def process_directory(directory: str, modify: bool) -> None: +def process_directory(directory: str, modify: bool, only_existing: bool) -> None: """ Process a directory, analyzing Python files and optionally modifying them. Args: directory (str): The directory to process. modify (bool): If True, modifies files; otherwise, generates a report. + only_existing (bool): If True, only updates functions that already have docstrings. """ python_files = find_python_files(directory) missing_exceptions = {} @@ -40,7 +41,9 @@ def process_directory(directory: str, modify: bool) -> None: if modify: updated_code = update_function_docstrings( - source_code, function_exceptions + source_code, + function_exceptions, + only_existing_docstrings=only_existing, ) write_python_file(file_path, updated_code) @@ -84,10 +87,17 @@ def main(): action="store_true", help="Modify files to add missing exception docstrings", ) + parser.add_argument( + "--only-existing", + action="store_true", + help="Only update functions that already have docstrings", + ) args = parser.parse_args() - process_directory(args.directory, modify=args.update) + process_directory( + args.directory, modify=args.update, only_existing=args.only_existing + ) if __name__ == "__main__": diff --git a/src/auto_detect_exceptions/docstring_utils.py b/src/auto_detect_exceptions/docstring_utils.py index e3daf49..dab5ad9 100644 --- a/src/auto_detect_exceptions/docstring_utils.py +++ b/src/auto_detect_exceptions/docstring_utils.py @@ -7,8 +7,13 @@ class DocstringUpdater(cst.CSTTransformer): Transformer that updates function docstrings to include missing exceptions. """ - def __init__(self, function_exceptions: Dict[str, Set[str]]): + def __init__( + self, + function_exceptions: Dict[str, Set[str]], + only_update_existing_docstrings: bool = False, + ): self.function_exceptions = function_exceptions + self.only_update_existing_docstrings = only_update_existing_docstrings def leave_FunctionDef( self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef @@ -28,6 +33,10 @@ def leave_FunctionDef( "\"'" ) # Strip triple quotes + # If the flag is enabled and the function has no existing docstring, skip modification + if self.only_update_existing_docstrings and not existing_docstring: + return updated_node + # Generate new exceptions section exception_lines = ["Raises:"] for exc in sorted(self.function_exceptions[func_name]): @@ -47,7 +56,7 @@ def leave_FunctionDef( body=[cst.Expr(value=cst.SimpleString(f'"""{new_docstring}"""'))] ) - # Insert new docstring into the function body + # Ensure updated_node.body.body is a list new_body = ( [new_docstring_node] + list(updated_node.body.body[1:]) if existing_docstring @@ -58,7 +67,9 @@ def leave_FunctionDef( def update_function_docstrings( - source_code: str, function_exceptions: Dict[str, Set[str]] + source_code: str, + function_exceptions: Dict[str, Set[str]], + only_update_existing_docstrings: bool = False, ) -> str: """ Uses `libcst` to update function docstrings in a Python source file. @@ -66,10 +77,13 @@ def update_function_docstrings( Args: source_code (str): The original source code. function_exceptions (Dict[str, Set[str]]): A mapping of function names to their exceptions. + only_update_existing_docstrings (bool): If True, only add Raises to functions with existing docstrings. Returns: str: The modified source code. """ tree = cst.parse_module(source_code) - updated_tree = tree.visit(DocstringUpdater(function_exceptions)) + updated_tree = tree.visit( + DocstringUpdater(function_exceptions, only_update_existing_docstrings) + ) return updated_tree.code diff --git a/src/auto_detect_exceptions/tests/test_cli.py b/src/auto_detect_exceptions/tests/test_cli.py index b390141..f784030 100644 --- a/src/auto_detect_exceptions/tests/test_cli.py +++ b/src/auto_detect_exceptions/tests/test_cli.py @@ -9,26 +9,31 @@ def setUp(self): """Create a temporary directory and Python files for testing.""" self.temp_dir = tempfile.TemporaryDirectory() self.test_file = Path(self.temp_dir.name) / "test_script.py" - self.test_file.write_text(""" + self.test_file.write_text(''' def foo(): + """Existing docstring.""" raise ValueError("An error occurred") -""") + +def bar(): + raise TypeError("Another error") +''') def tearDown(self): """Cleanup the temporary directory.""" self.temp_dir.cleanup() - def test_process_directory_report(self): - """Test processing a directory without modifying files.""" - process_directory(self.temp_dir.name, modify=False) - - def test_process_directory_update(self): - """Test modifying files to add exception docstrings.""" - process_directory(self.temp_dir.name, modify=True) + def test_process_directory_update_existing_only(self): + """Test modifying only functions that have existing docstrings.""" + process_directory(self.temp_dir.name, modify=True, only_existing=True) content = self.test_file.read_text() + + # Function `foo` should be updated self.assertIn("Raises:", content) self.assertIn("ValueError", content) + # Function `bar` should remain unchanged because it had no docstring + self.assertNotIn("TypeError", content) + if __name__ == "__main__": unittest.main() diff --git a/src/auto_detect_exceptions/tests/test_docstring_utils.py b/src/auto_detect_exceptions/tests/test_docstring_utils.py index b3318b9..809f7b2 100644 --- a/src/auto_detect_exceptions/tests/test_docstring_utils.py +++ b/src/auto_detect_exceptions/tests/test_docstring_utils.py @@ -3,19 +3,26 @@ class TestDocstringUtils(unittest.TestCase): - def test_update_function_docstrings(self): - """Test modifying a function's docstring to include missing exceptions.""" + def test_update_function_docstrings_existing_only(self): + """Test that it only updates functions with existing docstrings when the option is enabled.""" source_code = ''' def foo(): """This function does something.""" raise ValueError("Error occurred") + +def bar(): + raise TypeError("Another error") ''' - function_exceptions = {"foo": {"ValueError"}} - updated_code = update_function_docstrings(source_code, function_exceptions) + function_exceptions = {"foo": {"ValueError"}, "bar": {"TypeError"}} + + updated_code = update_function_docstrings( + source_code, function_exceptions, only_update_existing_docstrings=True + ) - self.assertIn("Raises:", updated_code) + self.assertIn("Raises:", updated_code) # Should update `foo` self.assertIn("ValueError", updated_code) + self.assertNotIn("TypeError", updated_code) # `bar` should NOT be updated if __name__ == "__main__": From 199c70877732acc0ab7b3a76e6ed7901fd7a128c Mon Sep 17 00:00:00 2001 From: lucemia Date: Mon, 10 Mar 2025 17:59:02 +0800 Subject: [PATCH 2/2] fix --- src/auto_detect_exceptions/cli.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/auto_detect_exceptions/cli.py b/src/auto_detect_exceptions/cli.py index 41b2f4a..a458bd8 100644 --- a/src/auto_detect_exceptions/cli.py +++ b/src/auto_detect_exceptions/cli.py @@ -1,4 +1,5 @@ import argparse +from pathlib import Path from .file_utils import find_python_files, read_python_file, write_python_file from .ast_utils import ( parse_python_code, @@ -43,7 +44,7 @@ def process_directory(directory: str, modify: bool, only_existing: bool) -> None updated_code = update_function_docstrings( source_code, function_exceptions, - only_existing_docstrings=only_existing, + only_update_existing_docstrings=only_existing, ) write_python_file(file_path, updated_code) @@ -51,7 +52,7 @@ def process_directory(directory: str, modify: bool, only_existing: bool) -> None generate_report(missing_exceptions) -def generate_report(missing_exceptions: dict) -> None: +def generate_report(missing_exceptions: dict[Path, dict[str, set[str]]]) -> None: """ Prints a report of functions missing exception documentation. @@ -71,7 +72,7 @@ def generate_report(missing_exceptions: dict) -> None: print(f" Expected exceptions: {', '.join(exceptions)}") -def main(): +def main() -> None: """ Entry point for the CLI tool. """