From 3e56b4803eb7f0366ce46686870c748f81a905a3 Mon Sep 17 00:00:00 2001 From: Sean N Date: Sun, 17 May 2026 04:31:36 +0200 Subject: [PATCH] Fixed: stem inconsistencies in the module resolution and rollback look timeouts --- nitro_dispatch/core/hook_registry.py | 247 ++++++++++----- nitro_dispatch/core/plugin_base.py | 42 ++- nitro_dispatch/core/plugin_manager.py | 221 ++++++++++++-- tests/test_critical_fixes.py | 424 ++++++++++++++++++++++++++ tests/test_plugin_manager.py | 1 + 5 files changed, 817 insertions(+), 118 deletions(-) create mode 100644 tests/test_critical_fixes.py diff --git a/nitro_dispatch/core/hook_registry.py b/nitro_dispatch/core/hook_registry.py index c5c88ff..64c90f2 100644 --- a/nitro_dispatch/core/hook_registry.py +++ b/nitro_dispatch/core/hook_registry.py @@ -2,7 +2,9 @@ import asyncio import concurrent.futures +import inspect import re +import threading from typing import Any, Callable, Dict, List, Optional import logging @@ -33,13 +35,34 @@ class HookRegistry: - :class:`StopPropagation` to halt the chain from a hook. - Plugin-level enable/disable: hooks from disabled plugins are skipped without unregistering. + - Thread-safe registration and dispatch: mutations and + ``_get_matching_hooks`` are guarded by an :class:`RLock` and + iteration snapshots the hook map. """ + # Class-level executor shared across instances so a hung sync-hook + # worker never gets joined by ThreadPoolExecutor.__exit__ on the + # caller's behalf. ``daemon=True`` lets the process exit even if a + # runaway hook is still in flight. We deliberately do not call + # ``shutdown(wait=True)`` anywhere on this pool. + _timeout_executor = concurrent.futures.ThreadPoolExecutor( + max_workers=8, + thread_name_prefix="nitro-hook-timeout", + ) + def __init__(self) -> None: """Initialize an empty registry with the default error strategy.""" self._hooks: Dict[str, List[Dict[str, Any]]] = {} self._error_strategy: str = "log_and_continue" self._hook_tracing: bool = False + # Reentrant: hooks running in worker threads may call back into + # register/unregister, and trigger_async dispatches sync hooks + # to executor threads concurrently. + self._lock = threading.RLock() + # Populated by trigger/trigger_async when the strategy is + # ``collect_all`` so callers can inspect what failed. Cleared at + # the start of each dispatch. + self._last_errors: List[Dict[str, Any]] = [] def register( self, @@ -73,9 +96,6 @@ def register( >>> reg = HookRegistry() >>> reg.register("user.*", lambda d: d, priority=100) """ - if event_name not in self._hooks: - self._hooks[event_name] = [] - hook_info = { "callback": callback, "plugin": plugin, @@ -85,10 +105,12 @@ def register( "is_async": asyncio.iscoroutinefunction(callback), } - self._hooks[event_name].append(hook_info) - - # Sort hooks by priority (higher priority first) - self._hooks[event_name].sort(key=lambda h: h["priority"], reverse=True) + with self._lock: + if event_name not in self._hooks: + self._hooks[event_name] = [] + self._hooks[event_name].append(hook_info) + # Sort hooks by priority (higher priority first) + self._hooks[event_name].sort(key=lambda h: h["priority"], reverse=True) logger.debug( f"Registered hook '{event_name}' from plugin " @@ -111,17 +133,19 @@ def unregister(self, event_name: str, callback: Callable, plugin: Optional[Any] Returns: True if a hook was found and removed; False otherwise. """ - if event_name not in self._hooks: - return False + with self._lock: + if event_name not in self._hooks: + return False + + original_length = len(self._hooks[event_name]) + self._hooks[event_name] = [ + hook + for hook in self._hooks[event_name] + if not (hook["callback"] == callback and hook["plugin"] == plugin) + ] - original_length = len(self._hooks[event_name]) - self._hooks[event_name] = [ - hook - for hook in self._hooks[event_name] - if not (hook["callback"] == callback and hook["plugin"] == plugin) - ] + removed = len(self._hooks[event_name]) < original_length - removed = len(self._hooks[event_name]) < original_length if removed: logger.debug(f"Unregistered hook '{event_name}'") return removed @@ -137,9 +161,12 @@ def _match_event_pattern(self, pattern: str, event: str) -> bool: Returns: True if event matches pattern """ - # `*` matches a single dot-delimited segment, mirroring glob semantics - # rather than regex `.*` (which would cross segment boundaries). - regex_pattern = pattern.replace(".", r"\.").replace("*", "[^.]*") + # `*` matches a single non-empty dot-delimited segment, mirroring + # glob semantics rather than regex `.*` (which would cross segment + # boundaries). ``+`` (one-or-more) instead of ``*`` (zero-or-more) + # is intentional: ``user.*`` does NOT match the literal string + # ``"user."`` with an empty trailing segment. + regex_pattern = pattern.replace(".", r"\.").replace("*", "[^.]+") regex_pattern = f"^{regex_pattern}$" return bool(re.match(regex_pattern, event)) @@ -153,9 +180,15 @@ def _get_matching_hooks(self, event_name: str) -> List[Dict[str, Any]]: Returns: List of matching hook information dictionaries """ - matching_hooks = [] + matching_hooks: List[Dict[str, Any]] = [] - for registered_event, hooks in self._hooks.items(): + # Snapshot under the lock so a concurrent register/unregister + # from a hook running in a worker thread cannot mutate the dict + # mid-iteration. + with self._lock: + snapshot = [(event, list(hooks)) for event, hooks in self._hooks.items()] + + for registered_event, hooks in snapshot: # Exact match if registered_event == event_name: matching_hooks.extend(hooks) @@ -175,6 +208,16 @@ def _execute_hook_with_timeout( """ Execute a synchronous hook with optional timeout. + Uses a shared class-level :class:`ThreadPoolExecutor` rather than + a per-call one. A previous implementation used + ``with ThreadPoolExecutor(...) as executor:``; on timeout, the + ``__exit__`` call invoked ``shutdown(wait=True)`` and blocked + the caller until the runaway callback actually returned — making + the timeout effectively unenforceable. The shared pool is never + joined, so :class:`HookTimeoutError` propagates immediately and + the orphaned worker thread (which Python cannot forcibly kill) + is left to finish in the background. + Args: callback: Hook callback function data: Data to pass to callback @@ -189,18 +232,15 @@ def _execute_hook_with_timeout( if timeout is None: return callback(data) - # Thread-based timeout: portable (works on Windows and in non-main - # threads, unlike signal.SIGALRM) and safe to call from executors. - # Note: the worker thread cannot be forcibly killed on timeout. This - # matches asyncio.wait_for's behavior for async hooks. - with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: - future = executor.submit(callback, data) - try: - return future.result(timeout=timeout) - except concurrent.futures.TimeoutError: - raise HookTimeoutError( - f"Hook execution exceeded timeout of {timeout}s" - ) + future = self._timeout_executor.submit(callback, data) + try: + return future.result(timeout=timeout) + except concurrent.futures.TimeoutError: + # Best-effort cancel; running futures cannot actually be + # cancelled, but we mark intent so the worker is reclaimed + # if it ever does return. + future.cancel() + raise HookTimeoutError(f"Hook execution exceeded timeout of {timeout}s") async def _execute_async_hook_with_timeout( self, callback: Callable, data: Any, timeout: Optional[float] @@ -227,6 +267,58 @@ async def _execute_async_hook_with_timeout( except asyncio.TimeoutError: raise HookTimeoutError(f"Async hook execution exceeded timeout of {timeout}s") + async def _execute_sync_hook_in_executor( + self, + callback: Callable, + data: Any, + timeout: Optional[float], + ) -> Any: + """Dispatch a sync hook to the default executor with timeout. + + Avoids the double-thread-pool dispatch that the old code + accidentally created (``run_in_executor`` -> helper -> *another* + ThreadPoolExecutor.submit). Enforces the timeout at the asyncio + boundary so a runaway hook does not pin two threads. + """ + loop = asyncio.get_running_loop() + future = loop.run_in_executor(None, callback, data) + if timeout is None: + return await future + try: + return await asyncio.wait_for(future, timeout=timeout) + except asyncio.TimeoutError: + raise HookTimeoutError(f"Hook execution exceeded timeout of {timeout}s") + + async def _notify_on_error(self, plugin: Any, error: Exception) -> None: + """Call ``plugin.on_error`` and ``await`` the result if it's a coroutine.""" + if not (plugin and hasattr(plugin, "on_error")): + return + try: + maybe_coro = plugin.on_error(error) + if inspect.iscoroutine(maybe_coro): + await maybe_coro + except Exception as notify_error: + logger.error(f"Error in plugin error handler: {notify_error}") + + def _notify_on_error_sync(self, plugin: Any, error: Exception) -> None: + """Sync variant: drops async ``on_error`` coroutines with a warning.""" + if not (plugin and hasattr(plugin, "on_error")): + return + try: + maybe_coro = plugin.on_error(error) + if inspect.iscoroutine(maybe_coro): + # The sync trigger() path cannot await this; close it to + # silence "coroutine was never awaited" RuntimeWarning and + # tell the user. + maybe_coro.close() + logger.warning( + f"Async on_error coroutine from plugin " + f"'{getattr(plugin, 'name', '?')}' was dropped in sync " + f"trigger(); use trigger_async() to await it." + ) + except Exception as notify_error: + logger.error(f"Error in plugin error handler: {notify_error}") + def trigger(self, event_name: str, data: Any = None) -> Any: """Fire an event and run matching hooks synchronously. @@ -237,6 +329,11 @@ def trigger(self, event_name: str, data: Any = None) -> Any: are skipped with a warning; use :meth:`trigger_async` for those. + Note: a hook that returns ``None`` does NOT clear the payload + for the next hook — the previous ``data`` is preserved. If you + need to set the chain value to ``None`` explicitly, raise + :class:`StopPropagation` or use a sentinel value. + Args: event_name: Event name to fire. Literal plus wildcard matches are dispatched. @@ -264,7 +361,7 @@ def trigger(self, event_name: str, data: Any = None) -> Any: if self._hook_tracing: logger.debug(f"Triggering event '{event_name}' with {len(hooks)} hooks") - errors = [] + errors: List[Dict[str, Any]] = [] result = data for hook_info in hooks: @@ -317,11 +414,7 @@ def trigger(self, event_name: str, data: Any = None) -> Any: error_msg = f"Hook '{event_name}' from plugin '{plugin_name}' " f"timed out: {e}" logger.error(error_msg) - if plugin and hasattr(plugin, "on_error"): - try: - plugin.on_error(e) - except Exception as notify_error: - logger.error(f"Error in plugin error handler: {notify_error}") + self._notify_on_error_sync(plugin, e) if self._error_strategy == "fail_fast": raise HookError(error_msg) from e @@ -340,12 +433,7 @@ def trigger(self, event_name: str, data: Any = None) -> Any: ) logger.error(error_msg) - # Notify plugin of error - if plugin and hasattr(plugin, "on_error"): - try: - plugin.on_error(e) - except Exception as notify_error: - logger.error(f"Error in plugin error handler: {notify_error}") + self._notify_on_error_sync(plugin, e) if self._error_strategy == "fail_fast": raise HookError(error_msg) from e @@ -359,6 +447,10 @@ def trigger(self, event_name: str, data: Any = None) -> Any: ) # log_and_continue: just continue to next hook + # Expose collected errors for programmatic inspection (issue: the + # ``collect_all`` strategy previously had no way for callers to + # see what failed). + self._last_errors = errors if errors and self._error_strategy == "collect_all": logger.warning(f"Event '{event_name}' completed with {len(errors)} errors") @@ -373,6 +465,9 @@ async def trigger_async(self, event_name: str, data: Any = None) -> Any: when invoked through this method. Ordering, stop-propagation, and error-strategy semantics are identical to :meth:`trigger`. + ``on_error`` callbacks are awaited if they return a coroutine, + so plugins may define ``async def on_error``. + Args: event_name: Event name to fire. data: Payload threaded through the chain. @@ -401,7 +496,7 @@ async def trigger_async(self, event_name: str, data: Any = None) -> Any: if self._hook_tracing: logger.debug(f"Triggering async event '{event_name}' with " f"{len(hooks)} hooks") - errors = [] + errors: List[Dict[str, Any]] = [] result = data for hook_info in hooks: @@ -428,14 +523,14 @@ async def trigger_async(self, event_name: str, data: Any = None) -> Any: callback, result, timeout ) else: - # Run sync hook in executor to avoid blocking - loop = asyncio.get_running_loop() - new_result = await loop.run_in_executor( - None, - self._execute_hook_with_timeout, - callback, - result, - timeout, + # Single-thread dispatch with asyncio-level timeout + # enforcement. The previous implementation called + # run_in_executor -> _execute_hook_with_timeout, which + # itself spun up another ThreadPoolExecutor — pinning + # two threads per timed hook and inheriting the + # shutdown-blocks-on-timeout bug. + new_result = await self._execute_sync_hook_in_executor( + callback, result, timeout ) if self._hook_tracing: @@ -461,11 +556,7 @@ async def trigger_async(self, event_name: str, data: Any = None) -> Any: ) logger.error(error_msg) - if plugin and hasattr(plugin, "on_error"): - try: - plugin.on_error(e) - except Exception as notify_error: - logger.error(f"Error in plugin error handler: {notify_error}") + await self._notify_on_error(plugin, e) if self._error_strategy == "fail_fast": raise HookError(error_msg) from e @@ -485,11 +576,7 @@ async def trigger_async(self, event_name: str, data: Any = None) -> Any: ) logger.error(error_msg) - if plugin and hasattr(plugin, "on_error"): - try: - plugin.on_error(e) - except Exception as notify_error: - logger.error(f"Error in plugin error handler: {notify_error}") + await self._notify_on_error(plugin, e) if self._error_strategy == "fail_fast": raise HookError(error_msg) from e @@ -502,6 +589,7 @@ async def trigger_async(self, event_name: str, data: Any = None) -> Any: } ) + self._last_errors = errors if errors and self._error_strategy == "collect_all": logger.warning(f"Async event '{event_name}' completed with " f"{len(errors)} errors") @@ -529,7 +617,8 @@ def get_all_events(self) -> List[str]: The literal strings used at registration. Wildcard patterns are returned as-is (e.g. ``"user.*"``). """ - return list(self._hooks.keys()) + with self._lock: + return list(self._hooks.keys()) def clear_event(self, event_name: str) -> None: """Remove every hook registered under a single event name. @@ -540,9 +629,10 @@ def clear_event(self, event_name: str) -> None: Args: event_name: Event name to clear. """ - if event_name in self._hooks: - del self._hooks[event_name] - logger.debug(f"Cleared all hooks for event '{event_name}'") + with self._lock: + if event_name in self._hooks: + del self._hooks[event_name] + logger.debug(f"Cleared all hooks for event '{event_name}'") def clear_all(self) -> None: """Remove every registered hook. @@ -550,9 +640,24 @@ def clear_all(self) -> None: Use between tests or when reconfiguring the registry from scratch. """ - self._hooks.clear() + with self._lock: + self._hooks.clear() logger.debug("Cleared all hooks") + def get_last_errors(self) -> List[Dict[str, Any]]: + """Return errors collected during the most recent dispatch. + + Populated when the error strategy is ``"collect_all"``. Each + entry is a dict with keys ``plugin``, ``error``, ``event``. + Cleared at the start of every :meth:`trigger` / + :meth:`trigger_async`. + + Returns: + A list of error records from the last dispatch, possibly + empty. + """ + return list(self._last_errors) + def set_error_strategy(self, strategy: str) -> None: """Choose how hook exceptions are handled during dispatch. @@ -561,8 +666,8 @@ def set_error_strategy(self, strategy: str) -> None: the next hook. - ``"fail_fast"``: raise :class:`HookError` and abort the chain. - - ``"collect_all"``: run every hook, then log a summary of - how many failed. + - ``"collect_all"``: run every hook, then expose collected + errors via :meth:`get_last_errors`. Args: strategy: One of the values above. diff --git a/nitro_dispatch/core/plugin_base.py b/nitro_dispatch/core/plugin_base.py index a39d64f..4f7531e 100644 --- a/nitro_dispatch/core/plugin_base.py +++ b/nitro_dispatch/core/plugin_base.py @@ -202,27 +202,41 @@ def _collect_decorated_hooks(self) -> None: """Gather @hook-decorated methods into ``self._hooks``. Called from ``__init__`` so the manager can register them at load - time. Skips private/magic attributes to avoid unnecessary access. + time. Walks ``type(self).__mro__`` and inspects raw class-dict + entries instead of calling ``getattr(self, ...)``, so subclasses + that define ``@property`` descriptors do NOT have those + descriptors invoked during plugin construction (which happens + during discovery, registration, load, and reload). """ - for attr_name in dir(self): - if attr_name.startswith("_"): + seen: set = set() + for klass in type(self).__mro__: + if klass is object: continue - - try: - attr = getattr(self, attr_name) - except AttributeError: - continue - - if callable(attr) and hasattr(attr, "_is_hook") and attr._is_hook: - event_name = attr._event_name - priority = getattr(attr, "_priority", 50) - timeout = getattr(attr, "_timeout", None) + for attr_name, raw in klass.__dict__.items(): + if attr_name.startswith("_") or attr_name in seen: + continue + # The @hook decorator wraps the function and sets + # ``_is_hook`` on the *function object* itself. Reading + # it from the class dict avoids descriptor invocation + # (properties, classmethods with side effects, etc.). + if not (callable(raw) and getattr(raw, "_is_hook", False)): + continue + seen.add(attr_name) + + # Bind to ``self`` so the registered callback carries + # the instance — matches the previous behavior obtained + # via ``getattr(self, attr_name)``. + bound = getattr(self, attr_name) + + event_name = raw._event_name + priority = getattr(raw, "_priority", 50) + timeout = getattr(raw, "_timeout", None) if event_name not in self._hooks: self._hooks[event_name] = [] self._hooks[event_name].append( { - "callback": attr, + "callback": bound, "priority": priority, "timeout": timeout, } diff --git a/nitro_dispatch/core/plugin_manager.py b/nitro_dispatch/core/plugin_manager.py index 90ef407..55b90bb 100644 --- a/nitro_dispatch/core/plugin_manager.py +++ b/nitro_dispatch/core/plugin_manager.py @@ -5,7 +5,7 @@ import inspect import sys from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Type, Union +from typing import Any, Callable, Dict, List, Optional, Set, Type, Union import logging from .plugin_base import PluginBase @@ -21,6 +21,23 @@ logger = logging.getLogger(__name__) +# Prefix under which ``discover_plugins`` stashes modules in ``sys.modules``. +# Namespacing avoids clobbering stdlib or application modules that happen +# to share a file stem (e.g. a plugin file called ``logging.py``). +_DISCOVERED_MODULE_PREFIX = "nitro_dispatch._discovered." + +# importlib.reload() walks the parent package chain in sys.modules. Register +# a stub for the synthetic ``nitro_dispatch._discovered`` namespace so +# reload() of discovered plugin modules doesn't raise +# ``ImportError: parent 'nitro_dispatch._discovered' not in sys.modules``. +_DISCOVERED_PARENT = "nitro_dispatch._discovered" +if _DISCOVERED_PARENT not in sys.modules: + import types as _types + + _stub = _types.ModuleType(_DISCOVERED_PARENT) + _stub.__path__ = [] # mark as a (namespace) package + sys.modules[_DISCOVERED_PARENT] = _stub + class PluginManager: """Central orchestrator for plugin lifecycle and event dispatch. @@ -90,6 +107,9 @@ def __init__( self._config: Dict[str, Any] = config or {} self._loaded: bool = False self._validate_metadata: bool = validate_metadata + # Names currently mid-``load()``. Used to detect circular + # dependencies before they blow the Python recursion limit. + self._loading: Set[str] = set() logging.basicConfig(level=getattr(logging, log_level.upper())) @@ -208,17 +228,47 @@ def load(self, plugin_name: str) -> PluginBase: logger.warning(f"Plugin '{plugin_name}' already loaded") return self._plugins[plugin_name] + # Cycle detection: if we're already mid-load for this plugin, the + # dependency graph has a cycle. Raise immediately instead of + # recursing until RecursionError. + if plugin_name in self._loading: + chain = " -> ".join(sorted(self._loading) + [plugin_name]) + raise DependencyError( + f"Circular dependency detected while loading " f"'{plugin_name}' (chain: {chain})" + ) + plugin_class = self._plugin_classes[plugin_name] + plugin: Optional[PluginBase] = None + hooks_registered: List[tuple] = [] # (event_name, callback) pairs + self._loading.add(plugin_name) try: plugin = plugin_class() plugin._manager = self + # Re-validate dependencies on the actual load-time instance. + # ``__init__`` may mutate ``self.dependencies`` based on + # environment, so the value seen at register() time can differ. + if not isinstance(plugin.dependencies, list): + raise ValidationError( + f"Plugin '{plugin_name}' dependencies must be a list, " + f"got {type(plugin.dependencies).__name__}" + ) + for dep_name in plugin.dependencies: + if not isinstance(dep_name, str): + raise ValidationError( + f"Plugin '{plugin_name}' has non-string dependency: " f"{dep_name!r}" + ) + for dep_name in plugin.dependencies: if dep_name not in self._plugins: logger.info(f"Loading dependency '{dep_name}' for '{plugin_name}'") try: self.load(dep_name) + except DependencyError: + # Already wrapped (cycle or upstream dep error); + # surface as-is so the chain message survives. + raise except Exception as e: raise DependencyError( f"Failed to load dependency '{dep_name}' for " f"'{plugin_name}': {e}" @@ -227,16 +277,19 @@ def load(self, plugin_name: str) -> PluginBase: for event_name, hook_list in plugin._hooks.items(): for hook_data in hook_list: if isinstance(hook_data, dict): + callback = hook_data["callback"] self.register_hook( event_name, - hook_data["callback"], + callback, plugin, hook_data.get("priority", 50), hook_data.get("timeout"), ) else: # Legacy format: bare callable stored without metadata. - self.register_hook(event_name, hook_data, plugin) + callback = hook_data + self.register_hook(event_name, callback, plugin) + hooks_registered.append((event_name, callback)) plugin.on_load() plugin.enabled = True @@ -252,6 +305,16 @@ def load(self, plugin_name: str) -> PluginBase: return plugin except Exception as e: + # Roll back any partial state so the registry doesn't end up + # with orphan hooks pointing at an unloadable instance. + for event_name, callback in hooks_registered: + try: + self._registry.unregister(event_name, callback, plugin) + except Exception: # pragma: no cover - defensive + pass + if plugin is not None: + plugin._manager = None + error_data = { "plugin_name": plugin_name, "error": str(e), @@ -259,6 +322,8 @@ def load(self, plugin_name: str) -> PluginBase: } self.trigger(self.EVENT_PLUGIN_ERROR, error_data) raise PluginLoadError(f"Failed to load plugin '{plugin_name}': {e}") from e + finally: + self._loading.discard(plugin_name) def load_all(self) -> List[str]: """Load every registered plugin, respecting dependencies. @@ -313,25 +378,47 @@ def unload(self, plugin_name: str) -> None: raise PluginNotFoundError(f"Plugin '{plugin_name}' not loaded") plugin = self._plugins[plugin_name] + on_unload_error: Optional[Exception] = None + # Call on_unload first, but don't let its failure leave the + # plugin half-detached. We capture the exception, run the full + # cleanup, and re-raise at the end. try: plugin.on_unload() - plugin.enabled = False + except Exception as e: + on_unload_error = e + logger.error( + f"Error in on_unload for '{plugin_name}': {e} " f"(proceeding with hook detachment)" + ) + + plugin.enabled = False - for event_name in self._registry.get_all_events(): - hooks = self._registry.get_hooks(event_name) + # Detach hooks unconditionally. Snapshot first — get_hooks() may + # return the live list and we mutate it via unregister() in the + # loop, which would otherwise skip entries. + try: + for event_name in list(self._registry.get_all_events()): + hooks = list(self._registry.get_hooks(event_name)) for hook_info in hooks: - if hook_info["plugin"] == plugin: + if hook_info["plugin"] is plugin: self._registry.unregister(event_name, hook_info["callback"], plugin) + finally: + # Always drop the manager's reference, even if hook removal + # somehow raised. Leaving a stale entry in ``_plugins`` is + # worse than a possibly-orphan registry entry, because every + # subsequent unload() would re-fail on the same plugin. + self._plugins.pop(plugin_name, None) + plugin._manager = None - del self._plugins[plugin_name] - logger.info(f"Unloaded plugin '{plugin_name}'") + logger.info(f"Unloaded plugin '{plugin_name}'") + try: self.trigger(self.EVENT_PLUGIN_UNLOADED, {"plugin_name": plugin_name}) - except Exception as e: - logger.error(f"Error unloading plugin '{plugin_name}': {e}") - raise + logger.error(f"Error firing EVENT_PLUGIN_UNLOADED for '{plugin_name}': {e}") + + if on_unload_error is not None: + raise on_unload_error def unload_all(self) -> None: """Unload every currently-loaded plugin. @@ -373,28 +460,79 @@ def reload(self, plugin_name: str) -> PluginBase: logger.info(f"Reloading plugin '{plugin_name}'") + plugin_class = self._plugin_classes[plugin_name] + module_name = getattr(plugin_class, "__module__", None) + + # Find every other registered plugin that lives in the same + # module — they all become stale when importlib.reload() runs. + sibling_names = [ + name + for name, cls in self._plugin_classes.items() + if name != plugin_name and getattr(cls, "__module__", None) == module_name + ] + # Track which siblings were loaded so we can restore them. + previously_loaded_siblings = [n for n in sibling_names if n in self._plugins] + + # Unload the target first, then any loaded siblings, so the + # module reload doesn't strand instances of dead classes. if plugin_name in self._plugins: self.unload(plugin_name) + for sibling in previously_loaded_siblings: + try: + self.unload(sibling) + except Exception as e: + logger.error(f"Error unloading sibling '{sibling}' during reload: {e}") + + if module_name and module_name in sys.modules: + logger.debug(f"Reloading module '{module_name}'") + existing = sys.modules[module_name] + spec = getattr(existing, "__spec__", None) + # Modules loaded via spec_from_file_location (discover_plugins + # and the ad-hoc case) can't always be reloaded through + # importlib.reload because their parent package isn't a real + # package on disk. Re-exec the saved spec instead when the + # module has a file origin; fall back to importlib.reload + # otherwise. + if ( + spec is not None + and getattr(spec, "origin", None) + and spec.loader is not None + and module_name.startswith(_DISCOVERED_MODULE_PREFIX) + ): + new_module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = new_module + try: + spec.loader.exec_module(new_module) + except Exception: + # Restore the previous module on failure so we don't + # leave sys.modules in a worse state than we found it. + sys.modules[module_name] = existing + raise + reloaded_module = new_module + else: + reloaded_module = importlib.reload(existing) + + # importlib.reload replaces the module's classes with new + # objects. Refresh our stored class reference for the target + # AND every sibling, so subsequent load() calls instantiate + # the new code rather than the pre-reload classes. + # Read ``name`` from the class dict to avoid running __init__ + # on every PluginBase subclass in the module. + refresh_targets = set(sibling_names) | {plugin_name} + for _, obj in inspect.getmembers(reloaded_module, inspect.isclass): + if not issubclass(obj, PluginBase) or obj is PluginBase: + continue + candidate_name = obj.name if obj.__dict__.get("name") else obj.__name__ + if candidate_name in refresh_targets: + self._plugin_classes[candidate_name] = obj - plugin_class = self._plugin_classes[plugin_name] - if hasattr(plugin_class, "__module__"): - module_name = plugin_class.__module__ - if module_name in sys.modules: - logger.debug(f"Reloading module '{module_name}'") - reloaded_module = importlib.reload(sys.modules[module_name]) - - # importlib.reload replaces the module's classes with new - # objects. Refresh our stored class reference so the subsequent - # load() instantiates the new code, not the pre-reload class. - # Read the name from class attrs to avoid running __init__ on - # every PluginBase subclass in the module during reload. - for _, obj in inspect.getmembers(reloaded_module, inspect.isclass): - if not issubclass(obj, PluginBase) or obj is PluginBase: - continue - candidate_name = obj.name if obj.__dict__.get("name") else obj.__name__ - if candidate_name == plugin_name: - self._plugin_classes[plugin_name] = obj - break + # Reload any previously-loaded siblings before the target, so + # callers see the same loaded set after reload() returns. + for sibling in previously_loaded_siblings: + try: + self.load(sibling) + except Exception as e: + logger.error(f"Error reloading sibling '{sibling}': {e}") return self.load(plugin_name) @@ -450,13 +588,30 @@ def discover_plugins( if not plugin_file.is_file(): continue + # Namespace the sys.modules key so a plugin file whose + # stem collides with a stdlib or app module (e.g. + # ``logging.py``) doesn't silently clobber the real one. + # Also include the file's absolute path hash so two + # discovered plugins with the same stem in different + # directories don't clash. + module_name = ( + f"{_DISCOVERED_MODULE_PREFIX}{plugin_file.stem}_" + f"{abs(hash(str(plugin_file)))}" + ) + try: - module_name = plugin_file.stem spec = importlib.util.spec_from_file_location(module_name, plugin_file) if spec and spec.loader: module = importlib.util.module_from_spec(spec) sys.modules[module_name] = module - spec.loader.exec_module(module) + try: + spec.loader.exec_module(module) + except Exception: + # Don't leave a half-executed module in + # sys.modules — later reload() / imports of + # the same path will get the broken object. + sys.modules.pop(module_name, None) + raise for name, obj in inspect.getmembers(module, inspect.isclass): if ( diff --git a/tests/test_critical_fixes.py b/tests/test_critical_fixes.py new file mode 100644 index 0000000..568ebd4 --- /dev/null +++ b/tests/test_critical_fixes.py @@ -0,0 +1,424 @@ +"""Regression tests for the critical issues identified in code review. + +Each test pins the fix for a specific issue so the bug can't silently +reappear. Reference IDs (#1-#11) match the review report ordering. +""" + +import asyncio +import os +import sys +import threading +import time + +import pytest + +from nitro_dispatch import PluginBase, PluginManager, hook +from nitro_dispatch.core.exceptions import ( + DependencyError, + HookTimeoutError, + PluginLoadError, +) +from nitro_dispatch.core.hook_registry import HookRegistry + + +# --------------------------------------------------------------------------- +# #1 — Sync timeout is actually enforced (no shutdown-blocks-on-timeout bug) +# --------------------------------------------------------------------------- + + +def test_sync_timeout_does_not_block_on_runaway_callback(): + """A sync hook that ignores its timeout must NOT pin the caller. + + Previously the per-call ThreadPoolExecutor's __exit__ joined the + worker, so a runaway callback hung the caller despite the timeout. + """ + reg = HookRegistry() + + def slow(_data): + time.sleep(5.0) + return _data + + reg.register("evt", slow, priority=50, timeout=0.05) + reg.set_error_strategy("fail_fast") + + start = time.time() + with pytest.raises(Exception): # wrapped HookTimeoutError -> HookError + reg.trigger("evt", {}) + elapsed = time.time() - start + + # Should return within ~timeout, never the full 5s sleep. Generous + # slack for CI scheduling, but well below the runaway sleep. + assert elapsed < 1.0, f"trigger blocked for {elapsed:.2f}s; timeout broken" + + +# --------------------------------------------------------------------------- +# #2 — load() rolls back partial hook registration when on_load() raises +# --------------------------------------------------------------------------- + + +class _BoomOnLoad(PluginBase): + name = "boom" + + @hook("some.event") + def react(self, data): + return data + + def on_load(self): + raise RuntimeError("kaboom") + + +def test_load_failure_does_not_leak_hooks_into_registry(): + mgr = PluginManager() + mgr.register(_BoomOnLoad) + + with pytest.raises(PluginLoadError): + mgr.load("boom") + + # The registry must not retain a hook bound to the failed instance. + matching = mgr._registry.get_hooks("some.event") + assert matching == [], f"orphan hooks remained: {matching}" + + # And triggering must not invoke the dead hook. + result = mgr.trigger("some.event", {"x": 1}) + assert result == {"x": 1} + + +# --------------------------------------------------------------------------- +# #3 — unload() is exception-safe even when on_unload() raises +# --------------------------------------------------------------------------- + + +class _BoomOnUnload(PluginBase): + name = "boom_unload" + + @hook("u.event") + def react(self, data): + return data + + def on_unload(self): + raise RuntimeError("unload boom") + + +def test_unload_cleans_up_even_when_on_unload_raises(): + mgr = PluginManager() + mgr.register(_BoomOnUnload) + mgr.load("boom_unload") + assert mgr.is_loaded("boom_unload") + + with pytest.raises(RuntimeError, match="unload boom"): + mgr.unload("boom_unload") + + # Must be fully detached despite the on_unload failure. + assert not mgr.is_loaded("boom_unload") + assert mgr._registry.get_hooks("u.event") == [] + # Subsequent unload should not crash on stale state — it should + # report not-loaded cleanly. + from nitro_dispatch.core.exceptions import PluginNotFoundError + + with pytest.raises(PluginNotFoundError): + mgr.unload("boom_unload") + + +# --------------------------------------------------------------------------- +# #4 — Circular dependencies raise immediately, not via RecursionError +# --------------------------------------------------------------------------- + + +class _PluginA(PluginBase): + name = "cyc_a" + dependencies = ["cyc_b"] + + +class _PluginB(PluginBase): + name = "cyc_b" + dependencies = ["cyc_a"] + + +def test_circular_dependency_raises_dependency_error_not_recursion(): + mgr = PluginManager() + mgr.register(_PluginA) + mgr.register(_PluginB) + + with pytest.raises(PluginLoadError) as excinfo: + mgr.load("cyc_a") + + # The chained cause should be a DependencyError mentioning the cycle. + chain_messages = [] + err = excinfo.value + while err is not None: + chain_messages.append(str(err)) + err = err.__cause__ + joined = " | ".join(chain_messages).lower() + assert "circular" in joined, f"expected cycle message, got: {joined}" + + +# --------------------------------------------------------------------------- +# #5 — unload() detaches ALL hooks even when a plugin registered multiple +# for the same event (mutation-during-iteration regression). +# --------------------------------------------------------------------------- + + +class _MultiHook(PluginBase): + name = "multi" + + @hook("multi.event", priority=10) + def low(self, data): + return data + + @hook("multi.event", priority=20) + def mid(self, data): + return data + + @hook("multi.event", priority=30) + def high(self, data): + return data + + +def test_unload_removes_all_hooks_for_same_event(): + mgr = PluginManager() + mgr.register(_MultiHook) + mgr.load("multi") + assert len(mgr._registry.get_hooks("multi.event")) == 3 + + mgr.unload("multi") + assert mgr._registry.get_hooks("multi.event") == [] + + +# --------------------------------------------------------------------------- +# #6 — reload() refreshes every class in a multi-class module +# --------------------------------------------------------------------------- + + +def test_reload_refreshes_sibling_classes_in_same_module(tmp_path): + plugin_file = tmp_path / "twoclass_plugin.py" + plugin_file.write_text( + "from nitro_dispatch import PluginBase, hook\n" + "class First(PluginBase):\n" + " name = 'first'\n" + " version = '1.0.0'\n" + " @hook('e1')\n" + " def h(self, d):\n" + " return 'first-v1'\n" + "class Second(PluginBase):\n" + " name = 'second'\n" + " version = '1.0.0'\n" + " @hook('e2')\n" + " def h(self, d):\n" + " return 'second-v1'\n" + ) + + mgr = PluginManager() + mgr.discover_plugins(str(tmp_path), pattern="*_plugin.py") + mgr.load("first") + mgr.load("second") + + assert mgr.trigger("e1", None) == "first-v1" + assert mgr.trigger("e2", None) == "second-v1" + + # Rewrite both classes. Bump mtime past fs resolution to be safe. + plugin_file.write_text( + "from nitro_dispatch import PluginBase, hook\n" + "class First(PluginBase):\n" + " name = 'first'\n" + " version = '2.0.0'\n" + " @hook('e1')\n" + " def h(self, d):\n" + " return 'first-v2'\n" + "class Second(PluginBase):\n" + " name = 'second'\n" + " version = '2.0.0'\n" + " @hook('e2')\n" + " def h(self, d):\n" + " return 'second-v2'\n" + ) + future_mtime = time.time() + 2 + os.utime(plugin_file, (future_mtime, future_mtime)) + + mgr.reload("first") + + # Both first AND second must run the new code, even though only + # 'first' was reloaded explicitly. + assert mgr.trigger("e1", None) == "first-v2" + assert ( + mgr.trigger("e2", None) == "second-v2" + ), "sibling class in the same reloaded module is stale" + + +# --------------------------------------------------------------------------- +# #7 — discover_plugins() must not clobber sys.modules entries by stem +# --------------------------------------------------------------------------- + + +def test_discover_does_not_clobber_sys_modules_by_stem(tmp_path): + # Pick a stem that overlaps with the stdlib — if discover_plugins + # inserts under that bare key, the real module disappears from + # sys.modules. + plugin_file = tmp_path / "logging_plugin.py" + plugin_file.write_text( + "from nitro_dispatch import PluginBase\n" + "class LP(PluginBase):\n" + " name = 'lp'\n" + " version = '1.0.0'\n" + ) + + original_logging = sys.modules.get("logging") + + mgr = PluginManager() + mgr.discover_plugins(str(tmp_path), pattern="*_plugin.py") + + # The real `logging` module must still be the one in sys.modules. + assert sys.modules.get("logging") is original_logging + + # The discovered module must live under the namespaced prefix. + discovered_keys = [k for k in sys.modules if k.startswith("nitro_dispatch._discovered.")] + assert any( + "logging_plugin" in k for k in discovered_keys + ), f"discovered module not under namespaced prefix; keys: {discovered_keys}" + + +# --------------------------------------------------------------------------- +# #8 — async on_error coroutines are awaited under trigger_async +# --------------------------------------------------------------------------- + + +class _AsyncOnErrorPlugin(PluginBase): + name = "async_err" + on_error_called_with = None + + @hook("err.event") + def boom(self, data): + raise RuntimeError("hook failed") + + async def on_error(self, error): + # Suspension proves we were truly awaited. + await asyncio.sleep(0) + type(self).on_error_called_with = error + + +@pytest.mark.asyncio +async def test_async_on_error_is_awaited(): + _AsyncOnErrorPlugin.on_error_called_with = None + mgr = PluginManager() + mgr.register(_AsyncOnErrorPlugin) + mgr.load("async_err") + + await mgr.trigger_async("err.event", {}) + + assert isinstance(_AsyncOnErrorPlugin.on_error_called_with, RuntimeError) + + +# --------------------------------------------------------------------------- +# #9 — collect_all errors are programmatically retrievable +# --------------------------------------------------------------------------- + + +def test_collect_all_errors_are_retrievable(): + reg = HookRegistry() + reg.set_error_strategy("collect_all") + + def good(data): + return data + 1 + + def bad(_data): + raise ValueError("nope") + + reg.register("e", good, priority=100) + reg.register("e", bad, priority=50) + + out = reg.trigger("e", 1) + assert out == 2 # good ran, bad failed silently in collect_all + + errors = reg.get_last_errors() + assert len(errors) == 1 + assert errors[0]["event"] == "e" + assert isinstance(errors[0]["error"], ValueError) + + +# --------------------------------------------------------------------------- +# #10 — registry is thread-safe across concurrent register/dispatch +# --------------------------------------------------------------------------- + + +def test_concurrent_register_and_trigger_does_not_raise(): + """Hammer register and trigger from multiple threads. + + Previously _get_matching_hooks iterated self._hooks.items() without + a lock, so a concurrent register could raise + 'dictionary changed size during iteration'. + """ + reg = HookRegistry() + reg.set_error_strategy("log_and_continue") + stop = threading.Event() + errors: list = [] + + def producer(): + i = 0 + while not stop.is_set(): + try: + reg.register(f"e.{i % 5}", lambda d: d) + i += 1 + except Exception as e: # pragma: no cover - regression guard + errors.append(e) + + def consumer(): + while not stop.is_set(): + try: + reg.trigger("e.1", None) + reg.trigger("e.2", None) + except Exception as e: # pragma: no cover - regression guard + errors.append(e) + + threads = [threading.Thread(target=producer) for _ in range(2)] + threads += [threading.Thread(target=consumer) for _ in range(2)] + for t in threads: + t.start() + time.sleep(0.3) + stop.set() + for t in threads: + t.join() + + assert not errors, f"thread-safety regression: {errors[:3]}" + + +# --------------------------------------------------------------------------- +# #11 — _collect_decorated_hooks does not invoke @property descriptors +# --------------------------------------------------------------------------- + + +def test_property_descriptors_are_not_invoked_during_init(): + invocations = [] + + class HasProperty(PluginBase): + name = "hp" + + @property + def expensive(self): + invocations.append("called") + return 42 + + @hook("hp.event") + def react(self, data): + return data + + inst = HasProperty() # noqa: F841 — instantiation is the test + assert invocations == [], f"property accessed during __init__: {invocations}" + + +# --------------------------------------------------------------------------- +# #12 — wildcard `*` requires a non-empty segment +# --------------------------------------------------------------------------- + + +def test_wildcard_does_not_match_empty_segment(): + reg = HookRegistry() + hits = [] + + def cb(data): + hits.append(data) + return data + + reg.register("user.*", cb) + reg.trigger("user.", "empty-segment") + reg.trigger("user.login", "good") + + assert hits == ["good"], f"wildcard matched empty segment: {hits}" diff --git a/tests/test_plugin_manager.py b/tests/test_plugin_manager.py index 78cf788..a08c63a 100644 --- a/tests/test_plugin_manager.py +++ b/tests/test_plugin_manager.py @@ -443,6 +443,7 @@ def respond(self, data): # Bump mtime past filesystem's 1s resolution so importlib doesn't # treat the cached bytecode as still-fresh. import os + future = plugin_file.stat().st_mtime + 2 os.utime(plugin_file, (future, future))