Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions PySpotObserver/pyspotobserver/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from dataclasses import dataclass, field
from enum import IntFlag
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from typing import Any

import yaml


Expand Down Expand Up @@ -73,12 +74,16 @@ class SpotConfig:
request_timeout_seconds: float = 10.0
"""Timeout for image requests"""

vision_model_path: Optional[str] = None
vision_model_path: str | None = None
"""Optional ONNX model path for run_pipeline=True"""

vision_providers: Optional[List[str]] = None
vision_providers: list[str] | None = None
"""Optional ONNX Runtime provider preference order"""

ema_beta: float | None = None
"""Optional decay factor beta in (0, 1) for EMA smoothing over output depth maps.
EMA_t = beta * EMA_{t-1} + (1 - beta) * x_t. Higher beta retains more history."""

# Advanced settings
sdk_name: str = "PySpotObserver"
"""Name to identify this SDK client"""
Expand All @@ -89,15 +94,15 @@ class SpotConfig:
connection_retry_delay_ms: int = 100
"""Delay between connection retry attempts (milliseconds)"""

extra_params: Dict[str, Any] = field(default_factory=dict)
extra_params: dict[str, Any] = field(default_factory=dict)
"""Additional user-defined parameters"""

def __post_init__(self) -> None:
if self.request_timeout_seconds <= 0:
raise ValueError("request_timeout_seconds must be positive")

@classmethod
def from_yaml(cls, yaml_path: Union[Path, str]) -> "SpotConfig":
def from_yaml(cls, yaml_path: Path | str) -> "SpotConfig":
"""
Load configuration from a YAML file.

Expand All @@ -115,15 +120,15 @@ def from_yaml(cls, yaml_path: Union[Path, str]) -> "SpotConfig":
if not path.exists():
raise FileNotFoundError(f"Config file not found: {path}")

with open(path, "r") as f:
with open(path) as f:
data = yaml.safe_load(f)

if data is None:
data = {}

return cls(**data)

def to_yaml(self, yaml_path: Union[Path, str]) -> None:
def to_yaml(self, yaml_path: Path | str) -> None:
"""
Save configuration to a YAML file.

Expand Down
89 changes: 59 additions & 30 deletions PySpotObserver/pyspotobserver/vision_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,14 @@

import os
import threading
from collections.abc import Sequence
from pathlib import Path
from typing import List, Optional, Sequence, Tuple, Union

import cv2
import numpy as np

from .config import SpotConfig


DEFAULT_MODEL_ENV_VAR = "PYSPOTOBSERVER_VISION_MODEL"
DEFAULT_PROVIDERS = ("CUDAExecutionProvider", "CPUExecutionProvider")
DEFAULT_DEPTH_SIZE = (120, 160) # (height, width)
Expand All @@ -24,7 +23,7 @@ class VisionPipelineError(Exception):
"""Raised when the optional vision pipeline cannot run."""


def _normalize_providers(providers: Optional[Union[Sequence[str], str]]) -> Tuple[str, ...]:
def _normalize_providers(providers: Sequence[str] | str | None) -> tuple[str, ...]:
if providers is None:
return DEFAULT_PROVIDERS
if isinstance(providers, str):
Expand All @@ -41,7 +40,7 @@ def _dtype_for_onnx_type(type_name: str) -> np.dtype:
return np.dtype(np.float32)


def _depth_list_from_output(output: np.ndarray, batch_size: int) -> List[np.ndarray]:
def _depth_list_from_output(output: np.ndarray, batch_size: int) -> list[np.ndarray]:
output = np.asarray(output)

if batch_size == 1 and output.ndim == 2:
Expand Down Expand Up @@ -71,27 +70,30 @@ class VisionPipeline:

def __init__(
self,
model_path: Union[str, os.PathLike[str]],
providers: Optional[Union[Sequence[str], str]] = None,
depth_size: Tuple[int, int] = DEFAULT_DEPTH_SIZE,
model_path: str | os.PathLike[str],
providers: Sequence[str] | str | None = None,
depth_size: tuple[int, int] = DEFAULT_DEPTH_SIZE,
ema_beta: float | None = None,
):
self.model_path = Path(model_path).expanduser()
self.providers = _normalize_providers(providers)
self.depth_size = depth_size
self.ema_beta = ema_beta
self._lock = threading.Lock()
self._session = None
self._input_names: List[str] = []
self._input_names: list[str] = []
self._rgb_dtype = np.dtype(np.float32)
self._depth_dtype = np.dtype(np.float32)
self._rgb_buffer: Optional[np.ndarray] = None
self._depth_buffer: Optional[np.ndarray] = None
self._depth_resize_buffer: Optional[np.ndarray] = None
self._rgb_buffer: np.ndarray | None = None
self._depth_buffer: np.ndarray | None = None
self._depth_resize_buffer: np.ndarray | None = None
self._ema_state: list[np.ndarray | None] | None = None

if not self.model_path.exists():
raise VisionPipelineError(f"Vision model not found: {self.model_path}")

@classmethod
def from_config(cls, config: SpotConfig) -> "VisionPipeline":
def from_config(cls, config: SpotConfig) -> VisionPipeline:
extra_params = config.extra_params or {}
model_path = (
config.vision_model_path
Expand All @@ -114,17 +116,21 @@ def from_config(cls, config: SpotConfig) -> "VisionPipeline":
if len(depth_size) != 2:
raise VisionPipelineError("vision_depth_size must contain height and width")

if config.ema_beta is not None and not (0.0 < config.ema_beta < 1.0):
raise ValueError(f"ema_beta must be in (0, 1), got: {config.ema_beta}")

return cls(
model_path=model_path,
providers=providers,
depth_size=(int(depth_size[0]), int(depth_size[1])),
ema_beta=config.ema_beta,
)

def run(
self,
rgb_images: List[np.ndarray],
depth_images: List[np.ndarray],
) -> Tuple[List[np.ndarray], List[np.ndarray]]:
rgb_images: list[np.ndarray],
depth_images: list[np.ndarray],
) -> tuple[list[np.ndarray], list[np.ndarray]]:
if not rgb_images or not depth_images:
raise VisionPipelineError("Vision pipeline requires at least one RGB/depth pair")
if len(rgb_images) != len(depth_images):
Expand All @@ -145,7 +151,34 @@ def run(
},
)[0]

return rgb_images, _depth_list_from_output(output, len(rgb_images))
depth_maps = _depth_list_from_output(output, len(rgb_images))
if self.ema_beta is not None:
depth_maps = self._apply_ema(depth_maps)

return rgb_images, depth_maps

def _apply_ema(self, depth_maps: list[np.ndarray]) -> list[np.ndarray]:
"""Apply EMA smoothing per camera slot in the batch.

Each index in the batch corresponds to a fixed camera, so history is
tracked independently per slot. If the batch size changes the state
is reset.
"""
if self._ema_state is None or len(self._ema_state) != len(depth_maps):
self._ema_state = [None] * len(depth_maps)

result = []
for i, depth in enumerate(depth_maps):
depth_f32 = depth.astype(np.float32)
if self._ema_state[i] is None:
self._ema_state[i] = depth_f32
else:
self._ema_state[i] = (
self.ema_beta * self._ema_state[i] + (1.0 - self.ema_beta) * depth_f32
)
result.append(self._ema_state[i].copy())

return result

def _init_session(self) -> None:
if self._session is not None:
Expand Down Expand Up @@ -200,14 +233,10 @@ def _ensure_buffers(
rgb_shape = (batch_size, 3, h_rgb, w_rgb)
depth_shape = (batch_size, 1, depth_h, depth_w)

if self._rgb_buffer is None or self._rgb_buffer.shape != rgb_shape:
self._rgb_buffer = np.empty(rgb_shape, dtype=self._rgb_dtype)
elif self._rgb_buffer.dtype != self._rgb_dtype:
if self._rgb_buffer is None or self._rgb_buffer.shape != rgb_shape or self._rgb_buffer.dtype != self._rgb_dtype:
self._rgb_buffer = np.empty(rgb_shape, dtype=self._rgb_dtype)

if self._depth_buffer is None or self._depth_buffer.shape != depth_shape:
self._depth_buffer = np.empty(depth_shape, dtype=self._depth_dtype)
elif self._depth_buffer.dtype != self._depth_dtype:
if self._depth_buffer is None or self._depth_buffer.shape != depth_shape or self._depth_buffer.dtype != self._depth_dtype:
self._depth_buffer = np.empty(depth_shape, dtype=self._depth_dtype)

if self._depth_dtype == np.dtype(np.float16):
Expand Down Expand Up @@ -257,19 +286,19 @@ def _fill_buffers(
depth_dst[...] = resize_dst


_default_pipeline: Optional[VisionPipeline] = None
_default_pipeline_key: Optional[Tuple[str, Tuple[str, ...], Tuple[int, int]]] = None
_default_pipeline: VisionPipeline | None = None
_default_pipeline_key: tuple[str, tuple[str, ...], tuple[int, int]] | None = None
_default_pipeline_lock = threading.Lock()


def run_vision_pipeline(
rgb_images: List[np.ndarray],
depth_images: List[np.ndarray],
rgb_images: list[np.ndarray],
depth_images: list[np.ndarray],
*,
model_path: Optional[str] = None,
providers: Optional[Union[Sequence[str], str]] = None,
depth_size: Tuple[int, int] = DEFAULT_DEPTH_SIZE,
) -> Tuple[List[np.ndarray], List[np.ndarray]]:
model_path: str | None = None,
providers: Sequence[str] | str | None = None,
depth_size: tuple[int, int] = DEFAULT_DEPTH_SIZE,
) -> tuple[list[np.ndarray], list[np.ndarray]]:
"""
Backwards-compatible functional entry point using a cached default pipeline.
"""
Expand Down
Loading