Skip to content

Commit 7ee6fdc

Browse files
authored
Merge pull request #302 from python-ellar/feat/nested-forward-ref-resolution
feat: implement recursive forward ref resolution in test module dependencies
2 parents f451bb1 + a486a83 commit 7ee6fdc

File tree

4 files changed

+162
-26
lines changed

4 files changed

+162
-26
lines changed

ellar/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
"""Ellar - Python ASGI web framework for building fast, efficient, and scalable RESTful APIs and server-side applications."""
22

3-
__version__ = "0.9.2"
3+
__version__ = "0.9.3"

ellar/testing/dependency_analyzer.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
if t.TYPE_CHECKING: # pragma: no cover
1414
from ellar.common import ControllerBase
15-
from ellar.core import ForwardRefModule, ModuleBase
15+
from ellar.core import ForwardRefModule, ModuleBase, ModuleSetup
1616
from ellar.di import ModuleTreeManager
1717

1818

@@ -60,6 +60,13 @@ def __init__(self, application_module: t.Union[t.Type["ModuleBase"], str]):
6060

6161
self._module_tree = self._build_module_tree()
6262

63+
def get_application_module_providers(self) -> t.List[t.Type]:
64+
"""Get all provider types from the ApplicationModule tree"""
65+
mod_data = self._module_tree.get_app_module()
66+
if mod_data:
67+
return list(mod_data.providers.values())
68+
return []
69+
6370
def _build_module_tree(self) -> "ModuleTreeManager":
6471
"""Build complete module tree for ApplicationModule"""
6572
from ellar.app import AppFactory
@@ -164,7 +171,7 @@ def collect_dependencies(mod: t.Type["ModuleBase"]) -> None:
164171

165172
def resolve_forward_ref(
166173
self, forward_ref: "ForwardRefModule"
167-
) -> t.Optional[t.Type["ModuleBase"]]:
174+
) -> t.Optional["ModuleSetup"]:
168175
"""
169176
Resolve a ForwardRefModule to its actual module from ApplicationModule tree
170177
@@ -181,7 +188,7 @@ def resolve_forward_ref(
181188
filter_item=lambda data: True,
182189
find_predicate=lambda data: data.name == forward_ref.module_name,
183190
)
184-
return t.cast(t.Type["ModuleBase"], result.value.module) if result else None
191+
return t.cast("ModuleSetup", result.value) if result else None
185192

186193
elif hasattr(forward_ref, "module") and forward_ref.module:
187194
# Module can be a Type or a string import path
@@ -197,12 +204,8 @@ def resolve_forward_ref(
197204

198205
# Search for this module type in the tree
199206
module_data = self._module_tree.get_module(module_cls)
200-
return (
201-
t.cast(t.Type["ModuleBase"], module_data.value.module)
202-
if module_data
203-
else None
204-
)
205-
207+
if module_data:
208+
return t.cast("ModuleSetup", module_data.value)
206209
return None
207210

208211

ellar/testing/module.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
constants,
1212
)
1313
from ellar.common.types import T
14-
from ellar.core import ModuleBase
14+
from ellar.core import ModuleBase, ModuleSetup
1515
from ellar.core.routing import EllarControllerMount
1616
from ellar.di import ProviderConfig
1717
from ellar.reflect import reflect
@@ -166,16 +166,12 @@ def create_test_module(
166166
app_analyzer = ApplicationModuleDependencyAnalyzer(application_module)
167167
controller_analyzer = ControllerDependencyAnalyzer()
168168

169-
# 1. Resolve ForwardRefs in registered modules
170-
resolved_modules = cls._resolve_forward_refs(modules_list, app_analyzer)
171-
modules_list = resolved_modules
172-
173169
# 2. Analyze controllers and find required modules (with recursive dependencies)
174170
required_modules = cls._analyze_and_resolve_controller_dependencies(
175171
controllers, controller_analyzer, app_analyzer
176172
)
177173

178-
# 3. Add required modules that aren't already registered
174+
# 2. Add required modules that aren't already registered
179175
# Use type comparison to avoid duplicates
180176
existing_module_types = {
181177
m if isinstance(m, type) else m.module if hasattr(m, "module") else m
@@ -186,6 +182,15 @@ def create_test_module(
186182
modules_list.append(required_module)
187183
existing_module_types.add(required_module)
188184

185+
# 4. Resolve ForwardRefs in registered modules
186+
resolved_modules = cls._resolve_forward_refs(modules_list, app_analyzer)
187+
modules_list.extend(resolved_modules)
188+
189+
providers = list(providers)
190+
# 5. Add application module providers, since this is the root module
191+
# and it will be used to resolve dependencies
192+
providers.extend(app_analyzer.get_application_module_providers())
193+
189194
# Create the module with complete dependency list
190195
module = Module(
191196
modules=modules_list,
@@ -229,20 +234,30 @@ def _resolve_forward_refs(
229234
modules: t.List[t.Any],
230235
app_analyzer: "ApplicationModuleDependencyAnalyzer",
231236
) -> t.List[t.Any]:
232-
"""Resolve ForwardRefModule instances from ApplicationModule"""
237+
"""Resolve ForwardRefModule instances from ApplicationModule recursively"""
233238
from ellar.core import ForwardRefModule
234239

235240
resolved = []
236241
for module in modules:
242+
# Resolve current module if it's a ForwardRefModule
237243
if isinstance(module, ForwardRefModule):
238244
actual_module = app_analyzer.resolve_forward_ref(module)
239-
if actual_module:
240-
resolved.append(actual_module)
241-
else:
242-
# Keep original if can't resolve (might be test-specific)
243-
resolved.append(module)
245+
current_module = actual_module.module
246+
resolved.append(actual_module)
247+
elif isinstance(module, ModuleSetup):
248+
current_module = module.module
244249
else:
245-
resolved.append(module)
250+
current_module = module
251+
252+
# Recursively resolve forward refs in module's dependencies
253+
registered_modules = (
254+
reflect.get_metadata(constants.MODULE_METADATA.MODULES, current_module)
255+
or []
256+
)
257+
if registered_modules:
258+
resolved.extend(
259+
cls._resolve_forward_refs(registered_modules, app_analyzer)
260+
)
246261

247262
return resolved
248263

tests/test_testing_dependency_resolution.py

Lines changed: 121 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -276,13 +276,39 @@ def test_application_module_analyzer_get_module_dependencies_none():
276276
assert len(dependencies) == 0
277277

278278

279+
def test_application_module_analyzer_get_application_module_providers():
280+
"""Test getting providers from ApplicationModule"""
281+
from ellar.di import ProviderConfig
282+
283+
@injectable
284+
class AppLevelService:
285+
pass
286+
287+
@Module(
288+
name="TestAppModuleWithProviders",
289+
modules=[AuthModule],
290+
providers=[ProviderConfig(AppLevelService, use_class=AppLevelService)],
291+
)
292+
class TestAppModuleWithProviders(ModuleBase):
293+
pass
294+
295+
analyzer = ApplicationModuleDependencyAnalyzer(TestAppModuleWithProviders)
296+
providers = analyzer.get_application_module_providers()
297+
298+
# Should include AppLevelService
299+
assert AppLevelService in providers or any(
300+
hasattr(p, "get_type") and p.get_type() == AppLevelService for p in providers
301+
)
302+
303+
279304
# ============================================================================
280305
# Unit Tests: ForwardRefModule Resolution
281306
# ============================================================================
282307

283308

284309
def test_forward_ref_resolution_by_type():
285310
"""Test resolving ForwardRefModule by type"""
311+
from ellar.core.modules import ModuleSetup
286312

287313
# Need to have DatabaseModule actually registered in the application tree
288314
@Module(
@@ -298,7 +324,9 @@ class ForwardRefTestModule(ModuleBase):
298324
forward_ref = ForwardRefModule(module=DatabaseModule)
299325
resolved = analyzer.resolve_forward_ref(forward_ref)
300326

301-
assert resolved == DatabaseModule
327+
# Should return a ModuleSetup instance
328+
assert isinstance(resolved, ModuleSetup)
329+
assert resolved.module == DatabaseModule
302330

303331

304332
def test_forward_ref_resolution_by_name():
@@ -318,7 +346,8 @@ class ForwardRefTestModule2(ModuleBase):
318346
forward_ref = ForwardRefModule(module_name="DatabaseModule")
319347
resolved = analyzer.resolve_forward_ref(forward_ref)
320348

321-
assert resolved == DatabaseModule
349+
# When resolving by name, it returns the module type directly
350+
assert resolved.module == DatabaseModule
322351

323352

324353
def test_forward_ref_resolution_not_found():
@@ -336,6 +365,52 @@ class ForwardRefTestModule3(ModuleBase):
336365
assert resolved is None
337366

338367

368+
def test_resolve_forward_refs_handles_module_setup():
369+
"""Test that _resolve_forward_refs properly handles ModuleSetup instances"""
370+
from ellar.testing.module import Test
371+
372+
@Module(name="TestModuleForSetup", modules=[AuthModule, DatabaseModule])
373+
class TestModuleForSetup(ModuleBase):
374+
pass
375+
376+
analyzer = ApplicationModuleDependencyAnalyzer(TestModuleForSetup)
377+
378+
# Pass ForwardRefModule instances that will be resolved
379+
forward_ref_auth = ForwardRefModule(module=AuthModule)
380+
forward_ref_db = ForwardRefModule(module=DatabaseModule)
381+
382+
modules = [forward_ref_auth, forward_ref_db]
383+
resolved = Test._resolve_forward_refs(modules, analyzer)
384+
385+
# Should resolve both ForwardRefModules (and potentially their dependencies)
386+
assert len(resolved) >= 2
387+
388+
389+
def test_resolve_forward_refs_recursive_extension():
390+
"""Test that _resolve_forward_refs recursively extends with nested modules"""
391+
from ellar.testing.module import Test
392+
393+
# DatabaseModule has LoggingModule as dependency
394+
@Module(
395+
name="TestModuleForRecursive",
396+
modules=[DatabaseModule, AuthModule],
397+
)
398+
class TestModuleForRecursive(ModuleBase):
399+
pass
400+
401+
analyzer = ApplicationModuleDependencyAnalyzer(TestModuleForRecursive)
402+
403+
# Start with ForwardRefModule to DatabaseModule (which has LoggingModule as dependency)
404+
forward_ref_db = ForwardRefModule(module=DatabaseModule)
405+
modules = [forward_ref_db]
406+
resolved = Test._resolve_forward_refs(modules, analyzer)
407+
408+
# Should return resolved DatabaseModule (and potentially nested dependencies)
409+
# The exact count depends on whether DatabaseModule's LoggingModule dependency
410+
# has any ForwardRefModules in its metadata
411+
assert len(resolved) >= 1
412+
413+
339414
# ============================================================================
340415
# Integration Tests: Test.create_test_module()
341416
# ============================================================================
@@ -469,7 +544,7 @@ class TestAppWithForwardRef(ModuleBase):
469544

470545
tm = Test.create_test_module(
471546
controllers=[UserController],
472-
modules=[ModuleWithForwardRef], # Contains ForwardRef to AuthModule
547+
# Don't manually add ForwardRef module - let auto-resolution handle it
473548
application_module=TestAppWithForwardRef,
474549
)
475550

@@ -622,6 +697,49 @@ def test_create_test_module_with_import_string_application_module(reflect_contex
622697
assert isinstance(controller.auth_service, IAuthService)
623698

624699

700+
def test_create_test_module_includes_application_module_providers(reflect_context):
701+
"""Test that test module includes providers from ApplicationModule"""
702+
703+
@injectable
704+
class AppLevelService:
705+
def get_value(self):
706+
return "app_level"
707+
708+
@Module(
709+
name="AppModuleWithProviders",
710+
modules=[AuthModule],
711+
providers=[ProviderConfig(AppLevelService, use_class=AppLevelService)],
712+
)
713+
class AppModuleWithProviders(ModuleBase):
714+
pass
715+
716+
@Controller()
717+
class ControllerUsingAppService:
718+
def __init__(self, app_service: AppLevelService):
719+
self.app_service = app_service
720+
721+
@get("/test")
722+
def test_endpoint(self):
723+
return {"value": self.app_service.get_value()}
724+
725+
tm = Test.create_test_module(
726+
controllers=[ControllerUsingAppService],
727+
application_module=AppModuleWithProviders,
728+
)
729+
730+
tm.create_application()
731+
732+
# Should be able to get the app-level service
733+
app_service = tm.get(AppLevelService)
734+
assert app_service is not None
735+
assert app_service.get_value() == "app_level"
736+
737+
# Controller should also work
738+
controller = tm.get(ControllerUsingAppService)
739+
assert controller is not None
740+
assert isinstance(controller.app_service, AppLevelService)
741+
742+
625743
def test_application_module_analyzer_with_import_string():
626744
"""Test that ApplicationModuleDependencyAnalyzer accepts import strings"""
627745
import_string = "tests.test_testing_dependency_resolution:ApplicationModule"

0 commit comments

Comments
 (0)