diff --git a/CHANGELOG.md b/CHANGELOG.md index f2d4e97..1d00889 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,31 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + +- Add opt-in support for context-managed services via the `manage_context=True` + flag on `add_scoped`, `add_transient`, `add_scoped_by_factory`, + `add_transient_by_factory`, `register`, `register_factory` and `bind_types`. + When set, rodi enters the resolved instance as a context manager and exits it + in LIFO order when the owning `ActivationScope` ends. Default is `False`, so + existing behavior is preserved. +- Add `AsyncActivationScope`, `Services.create_async_scope` and + `Services.aget` / `AsyncActivationScope.aget` to manage async context + managers (`__aenter__` / `__aexit__`) and mixed sync/async dependency graphs. + Sync and async managed instances entered through the same `AsyncActivationScope` + share a single `AsyncExitStack`, so cross-protocol exit order is strictly LIFO + regardless of whether each instance was resolved via `get` or `aget`. +- Add `AsyncContainerProtocol` and `Container.aresolve(obj_type, scope)` as the + async counterpart to `Container.resolve`. `aresolve` requires an explicit + `AsyncActivationScope` and raises `TypeError` otherwise, so frameworks that + integrate rodi's async support get a single, framework-facing entry point + without having to reach into `AsyncActivationScope.aget` directly. The + existing `ContainerProtocol` is unchanged; sync-only containers and + third-party implementations remain compatible. +- Reject `manage_context=True` for `SINGLETON` services with a new + `InvalidContextManagerRegistration` exception. Singleton lifecycle hooks are + intentionally out of scope; track that follow-up separately. + ## [2.1.0] - 2026-03-08 :woman: - Improve `resolve()` typing, by @sobolevn. diff --git a/rodi/__init__.py b/rodi/__init__.py index eda42ac..9750d15 100644 --- a/rodi/__init__.py +++ b/rodi/__init__.py @@ -3,6 +3,7 @@ import re import sys from collections import defaultdict +from contextlib import AsyncExitStack, ExitStack from enum import Enum from inspect import Signature, _empty, isabstract, isclass, iscoroutinefunction from typing import ( @@ -50,6 +51,27 @@ def __contains__(self, item) -> bool: """ +class AsyncContainerProtocol(Protocol): + """ + Optional interface for DI containers that support asynchronous resolution + of context-managed services. Implementing this is opt-in; sync-only + containers continue to satisfy `ContainerProtocol` alone. + """ + + @overload + async def aresolve(self, obj_type: Type[T], *args, **kwargs) -> T: ... + + @overload + async def aresolve(self, obj_type: str, *args, **kwargs) -> Any: ... + + async def aresolve(self, obj_type: Type[T] | str, *args, **kwargs) -> Any: + """ + Activates an instance of the given type asynchronously, entering and + exiting any sync or async context managers in the resolved dependency + graph through the supplied scope. + """ + + AliasesTypeHint = dict[str, Type] @@ -230,6 +252,17 @@ def __init__(self, base_type, decorator_type): ) +class InvalidContextManagerRegistration(DIException): + """Exception raised when manage_context is used with an unsupported lifestyle.""" + + def __init__(self, concrete_type): + super().__init__( + f"manage_context=True is not supported for SINGLETON services " + f"('{class_name(concrete_type)}'). Use SCOPED or TRANSIENT, or manage " + f"the context manually." + ) + + class ServiceLifeStyle(Enum): TRANSIENT = 1 SCOPED = 2 @@ -247,7 +280,7 @@ def _get_factory_annotations_or_throw(factory): class ActivationScope: - __slots__ = ("scoped_services", "provider") + __slots__ = ("scoped_services", "provider", "_exit_stack") def __init__( self, @@ -256,6 +289,7 @@ def __init__( ): self.provider = provider or Services() self.scoped_services = scoped_services or {} + self._exit_stack: ExitStack | None = None def __enter__(self): if self.scoped_services is None: @@ -263,6 +297,9 @@ def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): + if self._exit_stack is not None: + exit_stack, self._exit_stack = self._exit_stack, None + exit_stack.__exit__(exc_type, exc_val, exc_tb) self.dispose() def get( @@ -276,6 +313,25 @@ def get( raise TypeError("This scope is disposed.") return self.provider.get(desired_type, scope or self, default=default) + def _enter_context(self, instance): + if not (hasattr(instance, "__enter__") and hasattr(instance, "__exit__")): + if hasattr(instance, "__aenter__") and hasattr(instance, "__aexit__"): + raise TypeError( + f"{type(instance).__name__} is registered with manage_context=True " + f"but does not implement the synchronous context manager protocol. " + f"Use AsyncActivationScope to manage async context managers." + ) + raise TypeError( + f"{type(instance).__name__} is registered with manage_context=True " + f"but does not implement the context manager protocol. " + ) + self._register_sync_context(instance) + + def _register_sync_context(self, instance): + if self._exit_stack is None: + self._exit_stack = ExitStack() + self._exit_stack.enter_context(instance) + def dispose(self): if self.provider: self.provider = None @@ -320,13 +376,86 @@ def __exit__(self, exc_type, exc_val, exc_tb): # Pop this scope from the stack stack = self._active_scopes.get() self._active_scopes.set(stack[:-1]) - self.dispose() + super().__exit__(exc_type, exc_val, exc_tb) def dispose(self): if self.provider: self.provider = None +class AsyncActivationScope(ActivationScope): + """ + ActivationScope that supports both sync and async context-managed services. + Use `async with provider.create_async_scope() as scope` and + `await scope.aget(T)` to resolve services whose registrations use + `manage_context=True` and that implement `__aenter__`/`__aexit__`. + """ + + __slots__ = ("_async_exit_stack", "_async_mode", "_pending_aenter") + + def __init__(self, provider=None, scoped_services=None): + super().__init__(provider, scoped_services) + self._async_exit_stack: AsyncExitStack | None = None + self._async_mode: bool = False + self._pending_aenter: list = [] + + async def __aenter__(self): + if self.scoped_services is None: + self.scoped_services = {} + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if self._async_exit_stack is not None: + stack, self._async_exit_stack = self._async_exit_stack, None + await stack.__aexit__(exc_type, exc_val, exc_tb) + self.dispose() + + def _enter_context(self, instance): + if self._async_mode: + self._pending_aenter.append(instance) + else: + super()._enter_context(instance) + + def _register_sync_context(self, instance): + if self._async_exit_stack is None: + self._async_exit_stack = AsyncExitStack() + self._async_exit_stack.enter_context(instance) + + async def aget( + self, + desired_type: Type[T] | str, + *, + default: Any = ..., + ) -> T: + if self.provider is None: + raise TypeError("This scope is disposed.") + + self._async_mode = True + try: + instance = self.provider.get(desired_type, self, default=default) + finally: + self._async_mode = False + + if self._pending_aenter: + if self._async_exit_stack is None: + self._async_exit_stack = AsyncExitStack() + + pending = self._pending_aenter + self._pending_aenter = [] + for obj in pending: + if hasattr(obj, "__aenter__") and hasattr(obj, "__aexit__"): + await self._async_exit_stack.enter_async_context(obj) + elif hasattr(obj, "__enter__") and hasattr(obj, "__exit__"): + self._async_exit_stack.enter_context(obj) + else: + raise TypeError( + f"{type(obj).__name__} is registered with " + f"manage_context=True but does not implement any " + f"context manager protocol." + ) + return instance + + class ResolutionContext: __slots__ = ("resolved", "dynamic_chain") __deletable__ = ("resolved",) @@ -469,11 +598,40 @@ def __call__(self, context, parent_type): return self._instance +class ManagedScopedProvider: + __slots__ = ("_type", "_inner") + + def __init__(self, _type, inner): + self._type = _type + self._inner = inner + + def __call__(self, context: ActivationScope, parent_type): + if self._type in context.scoped_services: + return context.scoped_services[self._type] + instance = self._inner(context, parent_type) + context._enter_context(instance) + context.scoped_services[self._type] = instance + return instance + + +class ManagedTransientProvider: + __slots__ = ("_inner",) + + def __init__(self, inner): + self._inner = inner + + def __call__(self, context: ActivationScope, parent_type): + instance = self._inner(context, parent_type) + context._enter_context(instance) + return instance + + def get_annotations_type_provider( concrete_type: Type, resolvers: Mapping[str, Callable], life_style: ServiceLifeStyle, resolver_context: ResolutionContext, + manage_context: bool = False, ): def factory(context, parent_type): instance = concrete_type() @@ -481,7 +639,9 @@ def factory(context, parent_type): setattr(instance, name, resolver(context, parent_type)) return instance - return FactoryResolver(concrete_type, factory, life_style)(resolver_context) + return FactoryResolver(concrete_type, factory, life_style, manage_context)( + resolver_context + ) def get_mixed_type_provider( @@ -490,6 +650,7 @@ def get_mixed_type_provider( annotation_resolvers: Mapping[str, Callable], life_style: ServiceLifeStyle, resolver_context: ResolutionContext, + manage_context: bool = False, ): """ Provider that combines __init__ argument injection with class-level annotation @@ -503,7 +664,9 @@ def factory(context, parent_type): setattr(instance, name, resolver(context, parent_type)) return instance - return FactoryResolver(concrete_type, factory, life_style)(resolver_context) + return FactoryResolver(concrete_type, factory, life_style, manage_context)( + resolver_context + ) def _get_plain_class_factory(concrete_type: Type): @@ -535,15 +698,19 @@ def __init__(self, name, annotation): class DynamicResolver: - __slots__ = ("_concrete_type", "services", "life_style") + __slots__ = ("_concrete_type", "services", "life_style", "manage_context") - def __init__(self, concrete_type, services, life_style): + def __init__(self, concrete_type, services, life_style, manage_context=False): assert isclass(concrete_type) assert not isabstract(concrete_type) + if manage_context and life_style is ServiceLifeStyle.SINGLETON: + raise InvalidContextManagerRegistration(concrete_type) + self._concrete_type = concrete_type self.services = services self.life_style = life_style + self.manage_context = manage_context @property def concrete_type(self) -> Type: @@ -633,8 +800,14 @@ def _resolve_by_init_method(self, context: ResolutionContext): return SingletonTypeProvider(concrete_type, None) if self.life_style == ServiceLifeStyle.SCOPED: + if self.manage_context: + return ManagedScopedProvider( + concrete_type, TypeProvider(concrete_type) + ) return ScopedTypeProvider(concrete_type) + if self.manage_context: + return ManagedTransientProvider(TypeProvider(concrete_type)) return TypeProvider(concrete_type) fns = self._get_resolvers_for_parameters(concrete_type, context, params) @@ -643,8 +816,14 @@ def _resolve_by_init_method(self, context: ResolutionContext): return SingletonTypeProvider(concrete_type, fns) if self.life_style == ServiceLifeStyle.SCOPED: + if self.manage_context: + return ManagedScopedProvider( + concrete_type, ArgsTypeProvider(concrete_type, fns) + ) return ScopedArgsTypeProvider(concrete_type, fns) + if self.manage_context: + return ManagedTransientProvider(ArgsTypeProvider(concrete_type, fns)) return ArgsTypeProvider(concrete_type, fns) def _ignore_class_attribute(self, key: str, value) -> bool: @@ -685,7 +864,7 @@ def _resolve_by_annotations( resolvers[name] = fns[i] return get_annotations_type_provider( - self.concrete_type, resolvers, self.life_style, context + self.concrete_type, resolvers, self.life_style, context, self.manage_context ) def _resolve_by_init_and_annotations( @@ -727,7 +906,12 @@ def _resolve_by_init_and_annotations( } return get_mixed_type_provider( - concrete_type, init_fns, annotation_resolvers, self.life_style, context + concrete_type, + init_fns, + annotation_resolvers, + self.life_style, + context, + self.manage_context, ) def __call__(self, context: ResolutionContext): @@ -752,7 +936,10 @@ def __call__(self, context: ResolutionContext): raise CircularDependencyException(chain[0], concrete_type) return FactoryResolver( - concrete_type, _get_plain_class_factory(concrete_type), self.life_style + concrete_type, + _get_plain_class_factory(concrete_type), + self.life_style, + self.manage_context, )(context) # Custom __init__: also check for class-level annotations to inject as @@ -791,20 +978,32 @@ def __call__(self, context: ResolutionContext): class FactoryResolver: - __slots__ = ("concrete_type", "factory", "params", "life_style") + __slots__ = ("concrete_type", "factory", "params", "life_style", "manage_context") - def __init__(self, concrete_type, factory, life_style): + def __init__(self, concrete_type, factory, life_style, manage_context=False): + if manage_context and life_style is ServiceLifeStyle.SINGLETON: + raise InvalidContextManagerRegistration(concrete_type) self.factory = factory self.concrete_type = concrete_type self.life_style = life_style + self.manage_context = manage_context def __call__(self, context: ResolutionContext): if self.life_style == ServiceLifeStyle.SINGLETON: return SingletonFactoryTypeProvider(self.concrete_type, self.factory) if self.life_style == ServiceLifeStyle.SCOPED: + if self.manage_context: + return ManagedScopedProvider( + self.concrete_type, + FactoryTypeProvider(self.concrete_type, self.factory), + ) return ScopedFactoryTypeProvider(self.concrete_type, self.factory) + if self.manage_context: + return ManagedTransientProvider( + FactoryTypeProvider(self.concrete_type, self.factory) + ) return FactoryTypeProvider(self.concrete_type, self.factory) @@ -988,6 +1187,11 @@ def create_scope( ) -> ActivationScope: return self._scope_cls(self, scoped) + def create_async_scope( + self, scoped: dict[Type | str, Any] | None = None + ) -> "AsyncActivationScope": + return AsyncActivationScope(self, scoped) + def set(self, new_type: Type | str, value: Any): """ Sets a new service of desired type, as singleton. @@ -1037,6 +1241,21 @@ def get( return cast(T, scoped_service or resolver(scope, desired_type)) + async def aget( + self, + desired_type: Type[T] | str, + scope: "AsyncActivationScope | None" = None, + *, + default: Any = ..., + ) -> T: + """ + Awaitable counterpart of `get` that supports async context-managed + services. Requires (and creates if missing) an `AsyncActivationScope`. + """ + if scope is None: + scope = self.create_async_scope() + return await scope.aget(desired_type, default=default) + def _get_getter(self, key, param): if param.annotation is _empty: @@ -1130,7 +1349,7 @@ def __call__(self, context, activating_type): _ContainerSelf = TypeVar("_ContainerSelf", bound="Container") -class Container(ContainerProtocol): +class Container(ContainerProtocol, AsyncContainerProtocol): """ Configuration class for a collection of services. """ @@ -1167,6 +1386,7 @@ def bind_types( obj_type: Any, concrete_type: Any = None, life_style: ServiceLifeStyle = ServiceLifeStyle.TRANSIENT, + manage_context: bool = False, ) -> _ContainerSelf: try: assert issubclass(concrete_type, obj_type), ( @@ -1176,7 +1396,10 @@ def bind_types( except TypeError: # ignore, this happens with generic types pass - self._bind(obj_type, DynamicResolver(concrete_type, self, life_style)) + self._bind( + obj_type, + DynamicResolver(concrete_type, self, life_style, manage_context), + ) return self def register( @@ -1190,14 +1413,16 @@ def register( """ Registers a type in this container. """ + manage_context = kwargs.pop("manage_context", False) + if instance is not None: self.add_instance(instance, declared_class=obj_type) return self if sub_type is None: - self._add_exact_transient(obj_type) + self._add_exact_transient(obj_type, manage_context=manage_context) else: - self.add_transient(obj_type, sub_type) + self.add_transient(obj_type, sub_type, manage_context=manage_context) return self @overload @@ -1230,6 +1455,45 @@ def resolve( """ return self.provider.get(obj_type, scope=scope) + @overload + async def aresolve( + self, + obj_type: Type[T], + scope: AsyncActivationScope, + *args, + **kwargs, + ) -> T: ... + + @overload + async def aresolve( + self, + obj_type: str, + scope: AsyncActivationScope, + *args, + **kwargs, + ) -> Any: ... + + async def aresolve( + self, + obj_type: Type[T] | str, + scope: AsyncActivationScope, + *args, + **kwargs, + ) -> Any: + """ + Asynchronous counterpart to `resolve`. Requires an explicit + `AsyncActivationScope` so that any sync or async context-managed + services in the resolved graph are entered into the scope's exit + stack and exited in LIFO order on `__aexit__`. + """ + if not isinstance(scope, AsyncActivationScope): + raise TypeError( + "aresolve requires an AsyncActivationScope. Use `resolve` for " + "synchronous resolution, or pass an instance returned by " + "`provider.create_async_scope()`." + ) + return await scope.aget(obj_type) + def add_alias( self: _ContainerSelf, name: str, @@ -1331,7 +1595,11 @@ def add_instance( return self def add_singleton( - self: _ContainerSelf, base_type: Type, concrete_type: Type | None = None + self: _ContainerSelf, + base_type: Type, + concrete_type: Type | None = None, + *, + manage_context: bool = False, ) -> _ContainerSelf: """ Registers a type by base type, to be instantiated with singleton lifetime. @@ -1340,8 +1608,12 @@ def add_singleton( :param base_type: registered type. If a concrete type is provided, it must inherit the base type. :param concrete_type: concrete class + :param manage_context: not supported for singletons. :return: the service collection itself """ + if manage_context: + raise InvalidContextManagerRegistration(concrete_type or base_type) + if concrete_type is None: return self._add_exact_singleton(base_type) @@ -1351,6 +1623,8 @@ def add_scoped( self: _ContainerSelf, base_type: Type, concrete_type: Type | None = None, + *, + manage_context: bool = False, ) -> _ContainerSelf: """ Registers a type by base type, to be instantiated with scoped lifetime. @@ -1359,17 +1633,23 @@ def add_scoped( :param base_type: registered type. If a concrete type is provided, it must inherit the base type. :param concrete_type: concrete class + :param manage_context: when True, rodi enters/exits the resolved instance + as a context manager bound to the activation scope. :return: the service collection itself """ if concrete_type is None: - return self._add_exact_scoped(base_type) + return self._add_exact_scoped(base_type, manage_context=manage_context) - return self.bind_types(base_type, concrete_type, ServiceLifeStyle.SCOPED) + return self.bind_types( + base_type, concrete_type, ServiceLifeStyle.SCOPED, manage_context + ) def add_transient( self: _ContainerSelf, base_type: Type, concrete_type: Type | None = None, + *, + manage_context: bool = False, ) -> _ContainerSelf: """ Registers a type by base type, to be instantiated with transient lifetime. @@ -1378,12 +1658,16 @@ def add_transient( :param base_type: registered type. If a concrete type is provided, it must inherit the base type. :param concrete_type: concrete class + :param manage_context: when True, rodi enters/exits each resolved instance + as a context manager bound to the activation scope. :return: the service collection itself """ if concrete_type is None: - return self._add_exact_transient(base_type) + return self._add_exact_transient(base_type, manage_context=manage_context) - return self.bind_types(base_type, concrete_type, ServiceLifeStyle.TRANSIENT) + return self.bind_types( + base_type, concrete_type, ServiceLifeStyle.TRANSIENT, manage_context + ) def decorate( self: _ContainerSelf, @@ -1433,7 +1717,12 @@ def _add_exact_singleton( ) return self - def _add_exact_scoped(self: _ContainerSelf, concrete_type: Type) -> _ContainerSelf: + def _add_exact_scoped( + self: _ContainerSelf, + concrete_type: Type, + *, + manage_context: bool = False, + ) -> _ContainerSelf: """ Registers an exact type, to be instantiated with scoped lifetime. @@ -1442,12 +1731,18 @@ def _add_exact_scoped(self: _ContainerSelf, concrete_type: Type) -> _ContainerSe """ assert not isabstract(concrete_type) self._bind( - concrete_type, DynamicResolver(concrete_type, self, ServiceLifeStyle.SCOPED) + concrete_type, + DynamicResolver( + concrete_type, self, ServiceLifeStyle.SCOPED, manage_context + ), ) return self def _add_exact_transient( - self: _ContainerSelf, concrete_type: Type + self: _ContainerSelf, + concrete_type: Type, + *, + manage_context: bool = False, ) -> _ContainerSelf: """ Registers an exact type, to be instantiated with transient lifetime. @@ -1458,7 +1753,9 @@ def _add_exact_transient( assert not isabstract(concrete_type) self._bind( concrete_type, - DynamicResolver(concrete_type, self, ServiceLifeStyle.TRANSIENT), + DynamicResolver( + concrete_type, self, ServiceLifeStyle.TRANSIENT, manage_context + ), ) return self @@ -1474,16 +1771,24 @@ def add_transient_by_factory( self: _ContainerSelf, factory: FactoryCallableType, return_type: Type | None = None, + *, + manage_context: bool = False, ) -> _ContainerSelf: - self.register_factory(factory, return_type, ServiceLifeStyle.TRANSIENT) + self.register_factory( + factory, return_type, ServiceLifeStyle.TRANSIENT, manage_context + ) return self def add_scoped_by_factory( self: _ContainerSelf, factory: FactoryCallableType, return_type: Type | None = None, + *, + manage_context: bool = False, ) -> _ContainerSelf: - self.register_factory(factory, return_type, ServiceLifeStyle.SCOPED) + self.register_factory( + factory, return_type, ServiceLifeStyle.SCOPED, manage_context + ) return self @staticmethod @@ -1508,6 +1813,7 @@ def register_factory( factory: Callable, return_type: Type | None, life_style: ServiceLifeStyle, + manage_context: bool = False, ) -> None: if not callable(factory): raise InvalidFactory(return_type) @@ -1526,7 +1832,10 @@ def register_factory( self._bind( return_type, # type: ignore FactoryResolver( - return_type, self._check_factory(factory, sign, return_type), life_style + return_type, + self._check_factory(factory, sign, return_type), + life_style, + manage_context, ), ) diff --git a/tests/test_context_managers.py b/tests/test_context_managers.py new file mode 100644 index 0000000..45fae20 --- /dev/null +++ b/tests/test_context_managers.py @@ -0,0 +1,542 @@ +import pytest +from pytest import raises + +from rodi import ( + Container, + InvalidContextManagerRegistration, + ServiceLifeStyle, +) + + +class SyncCM: + def __init__(self) -> None: + self.entered = False + self.exited = False + self.exit_exc: BaseException | None = None + + def __enter__(self) -> "SyncCM": + self.entered = True + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.exited = True + self.exit_exc = exc_val + + +class SyncCMDep: + def __init__(self) -> None: + self.entered = False + self.exited = False + + def __enter__(self) -> "SyncCMDep": + self.entered = True + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.exited = True + + +class FailingEnter: + def __init__(self, dep: SyncCMDep) -> None: + self.dep = dep + + def __enter__(self): + raise RuntimeError("boom") + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + pass + + +class NotACM: + pass + + +class AsyncCM: + def __init__(self) -> None: + self.aentered = False + self.aexited = False + + async def __aenter__(self) -> "AsyncCM": + self.aentered = True + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + self.aexited = True + + +def test_default_is_unmanaged(): + container = Container() + container.add_scoped(SyncCM) + provider = container.build_provider() + + with provider.create_scope() as scope: + instance = scope.get(SyncCM) + assert instance.entered is False + + assert instance.exited is False + + +def test_scoped_managed_enter_and_exit(): + container = Container() + container.add_scoped(SyncCM, manage_context=True) + provider = container.build_provider() + + with provider.create_scope() as scope: + a = scope.get(SyncCM) + b = scope.get(SyncCM) + assert a is b + assert a.entered is True + assert a.exited is False + + assert a.exited is True + + +def test_transient_managed_enter_per_instance(): + container = Container() + container.add_transient(SyncCM, manage_context=True) + provider = container.build_provider() + + instances: list[SyncCM] = [] + with provider.create_scope() as scope: + for _ in range(3): + instances.append(scope.get(SyncCM)) + + assert len(set(id(i) for i in instances)) == 3 + assert all(i.entered for i in instances) + assert not any(i.exited for i in instances) + + assert all(i.exited for i in instances) + + +_lifo_order: list[str] = [] + + +class _Inner: + def __enter__(self): + _lifo_order.append("enter inner") + return self + + def __exit__(self, *a): + _lifo_order.append("exit inner") + + +class _Outer: + def __init__(self, inner: _Inner) -> None: + self.inner = inner + + def __enter__(self): + _lifo_order.append("enter outer") + return self + + def __exit__(self, *a): + _lifo_order.append("exit outer") + + +def test_lifo_exit_order(): + _lifo_order.clear() + + container = Container() + container.add_scoped(_Inner, manage_context=True) + container.add_scoped(_Outer, manage_context=True) + provider = container.build_provider() + + with provider.create_scope() as scope: + scope.get(_Outer) + + assert _lifo_order == ["enter inner", "enter outer", "exit outer", "exit inner"] + + +def test_exit_receives_exception_when_block_raises(): + container = Container() + container.add_scoped(SyncCM, manage_context=True) + provider = container.build_provider() + + instance: SyncCM | None = None + with raises(RuntimeError, match="oops"): + with provider.create_scope() as scope: + instance = scope.get(SyncCM) + raise RuntimeError("oops") + + assert instance is not None + assert instance.exited is True + assert isinstance(instance.exit_exc, RuntimeError) + + +def test_failure_in_enter_unwinds_already_entered(): + container = Container() + container.add_scoped(SyncCMDep, manage_context=True) + container.add_scoped(FailingEnter, manage_context=True) + provider = container.build_provider() + + dep: SyncCMDep | None = None + with raises(RuntimeError, match="boom"): + with provider.create_scope() as scope: + dep = scope.get(SyncCMDep) + scope.get(FailingEnter) + + assert dep is not None + assert dep.entered is True + assert dep.exited is True + + +def test_resolve_non_cm_with_manage_context_raises(): + container = Container() + container.add_scoped(NotACM, manage_context=True) + provider = container.build_provider() + + with provider.create_scope() as scope: + with raises(TypeError, match="context manager protocol"): + scope.get(NotACM) + + +def test_singleton_with_manage_context_rejected_in_add_singleton(): + container = Container() + with raises(InvalidContextManagerRegistration): + container.add_singleton(SyncCM, manage_context=True) + + +def test_singleton_with_manage_context_rejected_via_factory(): + def factory() -> SyncCM: + return SyncCM() + + container = Container() + with raises(InvalidContextManagerRegistration): + container.register_factory( + factory, SyncCM, ServiceLifeStyle.SINGLETON, manage_context=True + ) + + +def test_factory_registration_managed_scoped(): + container = Container() + container.add_scoped_by_factory( + lambda: SyncCMDep(), SyncCMDep, manage_context=True + ) + provider = container.build_provider() + + with provider.create_scope() as scope: + instance = scope.get(SyncCMDep) + assert instance.entered is True + + assert instance.exited is True + + +def test_register_kwarg_manage_context(): + container = Container() + container.register(SyncCM, manage_context=True) + provider = container.build_provider() + + with provider.create_scope() as scope: + a = scope.get(SyncCM) + b = scope.get(SyncCM) + assert a is not b + assert a.entered and b.entered + + assert a.exited and b.exited + + +@pytest.mark.asyncio +async def test_async_scope_manages_async_cm(): + container = Container() + container.add_scoped(AsyncCM, manage_context=True) + provider = container.build_provider() + + async with provider.create_async_scope() as scope: + instance = await scope.aget(AsyncCM) + assert instance.aentered is True + assert instance.aexited is False + + assert instance.aexited is True + + +@pytest.mark.asyncio +async def test_async_scope_managed_transient(): + container = Container() + container.add_transient(AsyncCM, manage_context=True) + provider = container.build_provider() + + instances: list[AsyncCM] = [] + async with provider.create_async_scope() as scope: + for _ in range(3): + instances.append(await scope.aget(AsyncCM)) + assert all(i.aentered for i in instances) + assert not any(i.aexited for i in instances) + + assert all(i.aexited for i in instances) + + +_async_lifo_order: list[str] = [] + + +class _AsyncInner: + async def __aenter__(self): + _async_lifo_order.append("enter inner") + return self + + async def __aexit__(self, *a): + _async_lifo_order.append("exit inner") + + +class _AsyncOuter: + def __init__(self, inner: _AsyncInner) -> None: + self.inner = inner + + async def __aenter__(self): + _async_lifo_order.append("enter outer") + return self + + async def __aexit__(self, *a): + _async_lifo_order.append("exit outer") + + +@pytest.mark.asyncio +async def test_async_scope_lifo_exit_order(): + _async_lifo_order.clear() + + container = Container() + container.add_scoped(_AsyncInner, manage_context=True) + container.add_scoped(_AsyncOuter, manage_context=True) + provider = container.build_provider() + + async with provider.create_async_scope() as scope: + await scope.aget(_AsyncOuter) + + assert _async_lifo_order == [ + "enter inner", + "enter outer", + "exit outer", + "exit inner", + ] + + +_mixed_order: list[str] = [] + + +class _MixedSync: + def __enter__(self): + _mixed_order.append("enter sync") + return self + + def __exit__(self, *a): + _mixed_order.append("exit sync") + + +class _MixedAsync: + def __init__(self, dep: _MixedSync) -> None: + self.dep = dep + + async def __aenter__(self): + _mixed_order.append("enter async") + return self + + async def __aexit__(self, *a): + _mixed_order.append("exit async") + + +@pytest.mark.asyncio +async def test_async_scope_mixed_sync_and_async_cms(): + _mixed_order.clear() + + container = Container() + container.add_scoped(_MixedSync, manage_context=True) + container.add_scoped(_MixedAsync, manage_context=True) + provider = container.build_provider() + + async with provider.create_async_scope() as scope: + await scope.aget(_MixedAsync) + + assert _mixed_order == ["enter sync", "enter async", "exit async", "exit sync"] + + +@pytest.mark.asyncio +async def test_async_scope_exit_receives_exception(): + container = Container() + container.add_scoped(AsyncCM, manage_context=True) + provider = container.build_provider() + + instance: AsyncCM | None = None + with raises(RuntimeError, match="aboom"): + async with provider.create_async_scope() as scope: + instance = await scope.aget(AsyncCM) + raise RuntimeError("aboom") + + assert instance is not None + assert instance.aexited is True + + +def test_sync_scope_rejects_async_only_cm(): + container = Container() + container.add_scoped(AsyncCM, manage_context=True) + provider = container.build_provider() + + with provider.create_scope() as scope: + with raises(TypeError, match="synchronous context manager protocol"): + scope.get(AsyncCM) + + +@pytest.mark.asyncio +async def test_existing_unmanaged_scope_unchanged_async(): + container = Container() + container.add_scoped(AsyncCM) + provider = container.build_provider() + + async with provider.create_async_scope() as scope: + instance = await scope.aget(AsyncCM) + assert instance.aentered is False + + assert instance.aexited is False + + +_cross_protocol_order: list[str] = [] + + +class _CrossSyncCM: + def __enter__(self): + _cross_protocol_order.append("enter sync") + return self + + def __exit__(self, *a): + _cross_protocol_order.append("exit sync") + + +class _CrossAsyncCM: + async def __aenter__(self): + _cross_protocol_order.append("enter async") + return self + + async def __aexit__(self, *a): + _cross_protocol_order.append("exit async") + + +@pytest.mark.asyncio +async def test_async_scope_cross_protocol_lifo(): + _cross_protocol_order.clear() + + container = Container() + container.add_scoped(_CrossAsyncCM, manage_context=True) + container.add_scoped(_CrossSyncCM, manage_context=True) + provider = container.build_provider() + + async with provider.create_async_scope() as scope: + await scope.aget(_CrossAsyncCM) + scope.get(_CrossSyncCM) + + assert _cross_protocol_order == [ + "enter async", + "enter sync", + "exit sync", + "exit async", + ] + + +@pytest.mark.asyncio +async def test_async_scope_cross_protocol_lifo_reverse(): + _cross_protocol_order.clear() + + container = Container() + container.add_scoped(_CrossAsyncCM, manage_context=True) + container.add_scoped(_CrossSyncCM, manage_context=True) + provider = container.build_provider() + + async with provider.create_async_scope() as scope: + scope.get(_CrossSyncCM) + await scope.aget(_CrossAsyncCM) + + assert _cross_protocol_order == [ + "enter sync", + "enter async", + "exit async", + "exit sync", + ] + + +class _DecoratedInner: + def __init__(self) -> None: + self.entered = False + self.exited = False + + def __enter__(self): + self.entered = True + return self + + def __exit__(self, *a): + self.exited = True + + +class _SyncDecorator: + def __init__(self, inner: _DecoratedInner) -> None: + self.inner = inner + + +def test_decorate_inner_managed_scoped(): + container = Container() + container.add_scoped(_DecoratedInner, manage_context=True) + container.decorate(_DecoratedInner, _SyncDecorator) + provider = container.build_provider() + + with provider.create_scope() as scope: + result = scope.get(_DecoratedInner) + assert isinstance(result, _SyncDecorator) + assert isinstance(result.inner, _DecoratedInner) + assert result.inner.entered is True + + assert result.inner.exited is True + + +@pytest.mark.asyncio +async def test_aresolve_async_cm_with_async_scope(): + container = Container() + container.add_scoped(AsyncCM, manage_context=True) + provider = container.build_provider() + + async with provider.create_async_scope() as scope: + instance = await container.aresolve(AsyncCM, scope) + assert instance.aentered is True + assert instance.aexited is False + + assert instance.aexited is True + + +@pytest.mark.asyncio +async def test_aresolve_managed_sync_cm_with_async_scope(): + container = Container() + container.add_scoped(SyncCM, manage_context=True) + provider = container.build_provider() + + async with provider.create_async_scope() as scope: + instance = await container.aresolve(SyncCM, scope) + assert instance.entered is True + assert instance.exited is False + + assert instance.exited is True + + +@pytest.mark.asyncio +async def test_aresolve_unmanaged_service(): + container = Container() + container.add_scoped(NotACM) + provider = container.build_provider() + + async with provider.create_async_scope() as scope: + instance = await container.aresolve(NotACM, scope) + assert isinstance(instance, NotACM) + + +@pytest.mark.asyncio +async def test_aresolve_with_sync_scope_raises(): + container = Container() + container.add_scoped(AsyncCM, manage_context=True) + provider = container.build_provider() + + with provider.create_scope() as sync_scope: + with raises(TypeError, match="aresolve requires an AsyncActivationScope"): + await container.aresolve(AsyncCM, sync_scope) # type: ignore[arg-type] + + +@pytest.mark.asyncio +async def test_aresolve_with_none_scope_raises(): + container = Container() + container.add_scoped(AsyncCM, manage_context=True) + container.build_provider() + + with raises(TypeError, match="aresolve requires an AsyncActivationScope"): + await container.aresolve(AsyncCM, None) # type: ignore[arg-type]