diff --git a/packages/device-connect-agent-tools/pyproject.toml b/packages/device-connect-agent-tools/pyproject.toml index da69ac7..9b603d0 100644 --- a/packages/device-connect-agent-tools/pyproject.toml +++ b/packages/device-connect-agent-tools/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "setuptools.build_meta" [project] name = "device-connect-agent-tools" -version = "0.2.3" +version = "0.2.4" description = "Framework-agnostic tools for Device Connect — discover and invoke IoT devices over NATS/Zenoh" readme = "README.md" requires-python = ">=3.11" diff --git a/packages/device-connect-edge/device_connect_edge/device.py b/packages/device-connect-edge/device_connect_edge/device.py index c776d1a..b4d52d9 100644 --- a/packages/device-connect-edge/device_connect_edge/device.py +++ b/packages/device-connect-edge/device_connect_edge/device.py @@ -61,6 +61,7 @@ async def capture_image(self, resolution: str = "1080p") -> dict: import json import logging import os +import random import re import time import uuid @@ -90,6 +91,44 @@ async def capture_image(self, resolution: str = "1080p") -> dict: logger = logging.getLogger(__name__) +def _env_float(name: str, default: float) -> float: + """Best-effort float env-var parser; falls back to default on garbage.""" + raw = os.getenv(name) + if raw is None or raw == "": + return default + try: + return float(raw) + except ValueError: + return default + + +# Registration knobs. At fleet scale a 2s request timeout combined with N +# phones starting in lockstep produces congestion collapse on the +# registry: every queued-but-late reply triggers a retry that re-enters +# the queue. A larger timeout lets the registry catch up before any +# retry fires, and an up-front jitter spreads the initial herd so the +# registry never sees a synchronized burst in the first place. Both are +# env-tunable (jitter can be disabled by setting it to 0). The 2s jitter +# default is a compromise: it decorrelates ~1000 devices into ~500/sec +# (much better than lockstep) while staying tolerable for single-device +# development. Operators at fleet scale should bump this via +# DEVICE_CONNECT_REGISTER_JITTER=10 (or higher) to spread the herd +# further. +# +# Lease-TTL interaction: the registry creates the etcd lease at the +# moment _do_register runs, so if a slow registry takes ~timeout +# seconds to reply, the lease can be near-expired before the heartbeat +# loop emits its first beat (`run()` awaits _register before starting +# the heartbeat task). With the 15s timeout default and the 15s `ttl` +# default that race is real; it self-heals — the next heartbeat fires +# `has_lease()=False` on the registry and triggers a requestRegistration +# round-trip — but operators raising DEVICE_CONNECT_REGISTER_TIMEOUT +# (or running a stressed registry where requests routinely take >ttl/3) +# should raise `ttl` in lockstep or shorten `heartbeat_interval` so +# the first beat lands inside the lease window. +_REGISTER_REQUEST_TIMEOUT = _env_float("DEVICE_CONNECT_REGISTER_TIMEOUT", 15.0) +_REGISTER_STARTUP_JITTER = _env_float("DEVICE_CONNECT_REGISTER_JITTER", 2.0) + def build_rpc_response(id_: str, result: Any) -> bytes: return json.dumps({"jsonrpc": "2.0", "id": id_, "result": result}).encode() @@ -974,6 +1013,19 @@ async def _register(self, force: bool = False) -> None: self._logger.debug("Registration completed by another task, skipping") return + # Spread the herd. With 1000+ phones spinning up in lockstep + # the registry sees a single synchronized burst that times + # out most callers and amplifies into a retry storm. A small + # randomized delay before the first request decorrelates the + # arrivals; subsequent retries already have exponential + # backoff so we only jitter once per _register call. + if _REGISTER_STARTUP_JITTER > 0: + jitter = random.uniform(0, _REGISTER_STARTUP_JITTER) + self._logger.debug( + "Pre-registration jitter: sleeping %.2fs before first request", jitter, + ) + await asyncio.sleep(jitter) + delay = 1 # initial retry delay in seconds while True: req_id = f"{self.device_id}-{int(time.time()*1000)}" @@ -983,7 +1035,7 @@ async def _register(self, force: bool = False) -> None: response_data = await self.messaging.request( f"device-connect.{self.tenant}.registry", json.dumps({"jsonrpc": "2.0", "id": req_id, "method": "registerDevice", "params": params}).encode(), - timeout=2, + timeout=_REGISTER_REQUEST_TIMEOUT, ) self._handle_registration_reply(response_data) # Note: device/online event is published by the registry service @@ -1762,7 +1814,10 @@ async def _setup_agentic_driver(self) -> None: if not isinstance(self._driver, DeviceDriver): return - self._logger.info("Setting up DeviceDriver D2D capabilities") + self._logger.info( + "Setting up DeviceDriver inter-device messaging " + "(router, registry, @on subscriptions)" + ) # Create and set D2D router (inline — no orchestration dependency). router = _RemoteInvoker( @@ -1796,7 +1851,11 @@ async def _setup_agentic_driver(self) -> None: # Set up event subscriptions await self._driver.setup_subscriptions() - self._logger.info("DeviceDriver D2D setup complete") + registry_kind = "D2DRegistry" if self._d2d_mode else "RegistryClient" + self._logger.info( + "DeviceDriver inter-device messaging ready (registry=%s)", + registry_kind, + ) async def _teardown_agentic_driver(self) -> None: """Teardown DeviceDriver subscriptions if applicable.""" @@ -1825,10 +1884,30 @@ async def _resubscribe_after_reconnect(self) -> None: Uses ``_subscription_lock`` to prevent concurrent invocations from rapid reconnects. """ - if not self._subscription_lock.acquire_nowait(): + # Review notes (do not re-litigate without reading these): + # + # 1. ``asyncio.Lock`` does NOT have ``acquire_nowait()``. That + # was a latent bug in the original implementation — the + # method only exists on ``threading.Lock``. At fleet scale + # during a reconnect storm it raised ``AttributeError`` on + # every reconnect and silently killed @on resubscription. + # See commit 1716f8d. + # + # 2. The ``locked() then await acquire()`` pattern below looks + # like a TOCTOU race but is safe under single-loop asyncio: + # ``Lock.locked()`` is synchronous and ``Lock.acquire()`` + # has a fast path that returns without yielding when the + # lock is free. Two concurrent callers cannot both observe + # ``locked() is False`` between the check and the take + # because there is no event-loop yield in that window. + # If you switch to a multi-loop primitive (anyio, trio, + # threading) this assumption breaks — use ``wait_for(..., + # timeout=0)`` over ``acquire()`` instead. + if self._subscription_lock.locked(): self._logger.debug("Subscription re-establishment already in progress, skipping") return + await self._subscription_lock.acquire() try: delay = 1 while True: diff --git a/packages/device-connect-edge/device_connect_edge/drivers/base.py b/packages/device-connect-edge/device_connect_edge/drivers/base.py index 73a596b..93d2a00 100644 --- a/packages/device-connect-edge/device_connect_edge/drivers/base.py +++ b/packages/device-connect-edge/device_connect_edge/drivers/base.py @@ -1049,16 +1049,47 @@ async def wait_for_device( def _collect_event_subscriptions(self) -> List[Dict[str, Any]]: """Collect all @on decorated methods. + Scans single-underscore-prefixed methods as well as public ones so + drivers can keep ``@on`` handlers conventionally private without + them silently becoming no-ops. Dunders are still skipped. + Returns: List of subscription definitions + + Review notes (do not re-litigate without reading): + - Skipping all ``_``-prefixed attrs (the original behavior) + silently dropped ``@on async def _on_foo`` handlers — Python + convention puts callbacks behind ``_`` and drivers expected + that to work. Fixed in 0673652. + - The ``_is_event_subscription`` marker check below is the + authoritative filter; the name prefix is *only* used to skip + dunders so we don't resolve descriptors like ``__class__``. """ subscriptions = [] + # We iterate ``dir(self)`` rather than ``__dict__`` so handlers + # inherited from a base class are still picked up. The trade-off + # is that ``getattr`` here will invoke ``@property`` descriptors, + # which may have side effects on driver subclasses (the @on + # decorator only marks methods, but properties live in the same + # namespace). We swallow exceptions from the resolve step so a + # broken / lazy property never breaks subscription setup for an + # unrelated handler. ``inspect.getattr_static`` would avoid this + # entirely but also bypasses descriptors we *do* want resolved + # (classmethod / staticmethod) -- so dynamic ``getattr`` plus a + # narrow try/except is the right balance here. for attr_name in dir(self): - if attr_name.startswith("_"): + if attr_name.startswith("__"): + continue + + try: + attr = getattr(self, attr_name, None) + except Exception: + # A property raised. Not a subscription candidate (the + # @on decorator marks methods, not descriptors) so skip + # silently rather than failing the whole driver. continue - attr = getattr(self, attr_name, None) if attr is None or not callable(attr): continue @@ -1225,6 +1256,24 @@ async def _setup_subscription(self, sub: Dict[str, Any]) -> None: logger.info("[%s] Subscribing to: %s", self_id, subject) + # device_type filtering relies on the D2D peer cache to resolve the + # source device's type. In portal/registry mode there is no peer + # cache, so the cache miss path passes the event through unfiltered. + # Warn once at setup so subscribers don't silently see events from + # other device types. Strict filtering can be added in-handler. + if ( + device_type + and not is_lifecycle + and getattr(self._device, "_d2d_collector", None) is None + ): + logger.warning( + "[%s] @on(device_type=%r) on %s: device_type filtering is " + "best-effort in registry/portal mode. The wildcard broker " + "subject delivers every device's matching event; add an " + "in-handler type check if you need strict filtering.", + self_id, device_type, subject, + ) + # Use subscribe_with_subject to get the matched subject in callback # This allows extracting device_id from wildcard subscriptions messaging_client = self._router._messaging diff --git a/packages/device-connect-edge/device_connect_edge/registry_client.py b/packages/device-connect-edge/device_connect_edge/registry_client.py index 91dad47..cabf789 100644 --- a/packages/device-connect-edge/device_connect_edge/registry_client.py +++ b/packages/device-connect-edge/device_connect_edge/registry_client.py @@ -30,15 +30,24 @@ import asyncio import json import logging +import os import time import uuid -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple from device_connect_edge.messaging.base import MessagingClient from device_connect_edge.messaging.exceptions import RequestTimeoutError logger = logging.getLogger(__name__) +# Per-page chunk size when the client transparently iterates the full fleet. +# Sized to keep one JSON-RPC reply well under the default NATS max_payload +# of 1 MB even when device records carry rich function schemas (~10 KB each +# in the worst case observed): 100 * ~10 KB = ~1 MB, with the actual upper +# bound for typical records (~6 KB) landing at ~600 KB. Operators on +# unusually rich schemas can drop this via DEVICE_CONNECT_LIST_PAGE_SIZE. +_DEFAULT_LIST_PAGE_SIZE = int(os.getenv("DEVICE_CONNECT_LIST_PAGE_SIZE", "100")) + class RegistryClient: """JSON-RPC client for the device registry service. @@ -157,8 +166,103 @@ async def list_devices( self._cache, device_type, location, capabilities, ) + # Page through the registry transparently so the wire never carries + # a fleet-sized reply (NATS default max_payload is 1 MB and was + # being exceeded at ~1400 devices). Older servers that don't + # understand ``limit`` just return everything in one reply with + # ``next_offset`` absent, so the loop exits after a single + # iteration — fully backward compatible. + devices: List[Dict[str, Any]] = [] + offset = 0 + while True: + page, next_offset, _total = await self._list_devices_page( + device_type=device_type, + location=location, + capabilities=capabilities, + offset=offset, + limit=_DEFAULT_LIST_PAGE_SIZE, + timeout=timeout, + ) + devices.extend(page) + if next_offset is None: + break + # Defense-in-depth: a buggy or future server returning a + # non-advancing cursor would loop forever otherwise. Break + # with a warning so a fleet-scale incident becomes a + # recoverable log line. + if next_offset <= offset: + logger.warning( + "Registry returned non-advancing next_offset=%s (current offset=%s); " + "stopping page walk to avoid infinite loop", + next_offset, offset, + ) + break + offset = next_offset + logger.debug("Discovered %d devices from registry", len(devices)) + + # Update cache (store unfiltered if we fetched without filters) + if ( + self._cache_ttl > 0 + and device_type is None + and location is None + and not capabilities + ): + self._cache = devices + self._cache_time = time.time() + + return devices + + async def list_devices_page( + self, + *, + offset: int = 0, + limit: int = _DEFAULT_LIST_PAGE_SIZE, + device_type: Optional[str] = None, + location: Optional[str] = None, + capabilities: Optional[List[str]] = None, + timeout: Optional[float] = None, + ) -> Tuple[List[Dict[str, Any]], Optional[int], int]: + """Fetch a single page of devices with pagination metadata. + + Use this when you want to display a paged UI or stream results; + most callers should stick with :meth:`list_devices`, which loops + internally and returns the full fleet. + + Returns: + ``(devices, next_offset, total_matched)`` where ``next_offset`` + is ``None`` on the final page. + + ACL caveat: + When the registry has ACLs enabled, server-side filtering + runs *after* slicing. As a result ``len(devices)`` for a + given page may be smaller than ``limit`` even when more + pages follow, and ``total_matched`` is the unfiltered total + (before the caller's ACL applies). UIs should treat + ``total_matched`` as an upper bound on what the caller will + ever see, and must not assume ``len(devices) == limit`` + implies a full page. + """ + return await self._list_devices_page( + device_type=device_type, + location=location, + capabilities=capabilities, + offset=offset, + limit=limit, + timeout=timeout, + ) + + async def _list_devices_page( + self, + *, + device_type: Optional[str], + location: Optional[str], + capabilities: Optional[List[str]], + offset: int, + limit: int, + timeout: Optional[float], + ) -> Tuple[List[Dict[str, Any]], Optional[int], int]: subject = f"device-connect.{self._tenant}.discovery" - params: Dict[str, Any] = {} + params: Dict[str, Any] = {"offset": int(offset), "limit": int(limit)} if device_type: params["device_type"] = device_type if location: @@ -167,20 +271,12 @@ async def list_devices( params["capabilities"] = capabilities result = await self._request( - subject, - "discovery/listDevices", - params if params else None, - timeout, + subject, "discovery/listDevices", params, timeout, ) devices = result.get("devices", []) - logger.debug("Discovered %d devices from registry", len(devices)) - - # Update cache (store unfiltered if we fetched without filters) - if self._cache_ttl > 0 and not params: - self._cache = devices - self._cache_time = time.time() - - return devices + next_offset = result.get("next_offset") + total = result.get("total_matched", len(devices)) + return devices, next_offset, total async def get_device( self, diff --git a/packages/device-connect-edge/pyproject.toml b/packages/device-connect-edge/pyproject.toml index 8a7a733..2b4a056 100644 --- a/packages/device-connect-edge/pyproject.toml +++ b/packages/device-connect-edge/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "setuptools.build_meta" [project] name = "device-connect-edge" -version = "0.2.3" +version = "0.2.4" description = "Device Connect Edge — lightweight edge device runtime with Zenoh/NATS messaging and D2D communication" readme = "README.md" requires-python = ">=3.11" diff --git a/packages/device-connect-edge/tests/test_device.py b/packages/device-connect-edge/tests/test_device.py index 14d03fa..a0fb1b8 100644 --- a/packages/device-connect-edge/tests/test_device.py +++ b/packages/device-connect-edge/tests/test_device.py @@ -621,3 +621,57 @@ async def test_request_registration_returns_payload(self): assert "identity" in result assert "status" in result assert "ts" in result["status"] + + +# ── Registration startup jitter ─────────────────────────────────── + +class TestRegisterStartupJitter: + """The pre-registration jitter exists to decorrelate ~1000 phones + that boot in lockstep. ``DEVICE_CONNECT_REGISTER_JITTER=0`` is the + documented escape hatch for single-device dev (no sleep, no random + call); the tests pin both branches of that gate.""" + + def _make_runtime(self): + rt = DeviceRuntime( + driver=StubDriver(), + device_id="cam-jit-1", + messaging_urls=["nats://localhost:4222"], + ) + rt.messaging = AsyncMock() + # _handle_registration_reply expects a valid reply; short-circuit. + rt._handle_registration_reply = lambda _data: None + rt.messaging.request = AsyncMock( + return_value=json.dumps({ + "jsonrpc": "2.0", "id": "x", + "result": {"registration_id": "r1", "device_ttl": 30}, + }).encode(), + ) + return rt + + @pytest.mark.asyncio + async def test_jitter_zero_skips_sleep_and_random(self): + rt = self._make_runtime() + with patch("device_connect_edge.device._REGISTER_STARTUP_JITTER", 0), \ + patch("device_connect_edge.device.random.uniform") as mock_uniform, \ + patch("device_connect_edge.device.asyncio.sleep", new_callable=AsyncMock) as mock_sleep: + await rt._register(force=True) + + # JITTER=0 must not call random.uniform at all (this is the + # contract for single-device dev / deterministic tests). + mock_uniform.assert_not_called() + # asyncio.sleep is only called from the retry path; since the + # registry replied OK on the first try, sleep must not fire. + mock_sleep.assert_not_called() + + @pytest.mark.asyncio + async def test_jitter_positive_sleeps_once_before_first_request(self): + rt = self._make_runtime() + with patch("device_connect_edge.device._REGISTER_STARTUP_JITTER", 4.0), \ + patch("device_connect_edge.device.random.uniform", return_value=1.23) as mock_uniform, \ + patch("device_connect_edge.device.asyncio.sleep", new_callable=AsyncMock) as mock_sleep: + await rt._register(force=True) + + mock_uniform.assert_called_once_with(0, 4.0) + # Exactly one sleep (the jitter) — registry reply succeeded on + # the first try so the retry-backoff sleep doesn't fire. + mock_sleep.assert_awaited_once_with(1.23) diff --git a/packages/device-connect-edge/tests/test_drivers.py b/packages/device-connect-edge/tests/test_drivers.py index 9099930..b17b4df 100644 --- a/packages/device-connect-edge/tests/test_drivers.py +++ b/packages/device-connect-edge/tests/test_drivers.py @@ -392,6 +392,114 @@ async def disconnect(self): subs = driver._collect_event_subscriptions() assert len(subs) == 2 + def test_underscore_prefixed_handler_is_still_collected(self): + """Single-underscore @on handlers must not silently become no-ops.""" + class MyDriver(DeviceDriver): + device_type = "test" + + @on(device_type="phone", event_name="state_changed") + async def _on_phone_state(self, device_id, event_name, payload): + pass + + async def connect(self): + pass + + async def disconnect(self): + pass + + driver = MyDriver() + subs = driver._collect_event_subscriptions() + assert len(subs) == 1 + assert subs[0]["device_type"] == "phone" + assert subs[0]["event_name"] == "state_changed" + + def test_collector_survives_raising_property(self): + """A driver subclass with a @property that raises must not break + subscription collection. ``dir()`` surfaces every attribute on + the class, and ``getattr`` will invoke descriptors — a buggy or + lazy-init property would otherwise crash setup_subscriptions for + unrelated handlers.""" + class MyDriver(DeviceDriver): + device_type = "test" + + @property + def _not_ready_yet(self): + # Simulates a property that depends on connect() having + # run, or a hardware probe that fails until init. + raise RuntimeError("not ready") + + @on(device_type="phone", event_name="state_changed") + async def on_phone_state(self, device_id, event_name, payload): + pass + + async def connect(self): + pass + + async def disconnect(self): + pass + + driver = MyDriver() + subs = driver._collect_event_subscriptions() + # Property side-effect was tolerated; the real handler still + # registered. + assert len(subs) == 1 + assert subs[0]["event_name"] == "state_changed" + + @pytest.mark.asyncio + async def test_portal_mode_device_type_filter_warns_at_setup(self, caplog): + """In portal/registry mode there is no D2D peer cache to resolve + a source device's type, so ``@on(device_type=...)`` filtering + silently passes events from other device types through. The + driver must emit a single setup-time WARNING so the subscriber + sees the gotcha once, not on every event.""" + + class MyDriver(DeviceDriver): + device_type = "test" + + @on(device_type="camera", event_name="motion") + async def on_motion(self, device_id, event_name, payload): + pass + + async def connect(self): + pass + + async def disconnect(self): + pass + + driver = MyDriver() + + mock_messaging = AsyncMock() + mock_messaging.subscribe_with_subject = AsyncMock(return_value=MagicMock()) + + class FakeRouter: + def __init__(self): + self._messaging = mock_messaging + self._tenant = "default" + + # Portal/registry mode: _device is set, but _d2d_collector is None. + # Use a plain object — MagicMock would auto-generate + # ``_is_event_subscription`` truthy values and leak phantom + # subscriptions into _collect_event_subscriptions. + class FakeDevice: + _d2d_collector = None + driver._device = FakeDevice() + driver._device_id = "watcher-1" + driver._router = FakeRouter() + + with caplog.at_level("WARNING", logger="device_connect_edge.drivers.base"): + await driver.setup_subscriptions() + + warnings = [r for r in caplog.records if r.levelname == "WARNING"] + matching = [ + r for r in warnings + if "device_type filtering is" in r.message + and "best-effort" in r.message + ] + assert len(matching) == 1, ( + f"expected exactly one portal-mode warning, got " + f"{[r.message for r in warnings]}" + ) + # ── setup_subscriptions error isolation ─────────────────────────── diff --git a/packages/device-connect-edge/tests/test_registry_client.py b/packages/device-connect-edge/tests/test_registry_client.py index b295924..66dec01 100644 --- a/packages/device-connect-edge/tests/test_registry_client.py +++ b/packages/device-connect-edge/tests/test_registry_client.py @@ -99,3 +99,180 @@ async def test_request_raises_after_all_retries_exhausted(self, mock_sleep): assert messaging.request.call_count == 3 + +class TestListDevicesPagination: + """Verify list_devices transparently pages through the registry.""" + + @staticmethod + def _paged_responses(total: int, page_size: int): + """Build the sequence of NATS reply bytes the server would emit.""" + devices = [{"device_id": f"dev-{i:04d}"} for i in range(total)] + responses = [] + for start in range(0, total, page_size): + page = devices[start:start + page_size] + end = start + page_size + next_offset = end if end < total else None + responses.append(json.dumps({ + "jsonrpc": "2.0", + "id": "rpc-test", + "result": { + "devices": page, + "next_offset": next_offset, + "total_matched": total, + }, + }).encode()) + if not responses: + # Empty fleet: still need one round-trip + responses.append(json.dumps({ + "jsonrpc": "2.0", + "id": "rpc-test", + "result": {"devices": [], "next_offset": None, "total_matched": 0}, + }).encode()) + return responses + + @pytest.mark.asyncio + async def test_list_devices_pages_through_full_fleet(self): + """1400 devices should arrive across multiple round-trips.""" + client, messaging = _make_client() + messaging.request = AsyncMock(side_effect=self._paged_responses(1400, 100)) + + devices = await client.list_devices() + + assert len(devices) == 1400 + assert [d["device_id"] for d in devices] == [ + f"dev-{i:04d}" for i in range(1400) + ] + # 1400 / 100 = 14 round-trips + assert messaging.request.call_count == 14 + + @pytest.mark.asyncio + async def test_list_devices_passes_offset_and_limit_in_params(self): + """Each request must carry the pagination params on the wire.""" + client, messaging = _make_client() + messaging.request = AsyncMock(side_effect=self._paged_responses(250, 100)) + + await client.list_devices() + + offsets = [] + limits = [] + for call_args in messaging.request.call_args_list: + payload = json.loads(call_args.args[1]) + offsets.append(payload["params"]["offset"]) + limits.append(payload["params"]["limit"]) + + assert offsets == [0, 100, 200] + assert all(lim == 100 for lim in limits) + + @pytest.mark.asyncio + async def test_list_devices_legacy_server_single_reply(self): + """Server without pagination (no next_offset) terminates after 1 call.""" + client, messaging = _make_client() + # Legacy reply shape: devices only, no pagination metadata. + legacy = json.dumps({ + "jsonrpc": "2.0", + "id": "rpc-test", + "result": {"devices": [{"device_id": "a"}, {"device_id": "b"}]}, + }).encode() + messaging.request = AsyncMock(return_value=legacy) + + devices = await client.list_devices() + + assert len(devices) == 2 + # next_offset absent => loop exits after one request + assert messaging.request.call_count == 1 + + @pytest.mark.asyncio + async def test_list_devices_page_returns_metadata(self): + """list_devices_page exposes next_offset and total_matched to caller.""" + client, messaging = _make_client() + reply = json.dumps({ + "jsonrpc": "2.0", + "id": "rpc-test", + "result": { + "devices": [{"device_id": "a"}, {"device_id": "b"}], + "next_offset": 2, + "total_matched": 10, + }, + }).encode() + messaging.request = AsyncMock(return_value=reply) + + page, next_offset, total = await client.list_devices_page( + offset=0, limit=2, + ) + + assert len(page) == 2 + assert next_offset == 2 + assert total == 10 + + @pytest.mark.asyncio + async def test_list_devices_forwards_filters(self): + """device_type / location filters must accompany pagination params.""" + client, messaging = _make_client() + messaging.request = AsyncMock(side_effect=self._paged_responses(0, 100)) + + await client.list_devices(device_type="camera", location="lab-A") + + payload = json.loads(messaging.request.call_args.args[1]) + assert payload["params"]["device_type"] == "camera" + assert payload["params"]["location"] == "lab-A" + assert payload["params"]["offset"] == 0 + assert payload["params"]["limit"] == 100 + + @pytest.mark.asyncio + async def test_list_devices_handles_empty_page_with_next_offset(self): + """ACL filtering can yield an empty page mid-walk with next_offset + still pointing forward; the loop must advance, not stall.""" + client, messaging = _make_client() + responses = [ + # Page 0: ACL filtered everything out, but more pages follow. + json.dumps({ + "jsonrpc": "2.0", + "id": "rpc-test", + "result": {"devices": [], "next_offset": 100, "total_matched": 200}, + }).encode(), + # Page 1: some visible devices, final page. + json.dumps({ + "jsonrpc": "2.0", + "id": "rpc-test", + "result": { + "devices": [{"device_id": "visible-1"}], + "next_offset": None, + "total_matched": 200, + }, + }).encode(), + ] + messaging.request = AsyncMock(side_effect=responses) + + devices = await client.list_devices() + + assert [d["device_id"] for d in devices] == ["visible-1"] + assert messaging.request.call_count == 2 + # Second request must use next_offset from the first reply. + second_payload = json.loads(messaging.request.call_args_list[1].args[1]) + assert second_payload["params"]["offset"] == 100 + + @pytest.mark.asyncio + async def test_list_devices_breaks_on_non_advancing_next_offset(self, caplog): + """A buggy server returning next_offset <= current offset must not + spin the client forever — the page loop bails with a warning.""" + client, messaging = _make_client() + # Server bug: keeps returning the same offset. + repeating = json.dumps({ + "jsonrpc": "2.0", + "id": "rpc-test", + "result": { + "devices": [{"device_id": "a"}], + "next_offset": 0, + "total_matched": 100, + }, + }).encode() + messaging.request = AsyncMock(return_value=repeating) + + with caplog.at_level("WARNING"): + devices = await client.list_devices() + + assert len(devices) == 1 + assert messaging.request.call_count == 1 + assert any( + "non-advancing next_offset" in rec.message for rec in caplog.records + ) diff --git a/packages/device-connect-server/device_connect_server/portal/app.py b/packages/device-connect-server/device_connect_server/portal/app.py index 5625140..d68cacb 100644 --- a/packages/device-connect-server/device_connect_server/portal/app.py +++ b/packages/device-connect-server/device_connect_server/portal/app.py @@ -73,13 +73,26 @@ async def auth_middleware(request: web.Request, handler): session = await _get_session(request) if not session.get("username"): # Preserve the requested URL so post-login redirect lands the user - # back on (e.g.) the CLI approval page. - next_url = path - if request.query_string: - next_url = f"{path}?{request.query_string}" - from urllib.parse import quote - login_url = "/login?next=" + quote(next_url, safe="") if path != "/login" else "/login" - if request.headers.get("HX-Request"): + # back on (e.g.) the CLI approval page — but only for top-level + # HTML navigations. Background htmx polls and JSON fetches under + # /api/ return HTML fragments or JSON, not full pages, so using + # them as the post-login destination dumps the user onto a + # chrome-less fragment. The dashboard's 10s poll on + # /api/devices/live was the original repro: portal restart -> + # session lost -> next poll redirected to /login with the poll + # URL as ``next`` -> after login the user landed on the raw + # fragment instead of the dashboard. + is_htmx = request.headers.get("HX-Request") == "true" + is_api = path.startswith("/api/") + if is_htmx or is_api: + login_url = "/login" + else: + next_url = path + if request.query_string: + next_url = f"{path}?{request.query_string}" + from urllib.parse import quote + login_url = "/login?next=" + quote(next_url, safe="") if path != "/login" else "/login" + if is_htmx: resp = web.Response(status=200) resp.headers["HX-Redirect"] = login_url return resp @@ -179,6 +192,10 @@ def create_app() -> web.Application: # Seed admin on startup app.on_startup.append(_on_startup) + # Close the cached NATS invoke client on shutdown. Without this the + # socket leaks at graceful exit because the client is module-level + # state in nats_rpc, not tied to the aiohttp Application lifecycle. + app.on_cleanup.append(_on_cleanup) return app @@ -190,3 +207,12 @@ async def _on_startup(app: web.Application): ensure_admin() except Exception as e: logger.warning("Could not seed admin account (etcd may not be ready): %s", e) + + +async def _on_cleanup(app: web.Application): + """Release long-lived resources held at module scope.""" + try: + from .services.nats_rpc import close_invoke_client + await close_invoke_client() + except Exception as e: + logger.warning("Error closing cached NATS invoke client: %s", e) diff --git a/packages/device-connect-server/device_connect_server/portal/services/credentials.py b/packages/device-connect-server/device_connect_server/portal/services/credentials.py index 58a7687..6fb118e 100644 --- a/packages/device-connect-server/device_connect_server/portal/services/credentials.py +++ b/packages/device-connect-server/device_connect_server/portal/services/credentials.py @@ -64,6 +64,24 @@ def get_credential_data(filename: str) -> dict | None: return None +def delete_credential(filename: str) -> bool: + """Remove a credential file from disk. + + Returns True if a file was deleted, False if no such file existed. + Uses the same path-traversal guard as :func:`get_credential`, so a + crafted ``filename`` that resolves outside ``CREDS_DIR`` is rejected. + """ + path = get_credential(filename) + if not path: + return False + try: + path.unlink() + return True + except OSError: + logger.exception("failed to remove credential %s", filename) + return False + + def get_tenants_summary() -> dict[str, dict]: """Get a summary of all tenants and their device counts. diff --git a/packages/device-connect-server/device_connect_server/portal/services/nats_rpc.py b/packages/device-connect-server/device_connect_server/portal/services/nats_rpc.py index 3c03ce0..82a61b3 100644 --- a/packages/device-connect-server/device_connect_server/portal/services/nats_rpc.py +++ b/packages/device-connect-server/device_connect_server/portal/services/nats_rpc.py @@ -4,8 +4,10 @@ """NATS helpers: RPC invocation and event streaming.""" +import asyncio import json import logging +import time import uuid from pathlib import Path @@ -18,6 +20,56 @@ # Registry credentials (privileged, can reach all tenants) _REGISTRY_CREDS = Path(config.CREDS_DIR) / "registry.creds.json" +# Long-lived client reused across all invoke() calls. The portal used to +# open and close a fresh NATS connection per RPC, which added a TCP + +# JWT-auth handshake to every dashboard "Run" click. The connection is +# concurrent-safe (each nc.request creates its own inbox subscription) +# so a single cached client serves the whole portal. +_invoke_client: "nats.aio.client.Client | None" = None + +# Exception types that mean "the cached NATS client is no longer usable" +# — i.e. the next request must reconnect. We deliberately do NOT include +# every nats.errors.Error subclass: BadSubjectError, MaxPayloadError, +# AuthorizationError etc. are caller / payload bugs that don't kill the +# connection, so dropping the client on them would churn the socket on +# every malformed request. Native OSError / ConnectionError covers +# socket-level failures the NATS client may not have wrapped yet. +# +# Review notes (do not re-litigate without reading these): +# - ``ConnectionReconnectingError`` is intentionally absent: it means the +# client is *already* reconnecting itself. Dropping + close()-ing in +# that state preempts the nats-py reconnect machinery, forces a fresh +# handshake on every queued request, and amplifies broker flaps. Let +# the existing client recover; the next ``nc.request`` either succeeds +# post-reconnect or raises something more terminal that *is* in this +# set. Past review round suggested adding it -- don't. +# - ``ProtocolError`` and ``NoRespondersError`` are payload-level signals +# over a healthy socket; covered by their own branches / left to the +# default handler without dropping the client. See ``test_nats_rpc``. +_TRANSPORT_FATAL_ERRORS: tuple = ( + nats.errors.ConnectionClosedError, + nats.errors.ConnectionDrainingError, + nats.errors.StaleConnectionError, + nats.errors.NoServersError, + nats.errors.OutboundBufferLimitError, + nats.errors.SecureConnFailedError, + ConnectionError, + OSError, +) +# Lock is created lazily inside _get_invoke_lock() rather than at import +# time. asyncio.Lock() binds to whatever event loop is current when it's +# constructed; constructing it here would break tests (and any future +# code) that runs this module under a fresh loop. +_invoke_client_lock: "asyncio.Lock | None" = None + + +def _get_invoke_lock() -> asyncio.Lock: + """Return the module-level invoke lock, creating it on first use.""" + global _invoke_client_lock + if _invoke_client_lock is None: + _invoke_client_lock = asyncio.Lock() + return _invoke_client_lock + def _load_creds() -> dict: """Load registry credentials for NATS auth.""" @@ -27,6 +79,47 @@ def _load_creds() -> dict: return {} +async def _get_invoke_client(): + """Lazily open and cache a single NATS client for RPC invocations.""" + global _invoke_client + async with _get_invoke_lock(): + if _invoke_client is None or _invoke_client.is_closed: + _invoke_client = await connect() + logger.info("invoke client connected; will be reused across requests") + return _invoke_client + + +async def _drop_invoke_client() -> None: + """Discard the cached client, best-effort closing whatever's there. + + Called after a hard transport failure so the next invoke() reconnects + rather than reusing a half-dead client. The ``close()`` is wrapped in + a broad try/except because the connection is already known to be in + a bad state — we just want to release sockets if we can. + """ + global _invoke_client + async with _get_invoke_lock(): + stale = _invoke_client + _invoke_client = None + if stale is not None: + try: + await stale.close() + except Exception: + logger.debug("ignored error closing stale invoke client", exc_info=True) + + +async def close_invoke_client() -> None: + """Close the cached invoke client at app shutdown. + + Wire this into ``aiohttp.web.Application.on_cleanup``: without it the + long-lived socket leaks on graceful shutdown (the cached client is + module-level state, not tied to the app's lifecycle). Idempotent — + calling twice is a no-op because ``_drop_invoke_client`` nils the + global first. + """ + await _drop_invoke_client() + + async def connect(): """Return a connected NATS client using registry credentials.""" creds = _load_creds() @@ -55,23 +148,42 @@ def _sign(nonce): async def invoke(tenant: str, device_id: str, function: str, params: dict, timeout: float = 5.0) -> dict: """Send a JSON-RPC request to a device and return the response.""" - nc = await connect() + t0 = time.monotonic() + subject = f"device-connect.{tenant}.{device_id}.cmd" + payload = { + "jsonrpc": "2.0", + "id": str(uuid.uuid4()), + "method": function, + "params": params, + } try: - subject = f"device-connect.{tenant}.{device_id}.cmd" - payload = { - "jsonrpc": "2.0", - "id": str(uuid.uuid4()), - "method": function, - "params": params, - } - + nc = await _get_invoke_client() msg = await nc.request(subject, json.dumps(payload).encode(), timeout=timeout) + logger.info( + "invoke %s/%s.%s ok in %.1fms", + tenant, device_id, function, (time.monotonic() - t0) * 1000, + ) return json.loads(msg.data) except nats.errors.NoRespondersError: + logger.warning( + "invoke %s/%s.%s no-responders in %.1fms", + tenant, device_id, function, (time.monotonic() - t0) * 1000, + ) return {"error": {"code": -1, "message": f"Device {device_id} is not responding"}} except nats.errors.TimeoutError: + logger.warning( + "invoke %s/%s.%s timeout in %.1fms", + tenant, device_id, function, (time.monotonic() - t0) * 1000, + ) return {"error": {"code": -2, "message": f"Request timed out after {timeout}s"}} except Exception as e: + # Only drop the cached client on transport-level failures so a + # payload / programmer bug (BadSubject, MaxPayload, KeyError in + # our own code, ...) doesn't churn the connection on every call. + if isinstance(e, _TRANSPORT_FATAL_ERRORS): + await _drop_invoke_client() + logger.exception( + "invoke %s/%s.%s error in %.1fms: %s", + tenant, device_id, function, (time.monotonic() - t0) * 1000, e, + ) return {"error": {"code": -3, "message": str(e)}} - finally: - await nc.close() diff --git a/packages/device-connect-server/device_connect_server/portal/services/registry_client.py b/packages/device-connect-server/device_connect_server/portal/services/registry_client.py index 4599420..2789567 100644 --- a/packages/device-connect-server/device_connect_server/portal/services/registry_client.py +++ b/packages/device-connect-server/device_connect_server/portal/services/registry_client.py @@ -37,6 +37,28 @@ def _etcd_client(): return Etcd3Client(host=config.ETCD_HOST, port=config.ETCD_PORT) +def format_live_device(data: dict) -> dict: + """Shape a raw etcd device record into the dashboard's row dict. + + Shared between list_live_devices (table render) and the per-device + row-html endpoint (used by the dashboard JSON poll when a brand new + device appears mid-session). Keeping the formatting in one place + means the appended row matches the initial server-rendered rows. + """ + status = data.get("status") or {} + identity = data.get("identity") or {} + reg = data.get("registry") or {} + return { + "device_id": data.get("device_id", "unknown"), + "device_type": identity.get("device_type", "unknown"), + "status": status.get("availability", "unknown"), + "location": status.get("location", ""), + "last_seen": _format_ts(status.get("ts")) or reg.get("registered_at", ""), + "capabilities": data.get("capabilities", {}), + "_raw": data, + } + + def list_live_devices(tenant: str) -> list[dict]: """Query etcd for all registered devices in a tenant namespace. @@ -47,25 +69,14 @@ def list_live_devices(tenant: str) -> list[dict]: results = client.get_prefix(prefix) devices = [] - for raw, meta in results: + for raw, _meta in results: try: if isinstance(raw, bytes): raw = raw.decode() data = json.loads(raw) - status = data.get("status") or {} - identity = data.get("identity") or {} - reg = data.get("registry") or {} - devices.append({ - "device_id": data.get("device_id", "unknown"), - "device_type": identity.get("device_type", "unknown"), - "status": status.get("availability", "unknown"), - "location": status.get("location", ""), - "last_seen": _format_ts(status.get("ts")) or reg.get("registered_at", ""), - "capabilities": data.get("capabilities", {}), - "_raw": data, - }) except (json.JSONDecodeError, TypeError): continue + devices.append(format_live_device(data)) return devices diff --git a/packages/device-connect-server/device_connect_server/portal/templates/admin/tenant_detail.html b/packages/device-connect-server/device_connect_server/portal/templates/admin/tenant_detail.html index eb71d7d..878d80d 100644 --- a/packages/device-connect-server/device_connect_server/portal/templates/admin/tenant_detail.html +++ b/packages/device-connect-server/device_connect_server/portal/templates/admin/tenant_detail.html @@ -11,15 +11,15 @@

{{ viewing_as }}'s Dashboard

Credentials Created
-
{{ creds_count }}
+
{{ creds_count }}
Devices Online
-
{{ online_count }}
+
{{ online_count }}
Devices Registered
-
{{ registered_count }}
+
{{ registered_count }}
@@ -51,9 +51,12 @@

Live Devices

Auto-refreshing + {# Initial table render only — htmx fires `load` once and the + fragment lands here. From then on, /api/devices/live.json drives + in-place cell updates. See dashboard.html for the rationale. #}
Loading devices...
@@ -61,7 +64,33 @@

Live Devices