Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions src/auto_detect_exceptions/cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
from pathlib import Path
from .file_utils import find_python_files, read_python_file, write_python_file
from .ast_utils import (
parse_python_code,
Expand All @@ -10,13 +11,14 @@
from .docstring_utils import update_function_docstrings


def process_directory(directory: str, modify: bool) -> None:
def process_directory(directory: str, modify: bool, only_existing: bool) -> None:
"""
Process a directory, analyzing Python files and optionally modifying them.

Args:
directory (str): The directory to process.
modify (bool): If True, modifies files; otherwise, generates a report.
only_existing (bool): If True, only updates functions that already have docstrings.
"""
python_files = find_python_files(directory)
missing_exceptions = {}
Expand All @@ -40,15 +42,17 @@ def process_directory(directory: str, modify: bool) -> None:

if modify:
updated_code = update_function_docstrings(
source_code, function_exceptions
source_code,
function_exceptions,
only_update_existing_docstrings=only_existing,
)
write_python_file(file_path, updated_code)

if not modify:
generate_report(missing_exceptions)


def generate_report(missing_exceptions: dict) -> None:
def generate_report(missing_exceptions: dict[Path, dict[str, set[str]]]) -> None:
"""
Prints a report of functions missing exception documentation.

Expand All @@ -68,7 +72,7 @@ def generate_report(missing_exceptions: dict) -> None:
print(f" Expected exceptions: {', '.join(exceptions)}")


def main():
def main() -> None:
"""
Entry point for the CLI tool.
"""
Expand All @@ -84,10 +88,17 @@ def main():
action="store_true",
help="Modify files to add missing exception docstrings",
)
parser.add_argument(
"--only-existing",
action="store_true",
help="Only update functions that already have docstrings",
)

args = parser.parse_args()

process_directory(args.directory, modify=args.update)
process_directory(
args.directory, modify=args.update, only_existing=args.only_existing
)


if __name__ == "__main__":
Expand Down
22 changes: 18 additions & 4 deletions src/auto_detect_exceptions/docstring_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,13 @@ class DocstringUpdater(cst.CSTTransformer):
Transformer that updates function docstrings to include missing exceptions.
"""

def __init__(self, function_exceptions: Dict[str, Set[str]]):
def __init__(
self,
function_exceptions: Dict[str, Set[str]],
only_update_existing_docstrings: bool = False,
):
self.function_exceptions = function_exceptions
self.only_update_existing_docstrings = only_update_existing_docstrings

def leave_FunctionDef(
self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
Expand All @@ -28,6 +33,10 @@ def leave_FunctionDef(
"\"'"
) # Strip triple quotes

# If the flag is enabled and the function has no existing docstring, skip modification
if self.only_update_existing_docstrings and not existing_docstring:
return updated_node

# Generate new exceptions section
exception_lines = ["Raises:"]
for exc in sorted(self.function_exceptions[func_name]):
Expand All @@ -47,7 +56,7 @@ def leave_FunctionDef(
body=[cst.Expr(value=cst.SimpleString(f'"""{new_docstring}"""'))]
)

# Insert new docstring into the function body
# Ensure updated_node.body.body is a list
new_body = (
[new_docstring_node] + list(updated_node.body.body[1:])
if existing_docstring
Expand All @@ -58,18 +67,23 @@ def leave_FunctionDef(


def update_function_docstrings(
source_code: str, function_exceptions: Dict[str, Set[str]]
source_code: str,
function_exceptions: Dict[str, Set[str]],
only_update_existing_docstrings: bool = False,
) -> str:
"""
Uses `libcst` to update function docstrings in a Python source file.

Args:
source_code (str): The original source code.
function_exceptions (Dict[str, Set[str]]): A mapping of function names to their exceptions.
only_update_existing_docstrings (bool): If True, only add Raises to functions with existing docstrings.

Returns:
str: The modified source code.
"""
tree = cst.parse_module(source_code)
updated_tree = tree.visit(DocstringUpdater(function_exceptions))
updated_tree = tree.visit(
DocstringUpdater(function_exceptions, only_update_existing_docstrings)
)
return updated_tree.code
23 changes: 14 additions & 9 deletions src/auto_detect_exceptions/tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,31 @@ def setUp(self):
"""Create a temporary directory and Python files for testing."""
self.temp_dir = tempfile.TemporaryDirectory()
self.test_file = Path(self.temp_dir.name) / "test_script.py"
self.test_file.write_text("""
self.test_file.write_text('''
def foo():
"""Existing docstring."""
raise ValueError("An error occurred")
""")

def bar():
raise TypeError("Another error")
''')

def tearDown(self):
"""Cleanup the temporary directory."""
self.temp_dir.cleanup()

def test_process_directory_report(self):
"""Test processing a directory without modifying files."""
process_directory(self.temp_dir.name, modify=False)

def test_process_directory_update(self):
"""Test modifying files to add exception docstrings."""
process_directory(self.temp_dir.name, modify=True)
def test_process_directory_update_existing_only(self):
"""Test modifying only functions that have existing docstrings."""
process_directory(self.temp_dir.name, modify=True, only_existing=True)
content = self.test_file.read_text()

# Function `foo` should be updated
self.assertIn("Raises:", content)
self.assertIn("ValueError", content)

# Function `bar` should remain unchanged because it had no docstring
self.assertNotIn("TypeError", content)


if __name__ == "__main__":
unittest.main()
17 changes: 12 additions & 5 deletions src/auto_detect_exceptions/tests/test_docstring_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,26 @@


class TestDocstringUtils(unittest.TestCase):
def test_update_function_docstrings(self):
"""Test modifying a function's docstring to include missing exceptions."""
def test_update_function_docstrings_existing_only(self):
"""Test that it only updates functions with existing docstrings when the option is enabled."""
source_code = '''
def foo():
"""This function does something."""
raise ValueError("Error occurred")

def bar():
raise TypeError("Another error")
'''

function_exceptions = {"foo": {"ValueError"}}
updated_code = update_function_docstrings(source_code, function_exceptions)
function_exceptions = {"foo": {"ValueError"}, "bar": {"TypeError"}}

updated_code = update_function_docstrings(
source_code, function_exceptions, only_update_existing_docstrings=True
)

self.assertIn("Raises:", updated_code)
self.assertIn("Raises:", updated_code) # Should update `foo`
self.assertIn("ValueError", updated_code)
self.assertNotIn("TypeError", updated_code) # `bar` should NOT be updated


if __name__ == "__main__":
Expand Down