From eab997dfe2490c2a4d01886042dae0d2dc567706 Mon Sep 17 00:00:00 2001 From: "praisonai-triage-agent[bot]" <272766704+praisonai-triage-agent[bot]@users.noreply.github.com> Date: Sun, 21 Jun 2026 09:21:05 +0000 Subject: [PATCH 1/2] fix: add protocol version negotiation and resilient reconnection to gateway (fixes #2130) - Core SDK: Added protocol version constants and negotiation types - Core SDK: Added sequence numbers to GatewayEvent for gap detection - Server: Implemented version negotiation in join handshake - Server: Added sequence numbering and gap detection support - Server: Enhanced resume with presence/health snapshot - Client: Created reconnecting Python client with exponential backoff - Client: Implemented gap detection and cursor-based resumption Co-authored-by: MervinPraison --- .../praisonaiagents/gateway/__init__.py | 22 + .../praisonaiagents/gateway/protocols.py | 61 ++- src/praisonai/praisonai/gateway/__init__.py | 9 + src/praisonai/praisonai/gateway/client.py | 409 ++++++++++++++++++ src/praisonai/praisonai/gateway/server.py | 44 +- 5 files changed, 542 insertions(+), 3 deletions(-) create mode 100644 src/praisonai/praisonai/gateway/client.py diff --git a/src/praisonai-agents/praisonaiagents/gateway/__init__.py b/src/praisonai-agents/praisonaiagents/gateway/__init__.py index 546d10379..7a5b15176 100644 --- a/src/praisonai-agents/praisonaiagents/gateway/__init__.py +++ b/src/praisonai-agents/praisonaiagents/gateway/__init__.py @@ -25,6 +25,17 @@ OutboundDeliveryProtocol, ChannelInfo, PresenceInfo, + # Home channel and delivery protocols + HomeChannelRegistryProtocol, + DeliveryResolverProtocol, + # Protocol version negotiation + PROTOCOL_VERSION, + MIN_PROTOCOL_VERSION, + MAX_PROTOCOL_VERSION, + ProtocolHello, + ProtocolHelloOk, + GapInfo, + ResumeSnapshot, ) from .config import ( GatewayConfig, @@ -93,6 +104,17 @@ def __getattr__(name: str): "OutboundDeliveryProtocol", "ChannelInfo", "PresenceInfo", + # Home channel and delivery protocols + "HomeChannelRegistryProtocol", + "DeliveryResolverProtocol", + # Protocol version negotiation + "PROTOCOL_VERSION", + "MIN_PROTOCOL_VERSION", + "MAX_PROTOCOL_VERSION", + "ProtocolHello", + "ProtocolHelloOk", + "GapInfo", + "ResumeSnapshot", # Config (always available) "GatewayConfig", "SessionConfig", diff --git a/src/praisonai-agents/praisonaiagents/gateway/protocols.py b/src/praisonai-agents/praisonaiagents/gateway/protocols.py index c01ee0294..b250ed49b 100644 --- a/src/praisonai-agents/praisonaiagents/gateway/protocols.py +++ b/src/praisonai-agents/praisonaiagents/gateway/protocols.py @@ -24,6 +24,7 @@ Literal, Optional, Protocol, + TypedDict, Union, runtime_checkable, ) @@ -173,6 +174,7 @@ class GatewayEvent: timestamp: Event creation time source: Source identifier (agent_id, client_id, etc.) target: Target identifier (optional, for directed events) + sequence: Monotonic sequence number for gap detection (optional) """ type: Union[EventType, str] @@ -181,10 +183,11 @@ class GatewayEvent: timestamp: float = field(default_factory=time.time) source: Optional[str] = None target: Optional[str] = None + sequence: Optional[int] = None # Monotonic sequence for gap detection def to_dict(self) -> Dict[str, Any]: """Convert to dictionary for serialization.""" - return { + result = { "type": self.type.value if isinstance(self.type, EventType) else self.type, "data": self.data, "event_id": self.event_id, @@ -192,6 +195,9 @@ def to_dict(self) -> Dict[str, Any]: "source": self.source, "target": self.target, } + if self.sequence is not None: + result["sequence"] = self.sequence + return result @classmethod def from_dict(cls, data: Dict[str, Any]) -> "GatewayEvent": @@ -209,6 +215,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "GatewayEvent": timestamp=data.get("timestamp", time.time()), source=data.get("source"), target=data.get("target"), + sequence=data.get("sequence"), ) @@ -1115,7 +1122,6 @@ def lookup(self, session_id: str) -> Optional[Dict[str, Any]]: ... -# --------------------------------------------------------------------------- # Home Channel and Delivery Routing Protocols # --------------------------------------------------------------------------- @@ -1193,3 +1199,54 @@ def resolve( List of concrete delivery targets """ ... + + +# --------------------------------------------------------------------------- +# Protocol Version Negotiation (Issue #2130) +# --------------------------------------------------------------------------- + +# Protocol version constants +PROTOCOL_VERSION = 1 +MIN_PROTOCOL_VERSION = 1 +MAX_PROTOCOL_VERSION = 1 + + +class ProtocolHello(TypedDict, total=False): + """Protocol version negotiation handshake request. + + Sent by client during join to negotiate protocol version. + """ + min_version: int # Minimum protocol version client supports + max_version: int # Maximum protocol version client supports + features: List[str] # Optional feature flags + + +class ProtocolHelloOk(TypedDict): + """Protocol version negotiation response. + + Server's response to protocol negotiation. + """ + protocol_version: int # Negotiated protocol version + server_min_version: int # Server's minimum supported version + server_max_version: int # Server's maximum supported version + features: List[str] # Enabled feature flags + + +class GapInfo(TypedDict): + """Information about a gap in the event sequence.""" + expected_seq: int # Expected sequence number + received_seq: int # Received sequence number + missed_count: int # Number of events missed + + +class ResumeSnapshot(TypedDict, total=False): + """Complete snapshot for session resumption. + + Provides all necessary state for one-round-trip reconnection. + """ + cursor: int # Resume cursor position + sequence: int # Current sequence number for gap detection + events: List[Dict[str, Any]] # Replayed events since cursor + presence: List[Dict[str, Any]] # Current presence information + health: Dict[str, Any] # Gateway health status + session_state: Dict[str, Any] # Session-specific state diff --git a/src/praisonai/praisonai/gateway/__init__.py b/src/praisonai/praisonai/gateway/__init__.py index 002e9c0b2..eee1d80df 100644 --- a/src/praisonai/praisonai/gateway/__init__.py +++ b/src/praisonai/praisonai/gateway/__init__.py @@ -9,6 +9,7 @@ if TYPE_CHECKING: from .server import WebSocketGateway, GatewaySession + from .client import GatewayClient, BackoffConfig from .rate_limiter import AuthRateLimiter from .pairing import PairingStore from .exec_approval import ExecApprovalManager, get_exec_approval_manager @@ -23,6 +24,12 @@ def __getattr__(name: str): if name == "GatewaySession": from .server import GatewaySession return GatewaySession + if name == "GatewayClient": + from .client import GatewayClient + return GatewayClient + if name == "BackoffConfig": + from .client import BackoffConfig + return BackoffConfig # Security / approval primitives if name == "AuthRateLimiter": from .rate_limiter import AuthRateLimiter @@ -50,6 +57,8 @@ def __getattr__(name: str): __all__ = [ "WebSocketGateway", "GatewaySession", + "GatewayClient", + "BackoffConfig", "AuthRateLimiter", "PairingStore", "ExecApprovalManager", diff --git a/src/praisonai/praisonai/gateway/client.py b/src/praisonai/praisonai/gateway/client.py new file mode 100644 index 000000000..e4722b10d --- /dev/null +++ b/src/praisonai/praisonai/gateway/client.py @@ -0,0 +1,409 @@ +""" +Reconnecting Gateway Client for PraisonAI. + +Provides automatic reconnection with exponential backoff, +protocol version negotiation, and gap detection. +""" + +import asyncio +import json +import logging +import random +import time +from dataclasses import dataclass, field +from typing import Any, AsyncIterator, Callable, Dict, List, Optional, Union + +try: + import websockets + from websockets.client import WebSocketClientProtocol +except ImportError: + raise ImportError("websockets is required. Install with: pip install websockets") + +from praisonaiagents.gateway import ( + GatewayEvent, + GatewayMessage, + EventType, + PROTOCOL_VERSION, + MIN_PROTOCOL_VERSION, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class BackoffConfig: + """Configuration for exponential backoff.""" + initial: float = 1.0 # Initial delay in seconds + max: float = 30.0 # Maximum delay in seconds + multiplier: float = 2.0 # Backoff multiplier + jitter: float = 0.2 # Random jitter factor (0-1) + + +class ConnectionState: + """Connection state enumeration.""" + DISCONNECTED = "disconnected" + CONNECTING = "connecting" + CONNECTED = "connected" + RECONNECTING = "reconnecting" + + +class GatewayClient: + """Reconnecting WebSocket client for PraisonAI Gateway. + + Features: + - Automatic reconnection with exponential backoff + - Protocol version negotiation + - Gap detection via sequence numbers + - Cursor-based event resumption + + Example: + client = GatewayClient( + url="ws://localhost:8765", + agent_id="my-agent", + reconnect=True, + backoff=BackoffConfig(initial=1, max=30) + ) + + # Set up gap handler + client.on_gap = lambda expected, received: print(f"Gap detected: {expected}->{received}") + + # Connect and handle events + await client.connect() + async for event in client.events(): + print(f"Event: {event}") + """ + + def __init__( + self, + url: str, + agent_id: str, + token: Optional[str] = None, + reconnect: bool = True, + backoff: Optional[BackoffConfig] = None, + max_reconnect_attempts: Optional[int] = None, + ): + """Initialize the gateway client. + + Args: + url: WebSocket URL to connect to + agent_id: Agent ID to join as + token: Optional authentication token + reconnect: Whether to auto-reconnect on disconnect + backoff: Backoff configuration + max_reconnect_attempts: Max reconnection attempts (None = infinite) + """ + self.url = url + self.agent_id = agent_id + self.token = token + self.reconnect = reconnect + self.backoff = backoff or BackoffConfig() + self.max_reconnect_attempts = max_reconnect_attempts + + self._ws: Optional[WebSocketClientProtocol] = None + self._state = ConnectionState.DISCONNECTED + self._session_id: Optional[str] = None + self._cursor: int = 0 + self._sequence: int = 0 + self._expected_sequence: int = 0 + self._protocol_version: int = PROTOCOL_VERSION + self._reconnect_attempts: int = 0 + self._event_queue: asyncio.Queue = asyncio.Queue() + self._running = False + self._receive_task: Optional[asyncio.Task] = None + + # Callbacks + self.on_gap: Optional[Callable[[int, int], None]] = None + self.on_state_change: Optional[Callable[[str], None]] = None + + @property + def state(self) -> str: + """Current connection state.""" + return self._state + + @property + def is_connected(self) -> bool: + """Whether the client is currently connected.""" + return self._state == ConnectionState.CONNECTED + + def _set_state(self, state: str) -> None: + """Set connection state and notify callback.""" + if self._state != state: + self._state = state + if self.on_state_change: + self.on_state_change(state) + + def _calculate_backoff(self) -> float: + """Calculate next backoff delay with jitter.""" + delay = min( + self.backoff.initial * (self.backoff.multiplier ** self._reconnect_attempts), + self.backoff.max + ) + # Add jitter + jitter = delay * self.backoff.jitter * (2 * random.random() - 1) + return max(0, delay + jitter) + + async def connect(self) -> None: + """Connect to the gateway with automatic reconnection.""" + if self._running: + return + + self._running = True + self._reconnect_attempts = 0 + + while self._running: + try: + await self._connect_once() + + # Reset reconnect attempts on successful connection + self._reconnect_attempts = 0 + + # Start receive loop + self._receive_task = asyncio.create_task(self._receive_loop()) + + # Wait for disconnect or stop + await self._receive_task + + except Exception as e: + logger.error(f"Connection error: {e}") + + if not self._running or not self.reconnect: + break + + # Check max attempts + if self.max_reconnect_attempts and self._reconnect_attempts >= self.max_reconnect_attempts: + logger.error(f"Max reconnection attempts ({self.max_reconnect_attempts}) reached") + break + + # Calculate backoff delay + delay = self._calculate_backoff() + self._reconnect_attempts += 1 + + logger.info(f"Reconnecting in {delay:.1f}s (attempt {self._reconnect_attempts})") + self._set_state(ConnectionState.RECONNECTING) + + await asyncio.sleep(delay) + + async def _connect_once(self) -> None: + """Establish WebSocket connection and perform handshake.""" + self._set_state(ConnectionState.CONNECTING) + + # Build connection URL with token if provided + connect_url = self.url + if self.token: + separator = "&" if "?" in connect_url else "?" + connect_url = f"{connect_url}{separator}token={self.token}" + + # Connect to WebSocket + self._ws = await websockets.connect(connect_url) + + # Send join message with protocol version + join_msg = { + "type": "join", + "agent_id": self.agent_id, + "min_version": MIN_PROTOCOL_VERSION, + "max_version": PROTOCOL_VERSION, + } + + # Include session/cursor for reconnection + if self._session_id: + join_msg["session_id"] = self._session_id + if self._cursor > 0: + join_msg["since"] = self._cursor + + await self._ws.send(json.dumps(join_msg)) + + # Wait for join response + response = await self._ws.recv() + data = json.loads(response) + + if data.get("type") == "error": + error_code = data.get("code") + error_msg = data.get("message", "Unknown error") + + if error_code == "version_unsupported": + raise ValueError(f"Protocol version unsupported: {error_msg}") + else: + raise ConnectionError(f"Join failed: {error_msg}") + + elif data.get("type") == "joined": + # Store session info + self._session_id = data.get("session_id") + self._cursor = data.get("cursor", 0) + self._sequence = data.get("sequence", 0) + self._expected_sequence = self._sequence + 1 + self._protocol_version = data.get("protocol_version", PROTOCOL_VERSION) + + self._set_state(ConnectionState.CONNECTED) + + logger.info( + f"Connected to gateway (session={self._session_id}, " + f"protocol=v{self._protocol_version}, resumed={data.get('resumed', False)})" + ) + + async def _receive_loop(self) -> None: + """Receive messages from WebSocket.""" + try: + while self._ws and not self._ws.closed: + message = await self._ws.recv() + data = json.loads(message) + + # Handle different message types + msg_type = data.get("type") + + if msg_type == "replay": + # Handle replayed event + event_data = data.get("event", {}) + event = GatewayEvent.from_dict(event_data) + await self._handle_event(event) + + elif msg_type == "event": + # Handle regular event + event = GatewayEvent.from_dict(data) + await self._handle_event(event) + + else: + # Queue other messages as events + event = GatewayEvent( + type=msg_type or "message", + data=data + ) + await self._event_queue.put(event) + + except websockets.ConnectionClosed: + logger.info("WebSocket connection closed") + except Exception as e: + logger.error(f"Receive loop error: {e}") + finally: + self._set_state(ConnectionState.DISCONNECTED) + + async def _handle_event(self, event: GatewayEvent) -> None: + """Handle an event with gap detection.""" + # Check for sequence gap + if event.sequence is not None: + if event.sequence != self._expected_sequence: + # Gap detected + gap_size = event.sequence - self._expected_sequence + logger.warning( + f"Gap detected: expected seq {self._expected_sequence}, " + f"received {event.sequence} (missed {gap_size} events)" + ) + + if self.on_gap: + self.on_gap(self._expected_sequence, event.sequence) + + self._expected_sequence = event.sequence + 1 + + # Update cursor if present + cursor = event.data.get("cursor") + if cursor: + self._cursor = cursor + + # Queue the event + await self._event_queue.put(event) + + async def disconnect(self) -> None: + """Disconnect from the gateway.""" + self._running = False + + if self._receive_task: + self._receive_task.cancel() + try: + await self._receive_task + except asyncio.CancelledError: + pass + + if self._ws: + await self._ws.close() + self._ws = None + + self._set_state(ConnectionState.DISCONNECTED) + + async def send(self, message: Union[str, Dict[str, Any]]) -> None: + """Send a message to the gateway. + + Args: + message: Message content (string or dict) + """ + if not self.is_connected: + raise ConnectionError("Not connected to gateway") + + if isinstance(message, dict): + await self._ws.send(json.dumps(message)) + else: + await self._ws.send(json.dumps({ + "type": "message", + "content": message + })) + + async def events(self) -> AsyncIterator[GatewayEvent]: + """Iterate over received events. + + Yields: + Gateway events as they are received + """ + while self._running: + try: + # Use timeout to periodically check if still running + event = await asyncio.wait_for( + self._event_queue.get(), + timeout=1.0 + ) + yield event + except asyncio.TimeoutError: + continue + + async def resync(self) -> None: + """Force a full resynchronization. + + Useful when a gap is detected beyond the replay window. + """ + logger.info("Forcing resynchronization") + + # Reset cursor to trigger full resync + self._cursor = 0 + self._sequence = 0 + self._expected_sequence = 0 + + # Disconnect and reconnect + if self._ws: + await self._ws.close() + + +async def example_usage(): + """Example usage of the GatewayClient.""" + client = GatewayClient( + url="ws://localhost:8765", + agent_id="example-agent", + reconnect=True, + backoff=BackoffConfig(initial=1, max=30, jitter=0.2) + ) + + # Set up event handlers + def on_gap(expected: int, received: int): + print(f"Gap detected: expected {expected}, got {received}") + # Could trigger resync here if gap is too large + + def on_state_change(state: str): + print(f"Connection state: {state}") + + client.on_gap = on_gap + client.on_state_change = on_state_change + + try: + # Connect to gateway + await client.connect() + + # Process events + async for event in client.events(): + print(f"Event: {event.type}, Data: {event.data}") + + # Send a response + if event.type == EventType.MESSAGE: + await client.send("Got your message!") + + finally: + await client.disconnect() + + +if __name__ == "__main__": + asyncio.run(example_usage()) \ No newline at end of file diff --git a/src/praisonai/praisonai/gateway/server.py b/src/praisonai/praisonai/gateway/server.py index 7a2d2679a..c73833e6a 100644 --- a/src/praisonai/praisonai/gateway/server.py +++ b/src/praisonai/praisonai/gateway/server.py @@ -26,6 +26,9 @@ GatewayEvent, GatewayMessage, EventType, + PROTOCOL_VERSION, + MIN_PROTOCOL_VERSION, + MAX_PROTOCOL_VERSION, ) from praisonaiagents.gateway.protocols import ( ConnectErrorCode, @@ -59,6 +62,8 @@ class GatewaySession: _event_cursor: int = 0 # Monotonic cursor for event replay _events: List[GatewayEvent] = field(default_factory=list) # Event history for replay _was_resumed: bool = False # Track if session was resumed from persistence + _sequence: int = 0 # Monotonic sequence number for gap detection + _protocol_version: int = PROTOCOL_VERSION # Negotiated protocol version # Stepper & Concurrency logic _inbox: asyncio.Queue = field(default_factory=asyncio.Queue) @@ -112,7 +117,9 @@ def close(self) -> None: def add_event(self, event: GatewayEvent) -> int: """Add an event and return its cursor position.""" self._event_cursor += 1 + self._sequence += 1 event.data['cursor'] = self._event_cursor + event.sequence = self._sequence # Add sequence for gap detection self._events.append(event) self._last_activity = time.time() # Keep events bounded to prevent unbounded growth @@ -1265,6 +1272,24 @@ async def _handle_client_message(self, client_id: str, data: Dict[str, Any]) -> elif msg_type == "join": agent_id = data.get("agent_id") if agent_id and agent_id in self._agents: + # Protocol version negotiation + client_min_version = data.get("min_version", MIN_PROTOCOL_VERSION) + client_max_version = data.get("max_version", PROTOCOL_VERSION) + + # Check if we can negotiate a common version + if client_max_version < MIN_PROTOCOL_VERSION or client_min_version > MAX_PROTOCOL_VERSION: + await self._send_to_client(client_id, { + "type": "error", + "code": "version_unsupported", + "message": f"Protocol version mismatch. Server supports {MIN_PROTOCOL_VERSION}-{MAX_PROTOCOL_VERSION}, client supports {client_min_version}-{client_max_version}", + "server_min_version": MIN_PROTOCOL_VERSION, + "server_max_version": MAX_PROTOCOL_VERSION, + }) + return + + # Negotiate the highest common version + negotiated_version = min(client_max_version, MAX_PROTOCOL_VERSION) + # Support reconnection with existing session session_id = data.get("session_id") # Optional: existing session to resume since_cursor = data.get("since") # Optional: cursor for event replay @@ -1277,15 +1302,32 @@ async def _handle_client_message(self, client_id: str, data: Dict[str, Any]) -> since_cursor=since_cursor, ) + # Set negotiated protocol version for the session + session._protocol_version = negotiated_version + self._client_sessions[client_id] = session.session_id - # Send join confirmation (old format for backward compatibility) + # Build presence snapshot + presence_snapshot = [] + if hasattr(self, '_presence_manager'): + from .push_presence import PresenceManager + if isinstance(self._presence_manager, PresenceManager): + presence_info = self._presence_manager.get_all_presence() + presence_snapshot = [p.to_dict() for p in presence_info] + + # Send join confirmation with protocol info and snapshot await self._send_to_client(client_id, { "type": "joined", "session_id": session.session_id, "agent_id": agent_id, "resumed": session._was_resumed, "cursor": session._event_cursor, + "sequence": session._sequence, # Current sequence for gap detection + "protocol_version": negotiated_version, + "server_min_version": MIN_PROTOCOL_VERSION, + "server_max_version": MAX_PROTOCOL_VERSION, + "presence": presence_snapshot, # Presence snapshot + "health": self.health(), # Health status }) # Replay missed events if any From b1d4853474b079a8c3a56d2cb26598f6cceda497 Mon Sep 17 00:00:00 2001 From: "praisonai-triage-agent[bot]" <272766704+praisonai-triage-agent[bot]@users.noreply.github.com> Date: Sun, 21 Jun 2026 09:38:31 +0000 Subject: [PATCH 2/2] fix: address reviewer feedback on gateway client and server - Fix GatewayClient.connect() to be non-blocking by running connection loop in background task - Add ValueError catch for version_unsupported to prevent infinite retry on permanent errors - Fix cursor update logic to handle zero values correctly (use 'is not None' check) - Add session persistence for sequence and protocol_version fields - Add protocol version validation with proper type checking for untrusted JSON - Fix sequence number in joined message to align with replay events to prevent false gap detection - Fix off-by-one error in reconnect attempts counter - Add timeout to WebSocket join handshake to prevent indefinite blocking - Remove unused imports (time, field, GatewayMessage) Co-authored-by: Mervin Praison --- src/praisonai/praisonai/gateway/client.py | 46 ++++++++++++++++++----- src/praisonai/praisonai/gateway/server.py | 33 ++++++++++++++-- 2 files changed, 65 insertions(+), 14 deletions(-) diff --git a/src/praisonai/praisonai/gateway/client.py b/src/praisonai/praisonai/gateway/client.py index e4722b10d..fe418d1a1 100644 --- a/src/praisonai/praisonai/gateway/client.py +++ b/src/praisonai/praisonai/gateway/client.py @@ -9,8 +9,7 @@ import json import logging import random -import time -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Any, AsyncIterator, Callable, Dict, List, Optional, Union try: @@ -21,7 +20,6 @@ from praisonaiagents.gateway import ( GatewayEvent, - GatewayMessage, EventType, PROTOCOL_VERSION, MIN_PROTOCOL_VERSION, @@ -143,13 +141,25 @@ def _calculate_backoff(self) -> float: return max(0, delay + jitter) async def connect(self) -> None: - """Connect to the gateway with automatic reconnection.""" + """Start the connection loop as a background task. + + This method starts the reconnection loop in the background and returns + immediately. Use events() to receive events after calling connect(). + + Example: + await client.connect() # Returns immediately + async for event in client.events(): + print(event) + """ if self._running: return self._running = True self._reconnect_attempts = 0 - + self._connect_task = asyncio.create_task(self._connection_loop()) + + async def _connection_loop(self) -> None: + """Connection loop with automatic reconnection.""" while self._running: try: await self._connect_once() @@ -163,6 +173,11 @@ async def connect(self) -> None: # Wait for disconnect or stop await self._receive_task + except ValueError as e: + # Protocol version mismatch is a permanent error + logger.error(f"Connection failed permanently: {e}") + self._running = False + raise except Exception as e: logger.error(f"Connection error: {e}") @@ -175,8 +190,8 @@ async def connect(self) -> None: break # Calculate backoff delay - delay = self._calculate_backoff() self._reconnect_attempts += 1 + delay = self._calculate_backoff() logger.info(f"Reconnecting in {delay:.1f}s (attempt {self._reconnect_attempts})") self._set_state(ConnectionState.RECONNECTING) @@ -212,8 +227,11 @@ async def _connect_once(self) -> None: await self._ws.send(json.dumps(join_msg)) - # Wait for join response - response = await self._ws.recv() + # Wait for join response with timeout + try: + response = await asyncio.wait_for(self._ws.recv(), timeout=10.0) + except asyncio.TimeoutError: + raise ConnectionError("Join handshake timed out") data = json.loads(response) if data.get("type") == "error": @@ -295,7 +313,7 @@ async def _handle_event(self, event: GatewayEvent) -> None: # Update cursor if present cursor = event.data.get("cursor") - if cursor: + if cursor is not None: self._cursor = cursor # Queue the event @@ -305,6 +323,14 @@ async def disconnect(self) -> None: """Disconnect from the gateway.""" self._running = False + # Cancel connect task if running + if hasattr(self, '_connect_task') and self._connect_task: + self._connect_task.cancel() + try: + await self._connect_task + except asyncio.CancelledError: + pass + if self._receive_task: self._receive_task.cancel() try: @@ -390,7 +416,7 @@ def on_state_change(state: str): client.on_state_change = on_state_change try: - # Connect to gateway + # Start connection (returns immediately) await client.connect() # Process events diff --git a/src/praisonai/praisonai/gateway/server.py b/src/praisonai/praisonai/gateway/server.py index c73833e6a..69924a1ab 100644 --- a/src/praisonai/praisonai/gateway/server.py +++ b/src/praisonai/praisonai/gateway/server.py @@ -165,6 +165,8 @@ def to_dict(self) -> Dict[str, Any]: "metadata": msg.metadata, } for msg in self._messages], "event_cursor": self._event_cursor, + "sequence": self._sequence, + "protocol_version": self._protocol_version, "events": [e.to_dict() for e in self._events[-100:]], # Keep last 100 events "pending_inbox": pending_inbox, "is_executing": self._is_executing, @@ -201,6 +203,8 @@ def from_dict(cls, data: Dict[str, Any], max_messages: int = 1000) -> 'GatewaySe # Restore event cursor and events session._event_cursor = data.get("event_cursor", 0) + session._sequence = data.get("sequence", session._event_cursor) + session._protocol_version = data.get("protocol_version", PROTOCOL_VERSION) for event_data in data.get("events", []): event = GatewayEvent.from_dict(event_data) session._events.append(event) @@ -1272,9 +1276,25 @@ async def _handle_client_message(self, client_id: str, data: Dict[str, Any]) -> elif msg_type == "join": agent_id = data.get("agent_id") if agent_id and agent_id in self._agents: - # Protocol version negotiation - client_min_version = data.get("min_version", MIN_PROTOCOL_VERSION) - client_max_version = data.get("max_version", PROTOCOL_VERSION) + # Protocol version negotiation with validation + try: + client_min_version = int(data.get("min_version", MIN_PROTOCOL_VERSION)) + client_max_version = int(data.get("max_version", PROTOCOL_VERSION)) + except (TypeError, ValueError): + await self._send_to_client(client_id, { + "type": "error", + "code": "invalid_protocol_hello", + "message": "Invalid protocol version fields. Expected integer min_version/max_version.", + }) + return + + if client_min_version > client_max_version: + await self._send_to_client(client_id, { + "type": "error", + "code": "invalid_protocol_hello", + "message": f"Invalid version range: min_version ({client_min_version}) > max_version ({client_max_version})", + }) + return # Check if we can negotiate a common version if client_max_version < MIN_PROTOCOL_VERSION or client_min_version > MAX_PROTOCOL_VERSION: @@ -1315,6 +1335,11 @@ async def _handle_client_message(self, client_id: str, data: Dict[str, Any]) -> presence_info = self._presence_manager.get_all_presence() presence_snapshot = [p.to_dict() for p in presence_info] + # Calculate correct sequence for replay + joined_sequence = session._sequence + if replay_events and replay_events[0].sequence is not None: + joined_sequence = replay_events[0].sequence - 1 + # Send join confirmation with protocol info and snapshot await self._send_to_client(client_id, { "type": "joined", @@ -1322,7 +1347,7 @@ async def _handle_client_message(self, client_id: str, data: Dict[str, Any]) -> "agent_id": agent_id, "resumed": session._was_resumed, "cursor": session._event_cursor, - "sequence": session._sequence, # Current sequence for gap detection + "sequence": joined_sequence, # Sequence aligned with replay events "protocol_version": negotiated_version, "server_min_version": MIN_PROTOCOL_VERSION, "server_max_version": MAX_PROTOCOL_VERSION,