diff --git a/README.md b/README.md index 88fadc3..d48e1a9 100644 --- a/README.md +++ b/README.md @@ -165,6 +165,14 @@ client = PolymarketUS( > **Note**: WebSocket connections are async-only due to their event-driven nature. > Use `asyncio.run()` when working with the sync client, or use `AsyncPolymarketUS` directly. +> **Reconnection**: connections automatically reconnect with exponential backoff +> on unexpected drops, re-sign the auth handshake, and replay every active +> subscription. A `reconnect` event fires after a successful reconnect. Reconnect +> stops on fatal auth failures (401/403/429). Disable with `auto_reconnect=False`. +> Note that `order`, `position`, and `trade` streams do not replay history on +> reconnect; resubscribe to `SUBSCRIPTION_TYPE_ORDER_SNAPSHOT` if you need current +> open orders, while market data and account balance snapshots are sent automatically. + ```python import asyncio import os @@ -297,6 +305,7 @@ WebSocket methods (`connect()`, `subscribe()`, `close()`) are async and must be - `account_balance_snapshot` - Initial balance - `account_balance_update` - Balance changes - `heartbeat` - Connection keepalive +- `reconnect` - Reconnected and resubscribed after a drop - `error` - Error events - `close` - Connection closed @@ -305,6 +314,7 @@ WebSocket methods (`connect()`, `subscribe()`, `close()`) are async and must be - `market_data_lite` - Lightweight price data - `trade` - Trade notifications - `heartbeat` - Connection keepalive +- `reconnect` - Reconnected and resubscribed after a drop - `error` - Error events - `close` - Connection closed diff --git a/polymarket_us/websocket/base.py b/polymarket_us/websocket/base.py index b2f3db1..42c7f7d 100644 --- a/polymarket_us/websocket/base.py +++ b/polymarket_us/websocket/base.py @@ -1,10 +1,11 @@ -"""Base WebSocket class.""" +"""Base WebSocket class with automatic reconnect and resubscribe.""" from __future__ import annotations import asyncio import contextlib import json +import random from collections.abc import Callable from typing import Any @@ -16,9 +17,49 @@ from .types import MarketSubscriptionType, PrivateSubscriptionType +# WebSocket upgrade failures with these statuses are fatal (bad credentials or +# rate limiting) and must not trigger reconnect attempts. +_FATAL_AUTH_STATUSES = frozenset({401, 403, 429}) + +_RECONNECT_INITIAL_SECONDS = 0.5 +_RECONNECT_MAX_SECONDS = 30.0 + + +def _reconnect_delay(attempt: int) -> float: + """Exponential reconnect backoff with equal jitter (attempt is 0-indexed).""" + capped = min(_RECONNECT_INITIAL_SECONDS * (2**attempt), _RECONNECT_MAX_SECONDS) + return capped / 2 + random.random() * (capped / 2) + + +def _upgrade_status(exc: Exception) -> int | None: + """Extract the HTTP status from a failed WebSocket upgrade, if available. + + Handles both the modern (``exc.response.status_code``) and legacy + (``exc.status_code``) ``websockets`` exception shapes. + """ + response = getattr(exc, "response", None) + if response is not None: + status = getattr(response, "status_code", None) + if isinstance(status, int): + return status + status = getattr(exc, "status_code", None) + return status if isinstance(status, int) else None + + +class _Subscription: + """A subscription the client should replay after a reconnect.""" + + def __init__( + self, + subscription_type: PrivateSubscriptionType | MarketSubscriptionType, + market_slugs: list[str] | None, + ) -> None: + self.subscription_type = subscription_type + self.market_slugs = market_slugs + class BaseWebSocket: - """Base WebSocket class with event emitter pattern.""" + """Base WebSocket class with an event emitter and resilient connection.""" def __init__( self, @@ -27,6 +68,8 @@ def __init__( secret_key: str, base_url: str = "wss://api.polymarket.us", path: str, + auto_reconnect: bool = True, + reconnect_max_attempts: int | None = None, ) -> None: """Initialize WebSocket. @@ -35,40 +78,117 @@ def __init__( secret_key: Base64-encoded Ed25519 secret key base_url: WebSocket base URL path: WebSocket endpoint path + auto_reconnect: Reconnect and replay subscriptions on unexpected drops + reconnect_max_attempts: Max reconnect attempts per drop (None = unlimited) """ self.key_id = key_id self.secret_key = secret_key self.base_url = base_url self.path = path + self.auto_reconnect = auto_reconnect + self.reconnect_max_attempts = reconnect_max_attempts self._ws: ClientConnection | None = None self._listeners: dict[str, list[Callable[..., Any]]] = {} self._once_listeners: dict[str, list[Callable[..., Any]]] = {} - self._message_task: asyncio.Task[None] | None = None + self._run_task: asyncio.Task[None] | None = None + self._subscriptions: dict[str, _Subscription] = {} + self._closed = False async def connect(self) -> None: - """Establish WebSocket connection.""" + """Establish the WebSocket connection and start processing messages.""" + self._closed = False + await self._open_socket() + self._emit("open") + self._run_task = asyncio.create_task(self._run()) + + async def _open_socket(self) -> None: + """Open a socket with a freshly signed auth handshake.""" url = f"{self.base_url}{self.path}" + # Re-sign on every (re)connect: the timestamp must be within the skew window. headers = create_auth_headers(self.key_id, self.secret_key, "GET", self.path) - self._ws = await websockets.connect(url, additional_headers=headers) - self._emit("open") - - # Start message handler - self._message_task = asyncio.create_task(self._message_loop()) - async def _message_loop(self) -> None: - """Process incoming messages.""" - if not self._ws: - return - try: - async for message in self._ws: - if isinstance(message, bytes): - message = message.decode("utf-8") - self._handle_message(message) - except websockets.ConnectionClosed: - self._emit("close") - except Exception as e: - self._emit("error", PolymarketUSError(str(e))) + async def _run(self) -> None: + """Read messages, reconnecting and resubscribing on unexpected drops.""" + while True: + try: + if self._ws is None: + break + async for message in self._ws: + if isinstance(message, bytes): + message = message.decode("utf-8") + self._handle_message(message) + except websockets.ConnectionClosed: + pass + except Exception as e: + self._emit("error", PolymarketUSError(str(e))) + + # The loop may have exited on a still-open socket (e.g. a handler + # error rather than a drop). Close it before reconnecting so the old + # connection isn't leaked when _open_socket overwrites self._ws. + with contextlib.suppress(Exception): + if self._ws is not None: + await self._ws.close(1000, "OK") + + if self._closed or not self.auto_reconnect: + if not self._closed: + self._emit("close") + return + + if not await self._reconnect(): + if not self._closed: + self._emit("close") + return + + async def _reconnect(self) -> bool: + """Reconnect with backoff and replay subscriptions. Returns success.""" + attempt = 0 + while not self._closed and ( + self.reconnect_max_attempts is None or attempt < self.reconnect_max_attempts + ): + await asyncio.sleep(_reconnect_delay(attempt)) + if self._closed: + return False + try: + await self._open_socket() + except Exception as e: + status = _upgrade_status(e) + if status in _FATAL_AUTH_STATUSES: + self._emit("error", PolymarketUSError(f"WebSocket auth failed ({status})")) + return False + attempt += 1 + continue + # The user may have called close() while the upgrade was in flight. + if self._closed: + return False + # If the fresh connection drops mid-replay, treat it as another + # failed attempt rather than letting the exception kill the task. + try: + await self._resubscribe() + except Exception: + # Close the just-opened socket before retrying so it isn't + # orphaned when the next attempt overwrites self._ws. + with contextlib.suppress(Exception): + if self._ws: + await self._ws.close(1000, "OK") + attempt += 1 + continue + self._emit("reconnect") + return True + return False + + async def _resubscribe(self) -> None: + """Replay all active subscriptions after a reconnect.""" + for request_id, sub in list(self._subscriptions.items()): + request: dict[str, Any] = { + "subscribe": { + "requestId": request_id, + "subscriptionType": sub.subscription_type, + } + } + if sub.market_slugs: + request["subscribe"]["marketSlugs"] = sub.market_slugs + await self.send(request) def _handle_message(self, data: str) -> None: """Handle incoming message (override in subclasses).""" @@ -92,6 +212,9 @@ async def subscribe( ) -> None: """Subscribe to a data stream. + The subscription is recorded so it can be replayed automatically after a + reconnect. + Args: request_id: Unique request ID subscription_type: Type of subscription @@ -106,6 +229,7 @@ async def subscribe( if market_slugs: request["subscribe"]["marketSlugs"] = market_slugs await self.send(request) + self._subscriptions[request_id] = _Subscription(subscription_type, market_slugs) async def unsubscribe(self, request_id: str) -> None: """Unsubscribe from a data stream. @@ -113,17 +237,22 @@ async def unsubscribe(self, request_id: str) -> None: Args: request_id: Request ID of the subscription to cancel """ + self._subscriptions.pop(request_id, None) await self.send({"unsubscribe": {"requestId": request_id}}) async def close(self) -> None: - """Close the WebSocket connection.""" - if self._message_task: - self._message_task.cancel() + """Close the WebSocket connection and stop reconnecting.""" + self._closed = True + # Cancel first so an in-flight reconnect (sleeping or mid-handshake) is + # interrupted rather than left to open a socket close() never sees. + if self._run_task: + self._run_task.cancel() with contextlib.suppress(asyncio.CancelledError): - await self._message_task + await self._run_task + self._run_task = None if self._ws: await self._ws.close(1000, "OK") - self._ws = None + self._ws = None @property def is_connected(self) -> bool: diff --git a/polymarket_us/websocket/markets.py b/polymarket_us/websocket/markets.py index aac0a05..c4c8988 100644 --- a/polymarket_us/websocket/markets.py +++ b/polymarket_us/websocket/markets.py @@ -1,6 +1,7 @@ """Markets WebSocket.""" import json +from typing import Any from polymarket_us.errors import PolymarketUSError, WebSocketError @@ -11,7 +12,7 @@ class MarketsWebSocket(BaseWebSocket): """WebSocket for market data (order book, trades).""" - def __init__(self, **kwargs: str) -> None: + def __init__(self, **kwargs: Any) -> None: """Initialize markets WebSocket.""" super().__init__(path="/v1/ws/markets", **kwargs) diff --git a/polymarket_us/websocket/private.py b/polymarket_us/websocket/private.py index 5254f04..03c8bf4 100644 --- a/polymarket_us/websocket/private.py +++ b/polymarket_us/websocket/private.py @@ -1,6 +1,7 @@ """Private WebSocket.""" import json +from typing import Any from polymarket_us.errors import PolymarketUSError, WebSocketError @@ -11,7 +12,7 @@ class PrivateWebSocket(BaseWebSocket): """WebSocket for private data (orders, positions, balances).""" - def __init__(self, **kwargs: str) -> None: + def __init__(self, **kwargs: Any) -> None: """Initialize private WebSocket.""" super().__init__(path="/v1/ws/private", **kwargs) diff --git a/tests/test_auth.py b/tests/test_auth.py index eb70c09..6c16d15 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -64,12 +64,15 @@ def test_signature_is_base64(self) -> None: def test_handles_64_byte_key(self) -> None: """Should handle 64-byte keys (uses first 32 bytes).""" + import base64 + # 64-byte key (seed + public key), base64 encoded - secret_key_64 = "nWGxne/9WmC6hEr0kuwsxERJxWl7MmkZcDusAxyuf2A=" * 2 + seed = base64.b64decode("nWGxne/9WmC6hEr0kuwsxERJxWl7MmkZcDusAxyuf2A=") + secret_key_64 = base64.b64encode(seed + seed).decode() # Should not raise headers = create_auth_headers( key_id="test", - secret_key=secret_key_64[:88], # 64 bytes in base64 + secret_key=secret_key_64, method="GET", path="/v1/test", ) diff --git a/tests/test_websocket_reconnect.py b/tests/test_websocket_reconnect.py new file mode 100644 index 0000000..8601aa3 --- /dev/null +++ b/tests/test_websocket_reconnect.py @@ -0,0 +1,225 @@ +"""Tests for WebSocket auto-reconnect, resubscribe, and auth handling.""" + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from polymarket_us import PolymarketUS +from polymarket_us.websocket.base import _upgrade_status + +TEST_SECRET_KEY = "nWGxne/9WmC6hEr0kuwsxERJxWl7MmkZcDusAxyuf2A=" + + +@pytest.fixture +def client() -> PolymarketUS: + return PolymarketUS(key_id="test-key", secret_key=TEST_SECRET_KEY) + + +class TestSubscriptionTracking: + """Subscriptions are recorded for replay and cleared on unsubscribe.""" + + async def test_subscribe_records_subscription(self, client: PolymarketUS) -> None: + ws = client.ws.private() + ws._ws = AsyncMock() + + await ws.subscribe_orders("ord-1", ["mkt-a"]) + + assert "ord-1" in ws._subscriptions + assert ws._subscriptions["ord-1"].subscription_type == "SUBSCRIPTION_TYPE_ORDER" + assert ws._subscriptions["ord-1"].market_slugs == ["mkt-a"] + + async def test_unsubscribe_clears_subscription(self, client: PolymarketUS) -> None: + ws = client.ws.markets() + ws._ws = AsyncMock() + + await ws.subscribe_market_data("md-1", ["mkt-a"]) + assert "md-1" in ws._subscriptions + + await ws.unsubscribe("md-1") + assert "md-1" not in ws._subscriptions + + async def test_resubscribe_replays_all_subscriptions(self, client: PolymarketUS) -> None: + ws = client.ws.markets() + ws._ws = AsyncMock() + + await ws.subscribe_market_data("md-1", ["mkt-a"]) + await ws.subscribe_trades("tr-1", ["mkt-b"]) + ws._ws.send.reset_mock() + + await ws._resubscribe() + + assert ws._ws.send.call_count == 2 + + +class TestReconnect: + """Reconnect loop honors backoff, resubscribe, and fatal auth failures.""" + + @patch("asyncio.sleep", new_callable=AsyncMock) + async def test_reconnects_after_transient_failure( + self, _sleep: AsyncMock, client: PolymarketUS + ) -> None: + ws = client.ws.private() + ws._open_socket = AsyncMock(side_effect=[ConnectionError("dropped"), None]) + reconnected = [] + ws.on("reconnect", lambda: reconnected.append(True)) + + result = await ws._reconnect() + + assert result is True + assert ws._open_socket.call_count == 2 + assert reconnected == [True] + + @patch("asyncio.sleep", new_callable=AsyncMock) + async def test_does_not_reconnect_on_auth_failure( + self, _sleep: AsyncMock, client: PolymarketUS + ) -> None: + class _Resp: + status_code = 401 + + class _UpgradeError(Exception): + response = _Resp() + + ws = client.ws.private() + ws._open_socket = AsyncMock(side_effect=_UpgradeError()) + errors = [] + ws.on("error", lambda e: errors.append(e)) + + result = await ws._reconnect() + + assert result is False + assert len(errors) == 1 + + @patch("asyncio.sleep", new_callable=AsyncMock) + async def test_stops_reconnect_after_max_attempts( + self, _sleep: AsyncMock, client: PolymarketUS + ) -> None: + ws = client.ws.private() + ws.reconnect_max_attempts = 3 + ws._open_socket = AsyncMock(side_effect=ConnectionError("dropped")) + + result = await ws._reconnect() + + assert result is False + assert ws._open_socket.call_count == 3 + + +class TestUpgradeStatus: + """The upgrade-status helper supports both websockets exception shapes.""" + + def test_modern_response_shape(self) -> None: + class _Resp: + status_code = 429 + + class _Err(Exception): + response = _Resp() + + assert _upgrade_status(_Err()) == 429 + + def test_legacy_status_code_shape(self) -> None: + class _Err(Exception): + status_code = 403 + + assert _upgrade_status(_Err()) == 403 + + def test_no_status(self) -> None: + assert _upgrade_status(ConnectionError("network")) is None + + +class TestReconnectRobustness: + """Reconnect survives resubscribe drops and honors close during handshake.""" + + @patch("asyncio.sleep", new_callable=AsyncMock) + async def test_retries_when_resubscribe_fails( + self, _sleep: AsyncMock, client: PolymarketUS + ) -> None: + ws = client.ws.private() + sockets: list[AsyncMock] = [] + + async def _open() -> None: + socket = AsyncMock() + sockets.append(socket) + ws._ws = socket + + ws._open_socket = AsyncMock(side_effect=_open) + ws._resubscribe = AsyncMock(side_effect=[RuntimeError("dropped"), None]) + + result = await ws._reconnect() + + assert result is True + assert ws._open_socket.call_count == 2 + assert ws._resubscribe.call_count == 2 + # The socket whose resubscribe failed must be closed, not leaked. + sockets[0].close.assert_awaited() + + @patch("asyncio.sleep", new_callable=AsyncMock) + async def test_aborts_when_closed_after_open( + self, _sleep: AsyncMock, client: PolymarketUS + ) -> None: + ws = client.ws.private() + + async def _open() -> None: + ws._closed = True + + ws._open_socket = AsyncMock(side_effect=_open) + ws._resubscribe = AsyncMock() + + result = await ws._reconnect() + + assert result is False + ws._resubscribe.assert_not_called() + + +class _OneMessageSocket: + """Minimal async-iterable socket that yields one message then stops.""" + + def __init__(self) -> None: + self.close = AsyncMock() + self._yielded = False + + def __aiter__(self) -> "_OneMessageSocket": + return self + + async def __anext__(self) -> str: + if self._yielded: + raise StopAsyncIteration + self._yielded = True + return "msg" + + +class TestRunSocketCleanup: + """_run closes a still-open socket before reconnecting.""" + + async def test_closes_socket_before_reconnect_on_handler_error( + self, client: PolymarketUS + ) -> None: + ws = client.ws.private() + socket = _OneMessageSocket() + ws._ws = socket # type: ignore[assignment] + ws._handle_message = MagicMock(side_effect=RuntimeError("boom")) + ws._reconnect = AsyncMock(return_value=False) + + await ws._run() + + socket.close.assert_awaited() + ws._reconnect.assert_awaited() + + +class TestClose: + """close() interrupts an in-flight run task.""" + + async def test_close_cancels_run_task(self, client: PolymarketUS) -> None: + ws = client.ws.private() + ws._ws = AsyncMock() + + async def _forever() -> None: + await asyncio.sleep(3600) + + task = asyncio.create_task(_forever()) + ws._run_task = task + + await ws.close() + + assert ws._closed is True + assert task.cancelled() + assert ws._run_task is None