diff --git a/openfeature/_api.py b/openfeature/_api.py new file mode 100644 index 00000000..c9736350 --- /dev/null +++ b/openfeature/_api.py @@ -0,0 +1,135 @@ +from __future__ import annotations + +from openfeature._event_support import EventSupport +from openfeature.client import OpenFeatureClient +from openfeature.evaluation_context import EvaluationContext +from openfeature.event import EventHandler, ProviderEvent +from openfeature.exception import GeneralError +from openfeature.hook import Hook +from openfeature.provider import FeatureProvider, ProviderStatus +from openfeature.provider._registry import ProviderRegistry +from openfeature.provider.metadata import Metadata +from openfeature.transaction_context import TransactionContextPropagator +from openfeature.transaction_context.no_op_transaction_context_propagator import ( + NoOpTransactionContextPropagator, +) + + +class OpenFeatureAPI: + """An independent OpenFeature API instance with its own isolated state. + + Each instance maintains its own providers, evaluation context, hooks, + event handlers, and transaction context propagator; fully separate from + the global singleton and from other instances. + """ + + def __init__(self) -> None: + self._hooks: list[Hook] = [] + self._evaluation_context = EvaluationContext() + self._transaction_context_propagator: TransactionContextPropagator = ( + NoOpTransactionContextPropagator() + ) + self._event_support = EventSupport() + self._provider_registry = ProviderRegistry( + event_support=self._event_support, + evaluation_context_getter=self.get_evaluation_context, + ) + + # --- Client creation --- + + def get_client( + self, domain: str | None = None, version: str | None = None + ) -> OpenFeatureClient: + return OpenFeatureClient(domain=domain, version=version, api=self) + + # --- Provider management --- + + def set_provider( + self, provider: FeatureProvider, domain: str | None = None + ) -> None: + if domain is None: + self._provider_registry.set_default_provider(provider) + else: + self._provider_registry.set_provider(domain, provider) + + def set_provider_and_wait( + self, provider: FeatureProvider, domain: str | None = None + ) -> None: + if domain is None: + self._provider_registry.set_default_provider(provider, wait_for_init=True) + else: + self._provider_registry.set_provider(domain, provider, wait_for_init=True) + + def get_provider_metadata(self, domain: str | None = None) -> Metadata: + return self._provider_registry.get_provider(domain).get_metadata() + + def get_provider(self, domain: str | None = None) -> FeatureProvider: + return self._provider_registry.get_provider(domain) + + def get_provider_status(self, provider: FeatureProvider) -> ProviderStatus: + return self._provider_registry.get_provider_status(provider) + + def clear_providers(self) -> None: + self._provider_registry.clear_providers() + self._event_support.clear() + + def shutdown(self) -> None: + # shutdown -> remove providers -> set default provider to NoOp -> remove event handlers + self.clear_providers() + # remove hooks + self.clear_hooks() + # set evaluation context to default + self._evaluation_context = EvaluationContext() + # set propagator to NoOp + self._transaction_context_propagator = NoOpTransactionContextPropagator() + + # --- Hooks --- + + def add_hooks(self, hooks: list[Hook]) -> None: + self._hooks = self._hooks + hooks + + def clear_hooks(self) -> None: + self._hooks = [] + + def get_hooks(self) -> list[Hook]: + return self._hooks + + # --- Evaluation context --- + + def get_evaluation_context(self) -> EvaluationContext: + return self._evaluation_context + + def set_evaluation_context(self, evaluation_context: EvaluationContext) -> None: + if evaluation_context is None: + raise GeneralError(error_message="No api level evaluation context") + self._evaluation_context = evaluation_context + + def clear_evaluation_context(self) -> None: + self.set_evaluation_context(EvaluationContext()) + + # --- Transaction context --- + + def set_transaction_context_propagator( + self, transaction_context_propagator: TransactionContextPropagator + ) -> None: + self._transaction_context_propagator = transaction_context_propagator + + def clear_transaction_context_propagator(self) -> None: + self.set_transaction_context_propagator(NoOpTransactionContextPropagator()) + + def get_transaction_context(self) -> EvaluationContext: + return self._transaction_context_propagator.get_transaction_context() + + def set_transaction_context(self, evaluation_context: EvaluationContext) -> None: + self._transaction_context_propagator.set_transaction_context(evaluation_context) + + # --- Event handlers --- + + def add_handler(self, event: ProviderEvent, handler: EventHandler) -> None: + self._event_support.add_global_handler(event, handler, self.get_client) + + def remove_handler(self, event: ProviderEvent, handler: EventHandler) -> None: + self._event_support.remove_global_handler(event, handler) + + +_default_api = OpenFeatureAPI() diff --git a/openfeature/_event_support.py b/openfeature/_event_support.py index 3928be3e..6aa772b9 100644 --- a/openfeature/_event_support.py +++ b/openfeature/_event_support.py @@ -4,6 +4,7 @@ import threading import typing from collections import defaultdict +from collections.abc import Callable from concurrent.futures import ThreadPoolExecutor from logging import getLogger @@ -23,103 +24,6 @@ _event_executor = ThreadPoolExecutor(thread_name_prefix="openfeature-event-handler") atexit.register(_event_executor.shutdown, wait=True) -_global_lock = threading.RLock() -_global_handlers: dict[ProviderEvent, list[EventHandler]] = defaultdict(list) - -_client_lock = threading.RLock() -_client_handlers: dict[OpenFeatureClient, dict[ProviderEvent, list[EventHandler]]] = ( - defaultdict(lambda: defaultdict(list)) -) - - -def run_client_handlers( - client: OpenFeatureClient, event: ProviderEvent, details: EventDetails -) -> None: - with _client_lock: - handlers_by_event = _client_handlers.get(client) - if handlers_by_event is None: - return - - handlers = tuple(handlers_by_event.get(event, ())) - - for handler in handlers: - _submit_handler(handler, details) - - -def run_global_handlers(event: ProviderEvent, details: EventDetails) -> None: - with _global_lock: - handlers = tuple(_global_handlers.get(event, ())) - - for handler in handlers: - _submit_handler(handler, details) - - -def add_client_handler( - client: OpenFeatureClient, event: ProviderEvent, handler: EventHandler -) -> None: - with _client_lock: - handlers = _client_handlers[client][event] - handlers.append(handler) - - _run_immediate_handler(client, event, handler) - - -def remove_client_handler( - client: OpenFeatureClient, event: ProviderEvent, handler: EventHandler -) -> None: - with _client_lock: - handlers = _client_handlers[client][event] - handlers.remove(handler) - - -def add_global_handler(event: ProviderEvent, handler: EventHandler) -> None: - with _global_lock: - _global_handlers[event].append(handler) - - from openfeature.api import get_client # noqa: PLC0415 - - _run_immediate_handler(get_client(), event, handler) - - -def remove_global_handler(event: ProviderEvent, handler: EventHandler) -> None: - with _global_lock: - _global_handlers[event].remove(handler) - - -def run_handlers_for_provider( - provider: FeatureProvider, - event: ProviderEvent, - provider_details: ProviderEventDetails, -) -> None: - details = EventDetails.from_provider_event_details( - provider.get_metadata().name, provider_details - ) - # run the global handlers - run_global_handlers(event, details) - # run the handlers for clients associated to this provider - with _client_lock: - clients = tuple( - client for client in _client_handlers if client.provider == provider - ) - - for client in clients: - run_client_handlers(client, event, details) - - -def _run_immediate_handler( - client: OpenFeatureClient, event: ProviderEvent, handler: EventHandler -) -> None: - status_to_event = { - ProviderStatus.READY: ProviderEvent.PROVIDER_READY, - ProviderStatus.ERROR: ProviderEvent.PROVIDER_ERROR, - ProviderStatus.FATAL: ProviderEvent.PROVIDER_ERROR, - ProviderStatus.STALE: ProviderEvent.PROVIDER_STALE, - } - if event == status_to_event.get(client.get_provider_status()): - _submit_handler( - handler, EventDetails(provider_name=client.provider.get_metadata().name) - ) - def _submit_handler(handler: EventHandler, details: EventDetails) -> None: _event_executor.submit(_run_handler, handler, details) @@ -132,8 +36,112 @@ def _run_handler(handler: EventHandler, details: EventDetails) -> None: logger.exception("Unhandled exception in OpenFeature event handler") -def clear() -> None: - with _global_lock: - _global_handlers.clear() - with _client_lock: - _client_handlers.clear() +class EventSupport: + """Per-API-instance event handler storage and dispatch.""" + + def __init__(self) -> None: + self._global_lock = threading.RLock() + self._global_handlers: dict[ProviderEvent, list[EventHandler]] = defaultdict( + list + ) + + self._client_lock = threading.RLock() + self._client_handlers: dict[ + OpenFeatureClient, dict[ProviderEvent, list[EventHandler]] + ] = defaultdict(lambda: defaultdict(list)) + + def run_client_handlers( + self, client: OpenFeatureClient, event: ProviderEvent, details: EventDetails + ) -> None: + with self._client_lock: + handlers_by_event = self._client_handlers.get(client) + if handlers_by_event is None: + return + + handlers = tuple(handlers_by_event.get(event, ())) + + for handler in handlers: + _submit_handler(handler, details) + + def run_global_handlers(self, event: ProviderEvent, details: EventDetails) -> None: + with self._global_lock: + handlers = tuple(self._global_handlers.get(event, ())) + + for handler in handlers: + _submit_handler(handler, details) + + def add_client_handler( + self, client: OpenFeatureClient, event: ProviderEvent, handler: EventHandler + ) -> None: + with self._client_lock: + handlers = self._client_handlers[client][event] + handlers.append(handler) + + self._run_immediate_handler(client, event, handler) + + def remove_client_handler( + self, client: OpenFeatureClient, event: ProviderEvent, handler: EventHandler + ) -> None: + with self._client_lock: + handlers = self._client_handlers[client][event] + handlers.remove(handler) + + def add_global_handler( + self, + event: ProviderEvent, + handler: EventHandler, + get_client: Callable[[], OpenFeatureClient], + ) -> None: + with self._global_lock: + self._global_handlers[event].append(handler) + + self._run_immediate_handler(get_client(), event, handler) + + def remove_global_handler( + self, event: ProviderEvent, handler: EventHandler + ) -> None: + with self._global_lock: + self._global_handlers[event].remove(handler) + + def run_handlers_for_provider( + self, + provider: FeatureProvider, + event: ProviderEvent, + provider_details: ProviderEventDetails, + ) -> None: + details = EventDetails.from_provider_event_details( + provider.get_metadata().name, provider_details + ) + # run the global handlers + self.run_global_handlers(event, details) + # run the handlers for clients associated to this provider + with self._client_lock: + clients = tuple( + client + for client in self._client_handlers + if client.provider == provider + ) + + for client in clients: + self.run_client_handlers(client, event, details) + + def _run_immediate_handler( + self, client: OpenFeatureClient, event: ProviderEvent, handler: EventHandler + ) -> None: + status_to_event = { + ProviderStatus.READY: ProviderEvent.PROVIDER_READY, + ProviderStatus.ERROR: ProviderEvent.PROVIDER_ERROR, + ProviderStatus.FATAL: ProviderEvent.PROVIDER_ERROR, + ProviderStatus.STALE: ProviderEvent.PROVIDER_STALE, + } + if event == status_to_event.get(client.get_provider_status()): + _submit_handler( + handler, + EventDetails(provider_name=client.provider.get_metadata().name), + ) + + def clear(self) -> None: + with self._global_lock: + self._global_handlers.clear() + with self._client_lock: + self._client_handlers.clear() diff --git a/openfeature/api.py b/openfeature/api.py index 4585e50e..c4ff833f 100644 --- a/openfeature/api.py +++ b/openfeature/api.py @@ -1,30 +1,22 @@ -from openfeature import _event_support +from openfeature._api import _default_api from openfeature.client import OpenFeatureClient -from openfeature.evaluation_context import ( - clear_evaluation_context, - get_evaluation_context, - set_evaluation_context, -) +from openfeature.evaluation_context import EvaluationContext from openfeature.event import ( EventHandler, ProviderEvent, ) -from openfeature.hook import add_hooks, clear_hooks, get_hooks +from openfeature.hook import Hook from openfeature.provider import FeatureProvider -from openfeature.provider._registry import provider_registry from openfeature.provider.metadata import Metadata -from openfeature.transaction_context import ( - clear_transaction_context_propagator, - get_transaction_context, - set_transaction_context, - set_transaction_context_propagator, -) +from openfeature.transaction_context import TransactionContextPropagator __all__ = [ "add_handler", "add_hooks", + "clear_evaluation_context", "clear_hooks", "clear_providers", + "clear_transaction_context_propagator", "get_client", "get_evaluation_context", "get_hooks", @@ -43,46 +35,74 @@ def get_client( domain: str | None = None, version: str | None = None ) -> OpenFeatureClient: - return OpenFeatureClient(domain=domain, version=version) + return _default_api.get_client(domain=domain, version=version) def set_provider(provider: FeatureProvider, domain: str | None = None) -> None: - if domain is None: - provider_registry.set_default_provider(provider) - else: - provider_registry.set_provider(domain, provider) + _default_api.set_provider(provider, domain) def set_provider_and_wait(provider: FeatureProvider, domain: str | None = None) -> None: - if domain is None: - provider_registry.set_default_provider(provider, wait_for_init=True) - else: - provider_registry.set_provider(domain, provider, wait_for_init=True) + _default_api.set_provider_and_wait(provider, domain) def clear_providers() -> None: - provider_registry.clear_providers() - _event_support.clear() + _default_api.clear_providers() def get_provider_metadata(domain: str | None = None) -> Metadata: - return provider_registry.get_provider(domain).get_metadata() + return _default_api.get_provider_metadata(domain) def shutdown() -> None: - # shutdown -> remove providers -> set default provider to NoOp -> remove event handlers - clear_providers() - # remove hooks - clear_hooks() - # set evaluation context to default - clear_evaluation_context() - # set propagator to NoOp - clear_transaction_context_propagator() + _default_api.shutdown() def add_handler(event: ProviderEvent, handler: EventHandler) -> None: - _event_support.add_global_handler(event, handler) + _default_api.add_handler(event, handler) def remove_handler(event: ProviderEvent, handler: EventHandler) -> None: - _event_support.remove_global_handler(event, handler) + _default_api.remove_handler(event, handler) + + +def add_hooks(hooks: list[Hook]) -> None: + _default_api.add_hooks(hooks) + + +def clear_hooks() -> None: + _default_api.clear_hooks() + + +def get_hooks() -> list[Hook]: + return _default_api.get_hooks() + + +def get_evaluation_context() -> EvaluationContext: + return _default_api.get_evaluation_context() + + +def set_evaluation_context(evaluation_context: EvaluationContext) -> None: + _default_api.set_evaluation_context(evaluation_context) + + +def clear_evaluation_context() -> None: + _default_api.clear_evaluation_context() + + +def set_transaction_context_propagator( + transaction_context_propagator: TransactionContextPropagator, +) -> None: + _default_api.set_transaction_context_propagator(transaction_context_propagator) + + +def clear_transaction_context_propagator() -> None: + _default_api.clear_transaction_context_propagator() + + +def get_transaction_context() -> EvaluationContext: + return _default_api.get_transaction_context() + + +def set_transaction_context(evaluation_context: EvaluationContext) -> None: + _default_api.set_transaction_context(evaluation_context) diff --git a/openfeature/client.py b/openfeature/client.py index 95dc5b6d..f9440580 100644 --- a/openfeature/client.py +++ b/openfeature/client.py @@ -4,8 +4,7 @@ from dataclasses import dataclass from itertools import chain -from openfeature import _event_support -from openfeature.evaluation_context import EvaluationContext, get_evaluation_context +from openfeature.evaluation_context import EvaluationContext from openfeature.event import EventHandler, ProviderEvent from openfeature.exception import ( ErrorCode, @@ -23,7 +22,7 @@ FlagValueType, Reason, ) -from openfeature.hook import Hook, HookContext, HookHints, get_hooks +from openfeature.hook import Hook, HookContext, HookHints from openfeature.hook._hook_support import ( after_all_hooks, after_hooks, @@ -31,13 +30,13 @@ error_hooks, ) from openfeature.provider import FeatureProvider, ProviderStatus -from openfeature.provider._registry import provider_registry from openfeature.track import TrackingEventDetails -from openfeature.transaction_context import get_transaction_context + +if typing.TYPE_CHECKING: + from openfeature._api import OpenFeatureAPI __all__ = [ "ClientMetadata", - "OpenFeatureClient", ] logger = logging.getLogger("openfeature") @@ -75,10 +74,19 @@ class ClientMetadata: class OpenFeatureClient: + """Client for evaluating feature flags against a specific OpenFeatureAPI. + + Clients should be obtained via ``OpenFeatureAPI.get_client()`` (or the + module-level ``openfeature.api.get_client()`` for the default API); + direct construction is supported only for advanced use cases and requires + passing the owning ``OpenFeatureAPI`` instance. + """ + def __init__( self, domain: str | None, version: str | None, + api: "OpenFeatureAPI", context: EvaluationContext | None = None, hooks: list[Hook] | None = None, ) -> None: @@ -86,13 +94,14 @@ def __init__( self.version = version self.context = context or EvaluationContext() self.hooks = hooks or [] + self._api = api @property def provider(self) -> FeatureProvider: - return provider_registry.get_provider(self.domain) + return self._api.get_provider(self.domain) def get_provider_status(self) -> ProviderStatus: - return provider_registry.get_provider_status(self.provider) + return self._api.get_provider_status(self.provider) def get_metadata(self) -> ClientMetadata: return ClientMetadata(domain=self.domain) @@ -422,8 +431,8 @@ def _establish_hooks_and_provider( # Merge transaction context into evaluation context before creating hook_context # This ensures hooks have access to the complete context including transaction context merged_eval_context = ( - get_evaluation_context() - .merge(get_transaction_context()) + self._api.get_evaluation_context() + .merge(self._api.get_transaction_context()) .merge(self.context) .merge(evaluation_context) ) @@ -448,7 +457,7 @@ def _establish_hooks_and_provider( ), ) for hook in chain( - get_hooks(), + self._api.get_hooks(), self.hooks, evaluation_hooks, provider.get_provider_hooks(), @@ -951,10 +960,10 @@ def _create_provider_evaluation( return resolution.to_flag_evaluation_details(flag_key) def add_handler(self, event: ProviderEvent, handler: EventHandler) -> None: - _event_support.add_client_handler(self, event, handler) + self._api._event_support.add_client_handler(self, event, handler) def remove_handler(self, event: ProviderEvent, handler: EventHandler) -> None: - _event_support.remove_client_handler(self, event, handler) + self._api._event_support.remove_client_handler(self, event, handler) def track( self, @@ -974,8 +983,8 @@ def track( evaluation_context = EvaluationContext() merged_eval_context = ( - get_evaluation_context() - .merge(get_transaction_context()) + self._api.get_evaluation_context() + .merge(self._api.get_transaction_context()) .merge(self.context) .merge(evaluation_context) ) diff --git a/openfeature/evaluation_context/__init__.py b/openfeature/evaluation_context/__init__.py index 690c63be..d36c577e 100644 --- a/openfeature/evaluation_context/__init__.py +++ b/openfeature/evaluation_context/__init__.py @@ -5,14 +5,7 @@ from dataclasses import dataclass, field from datetime import datetime -from openfeature.exception import GeneralError - -__all__ = [ - "EvaluationContext", - "clear_evaluation_context", - "get_evaluation_context", - "set_evaluation_context", -] +__all__ = ["EvaluationContext"] # https://openfeature.dev/specification/sections/evaluation-context#requirement-312 EvaluationContextAttribute: typing.TypeAlias = ( @@ -39,22 +32,3 @@ def merge(self, ctx2: EvaluationContext) -> EvaluationContext: targeting_key = ctx2.targeting_key or self.targeting_key return EvaluationContext(targeting_key=targeting_key, attributes=attributes) - - -def get_evaluation_context() -> EvaluationContext: - return _evaluation_context - - -def set_evaluation_context(evaluation_context: EvaluationContext) -> None: - global _evaluation_context - if evaluation_context is None: - raise GeneralError(error_message="No api level evaluation context") - _evaluation_context = evaluation_context - - -def clear_evaluation_context() -> None: - set_evaluation_context(EvaluationContext()) - - -# need to be at the bottom, because of the definition order -_evaluation_context = EvaluationContext() diff --git a/openfeature/hook/__init__.py b/openfeature/hook/__init__.py index 247d316b..a9f10976 100644 --- a/openfeature/hook/__init__.py +++ b/openfeature/hook/__init__.py @@ -18,13 +18,8 @@ "HookData", "HookHints", "HookType", - "add_hooks", - "clear_hooks", - "get_hooks", ] -_hooks: list[Hook] = [] - # https://openfeature.dev/specification/sections/hooks/#requirement-461 HookData = MutableMapping[str, typing.Any] @@ -149,17 +144,3 @@ def supports_flag_value_type(self, flag_type: FlagType) -> bool: or not (False) """ return True - - -def add_hooks(hooks: list[Hook]) -> None: - global _hooks - _hooks = _hooks + hooks - - -def clear_hooks() -> None: - global _hooks - _hooks = [] - - -def get_hooks() -> list[Hook]: - return _hooks diff --git a/openfeature/isolated.py b/openfeature/isolated.py new file mode 100644 index 00000000..5ab7ea6d --- /dev/null +++ b/openfeature/isolated.py @@ -0,0 +1,37 @@ +"""Factory for creating isolated OpenFeature API instances. + +Per specification requirement 1.8.3, this module is intentionally separate +from the global singleton ``openfeature.api`` to reduce the risk of +accidentally creating isolated instances when the singleton is appropriate. + +Usage:: + + from openfeature.isolated import create_api + + api = create_api() + api.set_provider(MyProvider()) + client = api.get_client() + +Each instance returned by :func:`create_api` maintains its own providers, +evaluation context, hooks, event handlers, and transaction context propagator +— fully independent from the global singleton and from other instances. + +A single provider instance should not be registered with more than one API +instance simultaneously (spec requirement 1.8.4). +""" + +from openfeature._api import OpenFeatureAPI + +__all__ = ["OpenFeatureAPI", "create_api"] + + +def create_api() -> OpenFeatureAPI: + """Create a new, independent OpenFeature API instance. + + The returned instance is functionally equivalent to the global singleton + (``openfeature.api``), but with completely isolated state. + + Returns: + A new :class:`OpenFeatureAPI` instance. + """ + return OpenFeatureAPI() diff --git a/openfeature/provider/_registry.py b/openfeature/provider/_registry.py index e46caadd..71adb8d4 100644 --- a/openfeature/provider/_registry.py +++ b/openfeature/provider/_registry.py @@ -1,7 +1,11 @@ +from __future__ import annotations + import threading +import typing +import weakref +from collections.abc import Callable -from openfeature._event_support import run_handlers_for_provider -from openfeature.evaluation_context import EvaluationContext, get_evaluation_context +from openfeature.evaluation_context import EvaluationContext from openfeature.event import ( ProviderEvent, ProviderEventDetails, @@ -10,6 +14,39 @@ from openfeature.provider import FeatureProvider, ProviderStatus from openfeature.provider.no_op_provider import NoOpProvider +if typing.TYPE_CHECKING: + from openfeature._event_support import EventSupport + +# spec 1.8.4: provider must not bind to more than one API; we track owning registry per provider, rebinding raises. WeakKeyDictionary lets providers be GC'd +_binding_lock = threading.Lock() +_provider_bindings: weakref.WeakKeyDictionary[FeatureProvider, ProviderRegistry] = ( + weakref.WeakKeyDictionary() +) + + +def _register_binding(provider: FeatureProvider, owner: ProviderRegistry) -> None: + try: + weakref.ref(provider) + except TypeError as exc: + raise TypeError( + f"Provider {type(provider).__name__!r} cannot be tracked because " + "it is not weak-referenceable. If your provider class uses " + "__slots__, add '__weakref__' to the slots list." + ) from exc + with _binding_lock: + existing = _provider_bindings.get(provider) + if existing is not None and existing is not owner: + raise RuntimeError( + "Provider is already bound to another OpenFeature API instance." + ) + _provider_bindings[provider] = owner + + +def _unregister_binding(provider: FeatureProvider, owner: ProviderRegistry) -> None: + with _binding_lock: + if _provider_bindings.get(provider) is owner: + del _provider_bindings[provider] + class ProviderRegistry: _default_provider: FeatureProvider @@ -17,13 +54,19 @@ class ProviderRegistry: _provider_status: dict[FeatureProvider, ProviderStatus] _lock: threading.RLock - def __init__(self) -> None: + def __init__( + self, + event_support: EventSupport, + evaluation_context_getter: Callable[[], EvaluationContext], + ) -> None: self._lock = threading.RLock() self._default_provider = NoOpProvider() self._providers = {} self._provider_status = { self._default_provider: ProviderStatus.READY, } + self._event_support = event_support + self._evaluation_context_getter = evaluation_context_getter def set_provider( self, domain: str, provider: FeatureProvider, wait_for_init: bool = False @@ -33,6 +76,8 @@ def set_provider( if domain is None: raise GeneralError(error_message="No domain") + _register_binding(provider, self) + old_provider: FeatureProvider | None = None needs_init = False with self._lock: @@ -64,6 +109,8 @@ def set_default_provider( if provider is None: raise GeneralError(error_message="No provider") + _register_binding(provider, self) + old_provider: FeatureProvider | None = None needs_init = False with self._lock: @@ -102,7 +149,7 @@ def shutdown(self) -> None: self._shutdown_provider(provider) def _get_evaluation_context(self) -> EvaluationContext: - return get_evaluation_context() + return self._evaluation_context_getter() def _initialize_provider( self, provider: FeatureProvider, wait_for_init: bool @@ -172,11 +219,8 @@ def _run_initialize( def _shutdown_if_unused(self, provider: FeatureProvider) -> None: # only shut down if no longer referenced. shutdown runs on a daemon # thread so a hanging shutdown() cannot block the caller. - with self._lock: - if provider is self._default_provider: - return - if provider in self._providers.values(): - return + if self._is_active(provider): + return thread = threading.Thread( target=self._shutdown_provider, @@ -186,20 +230,25 @@ def _shutdown_if_unused(self, provider: FeatureProvider) -> None: ) thread.start() + def _is_active(self, provider: FeatureProvider) -> bool: + with self._lock: + return ( + provider is self._default_provider + or provider in self._providers.values() + ) + def _shutdown_provider( self, provider: FeatureProvider, abort_if_re_registered: bool = False ) -> None: try: + # abort if re-registered before shutdown() to avoid tearing down the freshly-registered instance + if abort_if_re_registered and self._is_active(provider): + return if hasattr(provider, "shutdown"): provider.shutdown() - # if provider is being re-registered, leave its status and event wiring intact - if abort_if_re_registered: - with self._lock: - if ( - provider is self._default_provider - or provider in self._providers.values() - ): - return + # abort if re-registered during shutdown(); leave status and event wiring intact + if abort_if_re_registered and self._is_active(provider): + return with self._lock: self._provider_status.pop(provider, None) except Exception as err: @@ -212,6 +261,7 @@ def _shutdown_provider( ), ) provider.detach() + _unregister_binding(provider, self) def get_provider_status(self, provider: FeatureProvider) -> ProviderStatus: return self._provider_status.get(provider, ProviderStatus.NOT_READY) @@ -223,7 +273,7 @@ def dispatch_event( details: ProviderEventDetails, ) -> None: self._update_provider_status(provider, event, details) - run_handlers_for_provider(provider, event, details) + self._event_support.run_handlers_for_provider(provider, event, details) def _update_provider_status( self, @@ -243,6 +293,3 @@ def _update_provider_status( else ProviderStatus.ERROR ) self._provider_status[provider] = status - - -provider_registry = ProviderRegistry() diff --git a/openfeature/transaction_context/__init__.py b/openfeature/transaction_context/__init__.py index 15ac7e01..ca711cbf 100644 --- a/openfeature/transaction_context/__init__.py +++ b/openfeature/transaction_context/__init__.py @@ -1,10 +1,6 @@ -from openfeature.evaluation_context import EvaluationContext from openfeature.transaction_context.context_var_transaction_context_propagator import ( ContextVarsTransactionContextPropagator, ) -from openfeature.transaction_context.no_op_transaction_context_propagator import ( - NoOpTransactionContextPropagator, -) from openfeature.transaction_context.transaction_context_propagator import ( TransactionContextPropagator, ) @@ -12,34 +8,4 @@ __all__ = [ "ContextVarsTransactionContextPropagator", "TransactionContextPropagator", - "clear_transaction_context_propagator", - "get_transaction_context", - "set_transaction_context", - "set_transaction_context_propagator", ] - -_evaluation_transaction_context_propagator: TransactionContextPropagator = ( - NoOpTransactionContextPropagator() -) - - -def set_transaction_context_propagator( - transaction_context_propagator: TransactionContextPropagator, -) -> None: - global _evaluation_transaction_context_propagator - _evaluation_transaction_context_propagator = transaction_context_propagator - - -def clear_transaction_context_propagator() -> None: - set_transaction_context_propagator(NoOpTransactionContextPropagator()) - - -def get_transaction_context() -> EvaluationContext: - return _evaluation_transaction_context_propagator.get_transaction_context() - - -def set_transaction_context(evaluation_context: EvaluationContext) -> None: - global _evaluation_transaction_context_propagator - _evaluation_transaction_context_propagator.set_transaction_context( - evaluation_context - ) diff --git a/openfeature/transaction_context/context_var_transaction_context_propagator.py b/openfeature/transaction_context/context_var_transaction_context_propagator.py index 449c67a1..4fb27888 100644 --- a/openfeature/transaction_context/context_var_transaction_context_propagator.py +++ b/openfeature/transaction_context/context_var_transaction_context_propagator.py @@ -7,9 +7,10 @@ class ContextVarsTransactionContextPropagator(TransactionContextPropagator): - _transaction_context_var: ContextVar[EvaluationContext | None] = ContextVar( - "transaction_context", default=None - ) + def __init__(self) -> None: + self._transaction_context_var: ContextVar[EvaluationContext | None] = ( + ContextVar(f"transaction_context_{id(self)}", default=None) + ) def get_transaction_context(self) -> EvaluationContext: context = self._transaction_context_var.get() diff --git a/tests/conftest.py b/tests/conftest.py index 495634c1..1f013650 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,11 +6,8 @@ @pytest.fixture(autouse=True) def clear_providers(): - """ - For tests that use set_provider(), we need to clear the provider to avoid issues - in other tests. - """ - api.clear_providers() + """Fully reset the global default API between tests to avoid cross-test pollution.""" + api.shutdown() @pytest.fixture() diff --git a/tests/features/environment.py b/tests/features/environment.py index 4350ddca..a70bea5e 100644 --- a/tests/features/environment.py +++ b/tests/features/environment.py @@ -1,10 +1,7 @@ from openfeature import api +from openfeature.api import set_transaction_context, set_transaction_context_propagator from openfeature.evaluation_context import EvaluationContext -from openfeature.transaction_context import ( - ContextVarsTransactionContextPropagator, - set_transaction_context, - set_transaction_context_propagator, -) +from openfeature.transaction_context import ContextVarsTransactionContextPropagator def before_scenario(context, scenario): diff --git a/tests/features/steps/context_merging_steps.py b/tests/features/steps/context_merging_steps.py index afd75eb4..c5d488be 100644 --- a/tests/features/steps/context_merging_steps.py +++ b/tests/features/steps/context_merging_steps.py @@ -6,13 +6,11 @@ from behave import given, then, when from openfeature import api +from openfeature.api import set_transaction_context from openfeature.evaluation_context import EvaluationContext from openfeature.flag_evaluation import FlagResolutionDetails, FlagValueType, Reason from openfeature.hook import Hook, HookContext, HookHints from openfeature.provider import AbstractProvider, Metadata -from openfeature.transaction_context import ( - set_transaction_context, -) class RetrievableContextProvider(AbstractProvider): diff --git a/tests/hook/test_hook_data.py b/tests/hook/test_hook_data.py index a7e3e0d3..33327fbb 100644 --- a/tests/hook/test_hook_data.py +++ b/tests/hook/test_hook_data.py @@ -1,7 +1,6 @@ import typing -from openfeature.api import set_provider -from openfeature.client import OpenFeatureClient +from openfeature.api import get_client, set_provider from openfeature.evaluation_context import EvaluationContext from openfeature.flag_evaluation import FlagEvaluationDetails, FlagValueType from openfeature.hook import Hook, HookContext, HookHints @@ -46,7 +45,7 @@ def test_hook_data_is_not_shared_between_hooks(): provider = NoOpProvider() set_provider(provider) - client = OpenFeatureClient(domain=None, version=None) + client = get_client() hook_1 = HookWithData({"key": "value"}) hook_2 = HookWithData({"key": Example()}) diff --git a/tests/provider/test_registry.py b/tests/provider/test_registry.py index f7c55712..7f0328d4 100644 --- a/tests/provider/test_registry.py +++ b/tests/provider/test_registry.py @@ -4,21 +4,30 @@ import pytest +from openfeature._event_support import EventSupport +from openfeature.evaluation_context import EvaluationContext from openfeature.exception import GeneralError, ProviderFatalError from openfeature.provider import ProviderStatus from openfeature.provider._registry import ProviderRegistry from openfeature.provider.no_op_provider import NoOpProvider +def make_registry() -> ProviderRegistry: + return ProviderRegistry( + event_support=EventSupport(), + evaluation_context_getter=lambda: EvaluationContext(), + ) + + def test_registry_serves_noop_as_default(): - registry = ProviderRegistry() + registry = make_registry() assert isinstance(registry.get_default_provider(), NoOpProvider) assert isinstance(registry.get_provider("unknown domain"), NoOpProvider) def test_setting_provider_requires_domain(): - registry = ProviderRegistry() + registry = make_registry() with pytest.raises(GeneralError) as exc_info: registry.set_provider(None, NoOpProvider()) # type: ignore[reportArgumentType] @@ -27,7 +36,7 @@ def test_setting_provider_requires_domain(): def test_setting_provider_requires_provider(): - registry = ProviderRegistry() + registry = make_registry() with pytest.raises(GeneralError) as exc_info: registry.set_provider("domain", None) # type: ignore[reportArgumentType] @@ -36,7 +45,7 @@ def test_setting_provider_requires_provider(): def test_can_register_provider_to_multiple_domains(): - registry = ProviderRegistry() + registry = make_registry() provider = NoOpProvider() registry.set_provider("domain1", provider) @@ -49,7 +58,7 @@ def test_can_register_provider_to_multiple_domains(): def test_registering_provider_replaces_previous_provider(): """Test that registering a provider replaces the previous provider and calls shutdown on the old one.""" - registry = ProviderRegistry() + registry = make_registry() provider1 = Mock() provider2 = Mock() @@ -66,7 +75,7 @@ def test_registering_provider_replaces_previous_provider(): def test_registering_provider_for_first_time_initializes_it(): """Test that registering a provider for the first time calls its initialize method.""" - registry = ProviderRegistry() + registry = make_registry() provider = Mock() registry.set_provider("domain1", provider, wait_for_init=True) @@ -76,7 +85,7 @@ def test_registering_provider_for_first_time_initializes_it(): def test_setting_default_provider_requires_provider(): - registry = ProviderRegistry() + registry = make_registry() with pytest.raises(GeneralError) as exc_info: registry.set_default_provider(None) # type: ignore[reportArgumentType] @@ -87,7 +96,7 @@ def test_setting_default_provider_requires_provider(): def test_replacing_default_provider_shuts_down_old_one(): """Test that replacing the default provider shuts down the old default provider.""" - registry = ProviderRegistry() + registry = make_registry() default_provider1 = Mock() default_provider2 = Mock() @@ -102,7 +111,7 @@ def test_replacing_default_provider_shuts_down_old_one(): def test_setting_default_provider_initializes_it(): - registry = ProviderRegistry() + registry = make_registry() provider = Mock() registry.set_default_provider(provider, wait_for_init=True) @@ -113,7 +122,7 @@ def test_setting_default_provider_initializes_it(): def test_registering_provider_as_default_then_domain_only_initializes_once(): """Test that registering the same provider as default and for a domain only initializes it once.""" - registry = ProviderRegistry() + registry = make_registry() provider = Mock() registry.set_default_provider(provider, wait_for_init=True) @@ -125,7 +134,7 @@ def test_registering_provider_as_default_then_domain_only_initializes_once(): def test_registering_provider_as_domain_then_default_only_initializes_once(): """Test that registering the same provider as default and for a domain only initializes it once.""" - registry = ProviderRegistry() + registry = make_registry() provider = Mock() registry.set_provider("domain", provider, wait_for_init=True) @@ -137,7 +146,7 @@ def test_registering_provider_as_domain_then_default_only_initializes_once(): def test_replacing_provider_used_as_default_does_not_shutdown(): """Test that replacing a provider that is also the default does not shut it down twice.""" - registry = ProviderRegistry() + registry = make_registry() provider1 = Mock() provider2 = Mock() @@ -153,7 +162,7 @@ def test_replacing_provider_used_as_default_does_not_shutdown(): def test_replacing_default_provider_used_as_domain_does_not_shutdown(): """Test that replacing a default provider that is also used for a domain does not shut it down twice.""" - registry = ProviderRegistry() + registry = make_registry() provider1 = Mock() provider2 = Mock() @@ -169,7 +178,7 @@ def test_replacing_default_provider_used_as_domain_does_not_shutdown(): def test_shutting_down_registry_shuts_down_providers_once(): """Test that shutting down the registry shuts down each provider only once.""" - registry = ProviderRegistry() + registry = make_registry() provider1 = Mock() provider2 = Mock() @@ -188,7 +197,7 @@ def test_shutting_down_registry_shuts_down_providers_once(): def test_initializing_provider_sets_status_ready(): """Test that initializing a provider sets its status to READY.""" - registry = ProviderRegistry() + registry = make_registry() provider = Mock() assert registry.get_provider_status(provider) == ProviderStatus.NOT_READY @@ -202,7 +211,7 @@ def test_initializing_provider_sets_status_ready(): def test_shutting_down_provider_sets_status_not_ready(): """Test that shutting down a provider sets its status to NOT_READY.""" - registry = ProviderRegistry() + registry = make_registry() provider = Mock() registry.set_provider("domain", provider, wait_for_init=True) @@ -215,7 +224,7 @@ def test_shutting_down_provider_sets_status_not_ready(): def test_clearing_registry_resets_providers_and_default(): """Test that clearing the registry resets all providers and the default provider.""" - registry = ProviderRegistry() + registry = make_registry() provider = Mock() registry.set_provider("domain", provider, wait_for_init=True) @@ -235,7 +244,7 @@ def test_clearing_registry_resets_providers_and_default(): def test_set_provider_returns_before_initialization_completes(): """Test that set_provider (non-blocking) returns before initialize finishes.""" - registry = ProviderRegistry() + registry = make_registry() init_started = threading.Event() init_may_proceed = threading.Event() provider = Mock() @@ -257,7 +266,7 @@ def slow_initialize(ctx): def test_set_provider_and_wait_blocks_until_ready(): """Test that set_provider with wait_for_init=True blocks until READY.""" - registry = ProviderRegistry() + registry = make_registry() initialized = threading.Event() provider = Mock() @@ -274,7 +283,7 @@ def tracking_initialize(ctx): def test_set_provider_and_wait_reraises_on_error(): """Test that set_provider with wait_for_init=True re-raises initialization errors.""" - registry = ProviderRegistry() + registry = make_registry() provider = Mock() provider.initialize.side_effect = ProviderFatalError() @@ -286,7 +295,7 @@ def test_concurrent_set_provider_for_same_provider_initializes_once(): """Concurrent set_provider calls for different domains using the same provider instance must only initialize the provider once.""" - registry = ProviderRegistry() + registry = make_registry() init_count = 0 start_gate = threading.Event() @@ -317,7 +326,7 @@ def test_provider_replaced_during_async_init_does_not_set_ready_status(): """If a provider is replaced while its async initialize is still running, the late PROVIDER_READY event must not resurrect its status.""" - registry = ProviderRegistry() + registry = make_registry() init_started = threading.Event() init_may_proceed = threading.Event() @@ -350,7 +359,7 @@ def test_set_provider_does_not_block_on_hanging_old_shutdown(): """If the previously-registered provider's shutdown() hangs, a subsequent set_provider call must not be blocked by it.""" - registry = ProviderRegistry() + registry = make_registry() hanging = Mock() hang = threading.Event() @@ -383,7 +392,7 @@ def test_stale_shutdown_does_not_clobber_re_registered_provider(): (background) shutdown is still finishing, the stale shutdown must not pop its status or detach() the freshly-registered instance.""" - registry = ProviderRegistry() + registry = make_registry() shutdown_started = threading.Event() shutdown_may_proceed = threading.Event() @@ -424,3 +433,30 @@ def slow_shutdown(): "stale shutdown of A clobbered the fresh registration's status" ) provider_a.detach.assert_not_called() + + +def test_stale_shutdown_skips_shutdown_if_re_registered_first(): + """If a provider is re-registered before its background shutdown gets to + call shutdown() at all, shutdown() must not be invoked on the active + provider.""" + + registry = make_registry() + + provider_a = Mock() + provider_b = Mock() + + # step 1: register A, replace with B, then re-register A. queued background shutdown of A from the A->B swap is racing + registry.set_provider("domain", provider_a, wait_for_init=True) + registry.set_provider("domain", provider_b, wait_for_init=True) + registry.set_provider("domain", provider_a, wait_for_init=True) + # let the natural A->B background shutdown complete before we assert + time.sleep(0.2) + provider_a.shutdown.reset_mock() + provider_a.detach.reset_mock() + + # step 2: simulate the late-arriving stale shutdown; abort check must short-circuit before shutdown() is called + registry._shutdown_provider(provider_a, abort_if_re_registered=True) + + provider_a.shutdown.assert_not_called() + provider_a.detach.assert_not_called() + assert registry.get_provider_status(provider_a) == ProviderStatus.READY diff --git a/tests/test_api.py b/tests/test_api.py index cdb077fe..2f37e731 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -13,10 +13,12 @@ get_evaluation_context, get_hooks, get_provider_metadata, + get_transaction_context, remove_handler, set_evaluation_context, set_provider, set_provider_and_wait, + set_transaction_context_propagator, shutdown, ) from openfeature.evaluation_context import EvaluationContext @@ -24,13 +26,8 @@ from openfeature.exception import ErrorCode, GeneralError, ProviderFatalError from openfeature.hook import Hook from openfeature.provider import FeatureProvider, Metadata, ProviderStatus -from openfeature.provider._registry import provider_registry from openfeature.provider.no_op_provider import NoOpProvider -from openfeature.transaction_context import ( - ContextVarsTransactionContextPropagator, - get_transaction_context, - set_transaction_context_propagator, -) +from openfeature.transaction_context import ContextVarsTransactionContextPropagator def wait_for_mock_call(mock: MagicMock, timeout: float = 1.0) -> None: @@ -93,8 +90,9 @@ def test_should_invoke_provider_shutdown_function_once_provider_is_no_longer_in_ provider_2 = MagicMock(spec=FeatureProvider) # When - set_provider(provider_1) - set_provider(provider_2) + set_provider_and_wait(provider_1) + set_provider_and_wait(provider_2) + wait_for_mock_call(provider_1.shutdown) # Then assert provider_1.shutdown.called @@ -246,7 +244,7 @@ def test_shutdown_should_reset_api_state(): shutdown() # Then - provider = provider_registry.get_default_provider() + provider = get_client().provider assert isinstance(provider, NoOpProvider) hooks = get_hooks() diff --git a/tests/test_client.py b/tests/test_client.py index 44b49e5f..6125d4c6 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -8,7 +8,8 @@ import pytest -from openfeature import _event_support, api +from openfeature import api +from openfeature._api import _default_api from openfeature.api import ( add_hooks, clear_hooks, @@ -24,7 +25,6 @@ from openfeature.flag_evaluation import FlagResolutionDetails, FlagType, Reason from openfeature.hook import Hook from openfeature.provider import FeatureProvider, ProviderStatus -from openfeature.provider._registry import provider_registry from openfeature.provider.in_memory_provider import InMemoryFlag, InMemoryProvider from openfeature.provider.no_op_provider import NoOpProvider from openfeature.transaction_context import ContextVarsTransactionContextPropagator @@ -188,7 +188,7 @@ def test_should_pass_flag_metadata_from_resolution_to_evaluation_details(): ) set_provider(provider, "my-client") - client = OpenFeatureClient("my-client", None) + client = get_client("my-client") # When details = client.get_boolean_details(flag_key="Key", default_value=False) @@ -239,7 +239,7 @@ def test_should_handle_an_open_feature_exception_thrown_by_a_provider( def test_should_return_client_metadata_with_domain(): # Given - client = OpenFeatureClient("my-client", None, NoOpProvider()) + client = get_client("my-client") # When metadata = client.get_metadata() # Then @@ -359,7 +359,7 @@ def _shutdown(self) -> None: monkeypatch.setattr(provider, "shutdown", types.MethodType(_shutdown, provider)) # When - provider_registry.shutdown() + _default_api._provider_registry.shutdown() status = client.get_provider_status() @@ -549,13 +549,13 @@ def test_run_client_handlers_without_registered_handlers_is_noop(): client = get_client("client-without-handlers") details = EventDetails(provider_name=provider.get_metadata().name) - assert client not in _event_support._client_handlers + assert client not in _default_api._event_support._client_handlers - _event_support.run_client_handlers( + _default_api._event_support.run_client_handlers( client, ProviderEvent.PROVIDER_CONFIGURATION_CHANGED, details ) - assert client not in _event_support._client_handlers + assert client not in _default_api._event_support._client_handlers # Requirement 5.1.4, Requirement 5.1.5 @@ -707,7 +707,9 @@ def test_client_should_merge_contexts(): client_context = EvaluationContext( targeting_key="client", attributes={"client_attr": "client_value"} ) - client = OpenFeatureClient(domain=None, version=None, context=client_context) + client = OpenFeatureClient( + domain=None, version=None, api=_default_api, context=client_context + ) # Invocation-specific context invocation_context = EvaluationContext( @@ -743,6 +745,7 @@ def test_client_should_track_event(): def test_tracking_merges_evaluation_contexts(): + api.set_transaction_context_propagator(ContextVarsTransactionContextPropagator()) spy_provider = MagicMock(spec=NoOpProvider) api.set_provider(spy_provider) client = get_client() diff --git a/tests/test_isolated_api.py b/tests/test_isolated_api.py new file mode 100644 index 00000000..977cf525 --- /dev/null +++ b/tests/test_isolated_api.py @@ -0,0 +1,464 @@ +"""Tests for isolated OpenFeature API instances (spec section 1.8).""" + +import inspect +import time +from unittest.mock import MagicMock + +import pytest + +from openfeature import api +from openfeature._api import _default_api +from openfeature.evaluation_context import EvaluationContext +from openfeature.event import ProviderEvent, ProviderEventDetails +from openfeature.hook import Hook +from openfeature.isolated import OpenFeatureAPI, create_api +from openfeature.provider import FeatureProvider, Metadata, ProviderStatus +from openfeature.provider.no_op_provider import NoOpProvider +from openfeature.transaction_context import ContextVarsTransactionContextPropagator + + +def wait_for_mock_call(mock: MagicMock, timeout: float = 1.0) -> None: + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + if mock.call_count: + return + + time.sleep(0.01) + + +# --- Spec 1.8.1: Factory returns independent instances --- + + +def test_create_api_returns_new_instance(): + api1 = create_api() + api2 = create_api() + assert api1 is not api2 + + +def test_isolated_instance_is_openfeature_api(): + api_instance = create_api() + assert isinstance(api_instance, OpenFeatureAPI) + + +# --- Spec 1.8.2: Same API contract --- + + +_ISOLATED_API_PUBLIC_METHODS = ( + "add_handler", + "add_hooks", + "clear_evaluation_context", + "clear_hooks", + "clear_providers", + "clear_transaction_context_propagator", + "get_client", + "get_evaluation_context", + "get_hooks", + "get_provider", + "get_provider_metadata", + "get_provider_status", + "get_transaction_context", + "remove_handler", + "set_evaluation_context", + "set_provider", + "set_provider_and_wait", + "set_transaction_context", + "set_transaction_context_propagator", + "shutdown", +) + + +def test_isolated_api_provides_full_api_contract(): + """Spec 1.8.2: factory result MUST expose the same contract as the global API.""" + api_instance = create_api() + reference = OpenFeatureAPI() + + for name in _ISOLATED_API_PUBLIC_METHODS: + assert hasattr(api_instance, name), f"isolated API missing method: {name}" + attr = getattr(api_instance, name) + assert callable(attr), f"isolated API attribute is not callable: {name}" + actual = inspect.signature(attr) + expected = inspect.signature(getattr(reference, name)) + assert actual == expected, ( + f"signature mismatch for {name}: {actual} != {expected}" + ) + + +def test_isolated_api_get_client_returns_working_client(): + provider = MagicMock(spec=FeatureProvider) + provider.get_metadata.return_value = MagicMock(name="test-provider") + + api_instance = create_api() + api_instance.set_provider(provider) + + client = api_instance.get_client() + assert client is not None + assert client.provider is provider + + +def test_isolated_api_get_client_with_domain(): + provider = MagicMock(spec=FeatureProvider) + provider.get_metadata.return_value = MagicMock(name="domain-provider") + + api_instance = create_api() + api_instance.set_provider(provider, domain="my-domain") + + client = api_instance.get_client(domain="my-domain") + assert client.provider is provider + + +# --- Isolated state: providers --- + + +def test_isolated_providers_are_independent(): + provider_a = MagicMock(spec=FeatureProvider) + provider_a.get_metadata.return_value = MagicMock(name="provider-a") + provider_b = MagicMock(spec=FeatureProvider) + provider_b.get_metadata.return_value = MagicMock(name="provider-b") + + api1 = create_api() + api2 = create_api() + + api1.set_provider(provider_a) + api2.set_provider(provider_b) + + client1 = api1.get_client() + client2 = api2.get_client() + + assert client1.provider is provider_a + assert client2.provider is provider_b + + +def test_isolated_provider_does_not_affect_global(): + provider = MagicMock(spec=FeatureProvider) + provider.get_metadata.return_value = MagicMock(name="isolated-provider") + + api_instance = create_api() + api_instance.set_provider(provider) + + # Global singleton should still have NoOpProvider + global_client = api.get_client() + assert isinstance(global_client.provider, NoOpProvider) + + +# --- Spec 1.8.4: Provider should not be bound to multiple APIs --- + + +def test_binding_provider_to_multiple_apis_raises(): + """Spec 1.8.4: provider must not be bound to more than one OpenFeature API. + + Uses a Protocol-only provider (no AbstractProvider subclass) to ensure + detection works regardless of provider implementation strategy. + """ + provider = MagicMock(spec=FeatureProvider) + provider.get_metadata.return_value = MagicMock(name="protocol-provider") + + api1 = create_api() + api2 = create_api() + + api1.set_provider(provider) + + with pytest.raises(RuntimeError, match="already bound"): + api2.set_provider(provider) + + +def test_rebinding_provider_to_same_api_does_not_raise(): + """Re-binding the same provider to the same API (e.g., on a different domain) + must not trigger the spec 1.8.4 error.""" + provider = NoOpProvider() + api_instance = create_api() + + api_instance.set_provider(provider, domain="domain-a") + # second call must not raise + api_instance.set_provider(provider, domain="domain-b") + + assert api_instance.get_provider("domain-a") is provider + assert api_instance.get_provider("domain-b") is provider + + +def test_provider_can_be_rebound_after_being_released(): + """After a provider is released from one API (via clear_providers/shutdown), + binding it to another API must not raise.""" + provider = NoOpProvider() + + api1 = create_api() + api1.set_provider(provider) + api1.shutdown() + + # provider is released; binding to a different API now succeeds + api2 = create_api() + api2.set_provider(provider) + + assert api2.get_provider() is provider + + +def test_set_provider_rejects_non_weak_referenceable_provider(): + """Providers must be weak-referenceable so the SDK can track bindings + without leaking memory; surfacing this requirement up front (rather than + silently skipping the spec 1.8.4 check) avoids hard-to-diagnose bugs.""" + + # A direct ``object`` subclass with ``__slots__`` and no ``__weakref__`` + # entry; instances are not weak-referenceable. Implements the + # ``FeatureProvider`` protocol structurally rather than via inheritance + # (which would inherit ``__weakref__`` from the parent class). + class NotWeakReferenceable: + __slots__ = () + + def attach(self, on_emit): + pass + + def detach(self): + pass + + def get_metadata(self): + return Metadata(name="not-weak-referenceable") + + def get_provider_hooks(self): + return [] + + provider = NotWeakReferenceable() + api_instance = create_api() + + with pytest.raises(TypeError, match="weak-referenceable"): + api_instance.set_provider(provider) # type: ignore[arg-type] + + +# --- Isolated state: hooks --- + + +def test_isolated_hooks_are_independent(): + hook_a = MagicMock(spec=Hook) + hook_b = MagicMock(spec=Hook) + + api1 = create_api() + api2 = create_api() + + api1.add_hooks([hook_a]) + api2.add_hooks([hook_b]) + + assert hook_a in api1.get_hooks() + assert hook_b not in api1.get_hooks() + assert hook_b in api2.get_hooks() + assert hook_a not in api2.get_hooks() + + +def test_isolated_hooks_do_not_affect_global(): + global_hook = MagicMock(spec=Hook) + isolated_hook = MagicMock(spec=Hook) + + api.add_hooks([global_hook]) + + api_instance = create_api() + api_instance.add_hooks([isolated_hook]) + + assert api.get_hooks() == [global_hook] + assert api_instance.get_hooks() == [isolated_hook] + + +def test_clear_hooks_on_isolated_api(): + hook = MagicMock(spec=Hook) + + api_instance = create_api() + api_instance.add_hooks([hook]) + assert len(api_instance.get_hooks()) == 1 + + api_instance.clear_hooks() + assert len(api_instance.get_hooks()) == 0 + + +# --- Isolated state: evaluation context --- + + +def test_isolated_evaluation_context_is_independent(): + ctx_a = EvaluationContext(targeting_key="user-a") + ctx_b = EvaluationContext(targeting_key="user-b") + + api1 = create_api() + api2 = create_api() + + api1.set_evaluation_context(ctx_a) + api2.set_evaluation_context(ctx_b) + + assert api1.get_evaluation_context().targeting_key == "user-a" + assert api2.get_evaluation_context().targeting_key == "user-b" + + +def test_isolated_evaluation_context_does_not_affect_global(): + api.set_evaluation_context(EvaluationContext(targeting_key="global-user")) + + api_instance = create_api() + api_instance.set_evaluation_context( + EvaluationContext(targeting_key="isolated-user") + ) + + assert api.get_evaluation_context().targeting_key == "global-user" + assert api_instance.get_evaluation_context().targeting_key == "isolated-user" + + +# --- Isolated state: events --- + + +def test_isolated_event_handlers_are_independent(): + handler_a = MagicMock() + handler_b = MagicMock() + + api1 = create_api() + api2 = create_api() + + provider1 = MagicMock(spec=FeatureProvider) + provider1.get_metadata.return_value = MagicMock(name="p1") + provider2 = MagicMock(spec=FeatureProvider) + provider2.get_metadata.return_value = MagicMock(name="p2") + + api1.set_provider(provider1) + api2.set_provider(provider2) + + # Register handlers for CONFIGURATION_CHANGED to test dispatch isolation + api1.add_handler(ProviderEvent.PROVIDER_CONFIGURATION_CHANGED, handler_a) + api2.add_handler(ProviderEvent.PROVIDER_CONFIGURATION_CHANGED, handler_b) + + # Dispatch event on api1's registry — only handler_a should fire + api1._provider_registry.dispatch_event( + provider1, + ProviderEvent.PROVIDER_CONFIGURATION_CHANGED, + ProviderEventDetails(), + ) + + wait_for_mock_call(handler_a) + assert handler_a.call_count == 1 + assert handler_b.call_count == 0 + + +def test_isolated_event_handlers_do_not_affect_global(): + handler = MagicMock() + + api_instance = create_api() + provider = MagicMock(spec=FeatureProvider) + provider.get_metadata.return_value = MagicMock(name="p") + api_instance.set_provider(provider) + api_instance.add_handler(ProviderEvent.PROVIDER_CONFIGURATION_CHANGED, handler) + + # Dispatch on global — isolated handler should NOT fire + global_provider = MagicMock(spec=FeatureProvider) + global_provider.get_metadata.return_value = MagicMock(name="gp") + api.set_provider(global_provider) + + handler.reset_mock() + + _default_api._event_support.run_handlers_for_provider( + global_provider, + ProviderEvent.PROVIDER_CONFIGURATION_CHANGED, + ProviderEventDetails(), + ) + + assert handler.call_count == 0 + + +# --- Provider lifecycle on isolated instances --- + + +def test_isolated_api_initializes_provider(): + provider = MagicMock(spec=FeatureProvider) + provider.get_metadata.return_value = MagicMock(name="init-provider") + + api_instance = create_api() + api_instance.set_provider(provider) + + provider.initialize.assert_called_once() + + +def test_isolated_api_shuts_down_provider(): + provider = MagicMock(spec=FeatureProvider) + provider.get_metadata.return_value = MagicMock(name="shutdown-provider") + + api_instance = create_api() + api_instance.set_provider(provider) + api_instance.shutdown() + + provider.shutdown.assert_called_once() + + +def test_isolated_api_clear_providers(): + provider = MagicMock(spec=FeatureProvider) + provider.get_metadata.return_value = MagicMock(name="clear-provider") + + api_instance = create_api() + api_instance.set_provider(provider) + api_instance.clear_providers() + + client = api_instance.get_client() + assert isinstance(client.provider, NoOpProvider) + + +# --- Provider status on isolated instances --- + + +def test_isolated_client_provider_status(): + provider = MagicMock(spec=FeatureProvider) + provider.get_metadata.return_value = MagicMock(name="status-provider") + + api_instance = create_api() + api_instance.set_provider(provider) + + client = api_instance.get_client() + assert client.get_provider_status() == ProviderStatus.READY + + +# --- Transaction context on isolated instances --- + + +def test_isolated_transaction_context_propagator(): + api1 = create_api() + api2 = create_api() + + api1.set_transaction_context_propagator(ContextVarsTransactionContextPropagator()) + + ctx = EvaluationContext(targeting_key="tx-user") + api1.set_transaction_context(ctx) + + assert api1.get_transaction_context().targeting_key == "tx-user" + # api2 still uses NoOpTransactionContextPropagator → empty context + assert api2.get_transaction_context().targeting_key is None + + +def test_isolated_transaction_context_with_both_using_contextvars(): + """Two APIs with ContextVars propagators must not share state.""" + api1 = create_api() + api2 = create_api() + + api1.set_transaction_context_propagator(ContextVarsTransactionContextPropagator()) + api2.set_transaction_context_propagator(ContextVarsTransactionContextPropagator()) + + api1.set_transaction_context(EvaluationContext(targeting_key="api1-user")) + + assert api1.get_transaction_context().targeting_key == "api1-user" + assert api2.get_transaction_context().targeting_key is None + + +# --- Global singleton backward compatibility --- + + +def test_global_api_still_works(): + provider = MagicMock(spec=FeatureProvider) + provider.get_metadata.return_value = MagicMock(name="global-provider") + + api.set_provider(provider) + client = api.get_client() + + assert client.provider is provider + provider.initialize.assert_called_once() + + +def test_global_hooks_still_work(): + hook = MagicMock(spec=Hook) + + api.add_hooks([hook]) + assert hook in api.get_hooks() + + api.clear_hooks() + assert len(api.get_hooks()) == 0 + + +def test_global_evaluation_context_still_works(): + ctx = EvaluationContext(targeting_key="global-user") + api.set_evaluation_context(ctx) + assert api.get_evaluation_context().targeting_key == "global-user" diff --git a/tests/test_transaction_context_in_hooks.py b/tests/test_transaction_context_in_hooks.py index 61a5b5cf..2e317f46 100644 --- a/tests/test_transaction_context_in_hooks.py +++ b/tests/test_transaction_context_in_hooks.py @@ -1,9 +1,9 @@ from openfeature.api import ( + get_client, set_provider, set_transaction_context, set_transaction_context_propagator, ) -from openfeature.client import OpenFeatureClient from openfeature.evaluation_context import EvaluationContext from openfeature.hook import Hook from openfeature.provider.no_op_provider import NoOpProvider @@ -32,7 +32,7 @@ def test_transaction_context_merged_into_hook_context(): provider = NoOpProvider() set_provider(provider) - client = OpenFeatureClient(domain=None, version=None) + client = get_client() hook = TransactionContextHook() client.add_hooks([hook]) diff --git a/uv.lock b/uv.lock index fd4c1c92..ad2a07af 100644 --- a/uv.lock +++ b/uv.lock @@ -197,7 +197,7 @@ wheels = [ [[package]] name = "openfeature-sdk" -version = "0.9.0" +version = "0.10.0" source = { editable = "." } [package.dev-dependencies]