From 421a94126e6b100b4bfb72291f248a9bd0281429 Mon Sep 17 00:00:00 2001 From: Seth R Johnson Date: Thu, 7 May 2026 07:07:32 -0400 Subject: [PATCH] Improve type hints and process return type --- celerpy/visualize.py | 35 +++++++++++++++++++++++++---------- 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/celerpy/visualize.py b/celerpy/visualize.py index 5aa0faa..216899b 100644 --- a/celerpy/visualize.py +++ b/celerpy/visualize.py @@ -7,7 +7,7 @@ import json import re import warnings -from collections.abc import Iterable, Mapping, MutableSequence +from collections.abc import Iterable, MutableSequence from importlib.resources import files from pathlib import Path from subprocess import TimeoutExpired @@ -174,7 +174,7 @@ def trace( *, geometry: GeometryEngine | None = None, **kwargs, - ): + ) -> tuple[TraceOutput, np.ndarray]: """Trace with a geometry, memspace, etc.""" if image is None and not self.image: raise RuntimeError( @@ -216,7 +216,7 @@ def trace( return (result, npimg) - def close(self, *, timeout: float = 0.25) -> dict[str, dict] | str: + def close(self, *, timeout: float = 0.25) -> dict[str, Any]: """Cleanly exit the ray trace loop, returning run statistics if possible. """ @@ -225,9 +225,21 @@ def close(self, *, timeout: float = 0.25) -> dict[str, dict] | str: self.process.wait(timeout=timeout) result = result or process.close(self.process, timeout=timeout) - with contextlib.suppress(json.JSONDecodeError): - result = json.loads(result) - return result + if result is None: + return { + "result": None, + "returncode": self.process.returncode, + } + + try: + result_dict = json.loads(result) + except json.JSONDecodeError as e: + result_dict = { + "result": result, + "decode_error": str(e), + } + + return result_dict class ReverseIndexDict(collections.defaultdict): @@ -313,6 +325,7 @@ def __call__( ax.set_ylabel(y.label) ax.set_ylim((y.lo, y.hi)) tr = trace_output.trace + assert tr.geometry is not None and tr.memspace is not None ax.set_title(f"{tr.geometry.name} ({tr.memspace.name})") # Remap volume IDs and volume names into a persistent 0-based list @@ -355,8 +368,8 @@ def plot_all_geometry( *, colorbar: bool = True, figsize: tuple | None = None, - engines: Iterable | None = None, -) -> Mapping[GeometryEngine, Any]: + engines: Iterable[GeometryEngine] | None = None, +) -> dict: """Convenience function for plotting all available geometry types.""" if engines is None: engines = GeometryEngine @@ -371,8 +384,8 @@ def plot_all_geometry( figsize=figsize, gridspec_kw=dict(width_ratios=width_ratios), ) - result = {} - all_cbar: list[Any] = [False] * len(engines) + result: dict = {"fig": fig} + all_cbar: list = [False] * len(engines) if colorbar: all_cbar[:0] = [all_ax[-1]] @@ -381,6 +394,8 @@ def plot_all_geometry( result[g] = trace_image(ax, geometry=g, colorbar=cb) except Exception as e: warnings.warn(f"Failed to trace {g} geometry: {e!s}", stacklevel=1) + if colorbar: + result["colorbar"] = result.get(engines[-1], {}).get("colorbar") return result