diff --git a/README.md b/README.md index 6c6fb89..278642f 100644 --- a/README.md +++ b/README.md @@ -158,7 +158,9 @@ If your Elgato app lives somewhere other than `/Applications/Elgato Stream Deck. The original USB-direct server is preserved for backwards compatibility — useful when you'd rather have the MCP server own the hardware directly (Linux, headless setups, or environments where the Elgato app isn't running). It exposes a different tool surface focused on direct hardware control: -`streamdeck_connect`, `streamdeck_info`, `streamdeck_set_button`, `streamdeck_set_buttons`, `streamdeck_clear_button`, `streamdeck_get_button`, `streamdeck_clear_all`, `streamdeck_set_brightness`, `streamdeck_create_page`, `streamdeck_switch_page`, `streamdeck_list_pages`, `streamdeck_delete_page`, `streamdeck_disconnect`. +`streamdeck_list_devices`, `streamdeck_connect`, `streamdeck_info`, `streamdeck_set_button`, `streamdeck_set_buttons`, `streamdeck_clear_button`, `streamdeck_get_button`, `streamdeck_clear_all`, `streamdeck_set_brightness`, `streamdeck_create_page`, `streamdeck_switch_page`, `streamdeck_list_pages`, `streamdeck_delete_page`, `streamdeck_disconnect`. + +When multiple decks are attached, call `streamdeck_list_devices` first, then pass the desired `serial` to `streamdeck_connect`. Omitting `serial` preserves the legacy behavior of opening the first enumerated deck. Run via: diff --git a/server.py b/server.py index 731f5b9..8ddfe85 100644 --- a/server.py +++ b/server.py @@ -258,9 +258,72 @@ def _check_deck_connected(self) -> None: "Stream Deck disconnected. Reconnect with streamdeck_connect." ) - def connect(self) -> dict[str, Any]: + def _enumerate_decks(self) -> list[Any]: + """Enumerate attached Stream Deck devices.""" + try: + return DeviceManager().enumerate() + except Exception as e: + logger.error(f"Failed to enumerate devices: {e}") + raise StreamDeckError(f"Failed to scan for Stream Deck devices: {e}") + + def _read_deck_serial(self, deck: Any) -> tuple[str, bool]: + """ + Read a deck serial, opening the device if needed. + + Returns: + Tuple of serial and whether this call opened the device. + """ + opened_here = False + try: + if not deck.is_open(): + deck.open() + opened_here = True + return deck.get_serial_number(), opened_here + except Exception as e: + if opened_here: + try: + deck.close() + except Exception as close_error: + logger.warning( + f"Failed to close Stream Deck after serial read error: {close_error}" + ) + raise StreamDeckError(f"Failed to read Stream Deck serial: {e}") + + def list_devices(self) -> list[dict[str, Any]]: + """ + List attached Stream Deck devices. + + Returns: + List of discovered deck metadata + + Raises: + StreamDeckError: If enumeration fails """ - Connect to the first available Stream Deck. + if not HAS_STREAMDECK: + raise StreamDeckError( + "streamdeck library not installed. Run: pip install streamdeck pillow" + ) + + devices = [] + for deck in self._enumerate_decks(): + serial, opened_here = self._read_deck_serial(deck) + try: + devices.append( + { + "serial": serial, + "deck_type": deck.deck_type(), + "key_count": deck.key_count(), + } + ) + finally: + if opened_here: + deck.close() + + return devices + + def connect(self, serial: str | None = None) -> dict[str, Any]: + """ + Connect to an available Stream Deck. Returns: Dict with connection result and deck info @@ -279,18 +342,38 @@ def connect(self) -> dict[str, Any]: time.sleep(RECONNECT_DELAY_BASE) self._last_connect_attempt = now - try: - decks = DeviceManager().enumerate() - except Exception as e: - logger.error(f"Failed to enumerate devices: {e}") - raise StreamDeckError(f"Failed to scan for Stream Deck devices: {e}") - + decks = self._enumerate_decks() if not decks: raise StreamDeckError("No Stream Deck found. Check USB connection and permissions.") + selected_deck = decks[0] + selected_is_open = False + + if serial is not None: + selected_deck = None + available_serials = [] + + for deck in decks: + deck_serial, opened_here = self._read_deck_serial(deck) + available_serials.append(deck_serial) + + if deck_serial == serial: + selected_deck = deck + selected_is_open = opened_here or deck.is_open() + break + + if opened_here: + deck.close() + + if selected_deck is None: + raise StreamDeckError( + f"No Stream Deck with serial {serial!r}. Available serials: {available_serials}" + ) + try: - self.deck = decks[0] - self.deck.open() + self.deck = selected_deck + if not selected_is_open: + self.deck.open() self.deck.reset() self.deck.set_brightness(self._brightness) self.deck.set_key_callback(self._key_callback) @@ -308,6 +391,13 @@ def connect(self) -> dict[str, Any]: except Exception as e: self._connect_attempts += 1 logger.error(f"Connection attempt {self._connect_attempts} failed: {e}") + if self.deck: + try: + self.deck.close() + except Exception as close_error: + logger.warning( + f"Failed to close Stream Deck after connection error: {close_error}" + ) self.deck = None if self._connect_attempts >= MAX_RECONNECT_ATTEMPTS: @@ -853,7 +943,23 @@ async def list_tools() -> list[Tool]: return [ Tool( name="streamdeck_connect", - description="Connect to a Stream Deck device. Call this first before other operations.", + description=( + "Connect to a Stream Deck device. Call this first before other operations. " + "Use serial to choose a specific deck when multiple are attached." + ), + inputSchema={ + "type": "object", + "properties": { + "serial": { + "type": "string", + "description": "Optional Stream Deck serial number to connect to", + }, + }, + }, + ), + Tool( + name="streamdeck_list_devices", + description="List attached Stream Deck devices without requiring an active connection", inputSchema={ "type": "object", "properties": {}, @@ -1099,7 +1205,7 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: ) ] - info = state.connect() + info = state.connect(serial=arguments.get("serial")) return [ TextContent( type="text", @@ -1110,6 +1216,21 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: ) ] + elif name == "streamdeck_list_devices": + if not HAS_STREAMDECK: + return [ + TextContent( + type="text", + text=( + "❌ streamdeck library not installed. " + "Run: pip install streamdeck pillow" + ), + ) + ] + + devices = state.list_devices() + return [TextContent(type="text", text=json.dumps(devices, indent=2))] + elif name == "streamdeck_info": info = state.get_deck_info() return [TextContent(type="text", text=json.dumps(info, indent=2))] diff --git a/tests/test_server.py b/tests/test_server.py index f677869..f607061 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -24,7 +24,9 @@ from server import ( # noqa: E402,I001 DeckNotConnectedError, StreamDeckState, + StreamDeckError, ValidationError, + list_tools, subprocess as server_subprocess, ) @@ -32,6 +34,25 @@ class TestStreamDeckState: """Tests for StreamDeckState class.""" + def _mock_deck( + self, + *, + serial: str, + deck_type: str = "Stream Deck Original", + key_count: int = 15, + is_open: bool = False, + ) -> MagicMock: + deck = MagicMock() + deck.deck_type.return_value = deck_type + deck.key_count.return_value = key_count + deck.get_serial_number.return_value = serial + deck.is_open.return_value = is_open + deck.key_layout.return_value = (3, 5) + deck.id.return_value = f"id-{serial}" + deck.get_firmware_version.return_value = "1.0.0" + deck.key_image_format.return_value = {"size": (72, 72), "format": "JPEG"} + return deck + @pytest.fixture def temp_config_dir(self, tmp_path: Path): """Create a temporary config directory.""" @@ -230,6 +251,97 @@ def test_get_deck_info_not_connected(self, state: StreamDeckState): info = state.get_deck_info() assert info["connected"] is False + def test_list_devices_returns_serial_type_and_key_count(self, state: StreamDeckState): + """Should enumerate connected deck metadata without keeping devices open.""" + original = self._mock_deck(serial="ORIGINAL123", deck_type="Stream Deck Original") + plus = self._mock_deck(serial="PLUS456", deck_type="Stream Deck Plus", key_count=8) + + with patch("server.DeviceManager") as mock_device_manager: + mock_device_manager.return_value.enumerate.return_value = [original, plus] + + devices = state.list_devices() + + assert devices == [ + { + "serial": "ORIGINAL123", + "deck_type": "Stream Deck Original", + "key_count": 15, + }, + { + "serial": "PLUS456", + "deck_type": "Stream Deck Plus", + "key_count": 8, + }, + ] + original.open.assert_called_once_with() + original.close.assert_called_once_with() + plus.open.assert_called_once_with() + plus.close.assert_called_once_with() + + def test_connect_without_serial_opens_first_deck(self, state: StreamDeckState): + """Default connect behavior should preserve the first enumerated deck choice.""" + original = self._mock_deck(serial="ORIGINAL123") + plus = self._mock_deck(serial="PLUS456", deck_type="Stream Deck Plus", key_count=8) + + with patch("server.DeviceManager") as mock_device_manager: + mock_device_manager.return_value.enumerate.return_value = [original, plus] + with patch.object(state, "_render_current_page"): + info = state.connect() + + assert state.deck is original + assert info["serial"] == "ORIGINAL123" + original.open.assert_called_once_with() + plus.open.assert_not_called() + + def test_connect_with_serial_opens_matching_deck(self, state: StreamDeckState): + """A provided serial should select that device from the enumerated decks.""" + original = self._mock_deck(serial="ORIGINAL123") + plus = self._mock_deck(serial="PLUS456", deck_type="Stream Deck Plus", key_count=8) + + with patch("server.DeviceManager") as mock_device_manager: + mock_device_manager.return_value.enumerate.return_value = [original, plus] + with patch.object(state, "_render_current_page"): + info = state.connect(serial="PLUS456") + + assert state.deck is plus + assert info["serial"] == "PLUS456" + original.open.assert_called_once_with() + original.close.assert_called_once_with() + plus.open.assert_called_once_with() + plus.close.assert_not_called() + + def test_connect_with_unknown_serial_lists_available_serials(self, state: StreamDeckState): + """Unknown serial errors should include available serials for self-correction.""" + original = self._mock_deck(serial="ORIGINAL123") + plus = self._mock_deck(serial="PLUS456", deck_type="Stream Deck Plus", key_count=8) + + with patch("server.DeviceManager") as mock_device_manager: + mock_device_manager.return_value.enumerate.return_value = [original, plus] + with pytest.raises(StreamDeckError, match="MISSING789") as exc_info: + state.connect(serial="MISSING789") + + message = str(exc_info.value) + assert "ORIGINAL123" in message + assert "PLUS456" in message + assert state.deck is None + original.close.assert_called_once_with() + plus.close.assert_called_once_with() + + def test_connect_with_serial_closes_match_when_setup_fails(self, state: StreamDeckState): + """A serial-selected deck opened during probing should be closed on setup failure.""" + original = self._mock_deck(serial="ORIGINAL123") + plus = self._mock_deck(serial="PLUS456", deck_type="Stream Deck Plus", key_count=8) + plus.reset.side_effect = RuntimeError("reset failed") + + with patch("server.DeviceManager") as mock_device_manager: + mock_device_manager.return_value.enumerate.return_value = [original, plus] + with pytest.raises(StreamDeckError, match="reset failed"): + state.connect(serial="PLUS456") + + assert state.deck is None + original.close.assert_called_once_with() + plus.close.assert_called_once_with() + # ======================================================================== # Button Action Tests # ======================================================================== @@ -308,6 +420,18 @@ async def test_tool_error_handling(self): # Left as placeholder for integration tests pass + @pytest.mark.asyncio + async def test_tool_schema_includes_device_listing_and_connect_serial(self): + """MCP schema should expose device listing and optional serial connect.""" + tools = await list_tools() + tools_by_name = {tool.name: tool for tool in tools} + + assert "streamdeck_list_devices" in tools_by_name + + connect_schema = tools_by_name["streamdeck_connect"].inputSchema + assert "serial" in connect_schema["properties"] + assert "serial" not in connect_schema.get("required", []) + # Run with: pytest tests/test_server.py -v if __name__ == "__main__":