Skip to content

Commit fcdf352

Browse files
committed
Add WebSocket transport (ws:// + wss://) for Remote Desktop
A new MessageChannel abstraction lets the host and viewer speak the existing typed-message protocol over either raw TCP framing or WebSocket BINARY frames. Each WS frame carries one full encoded typed message (magic + type + length + payload), so decode_frame_header / encode_frame are reused unchanged and only the wire layer changes. ws_protocol.py is a small RFC 6455 implementation (no extra deps): server / client handshake helpers, single-frame BINARY send, recv that transparently handles PING / PONG / CLOSE control frames, and explicit rejection of fragmented data frames so messages always fit in one ~16 MiB frame. Clients mask outgoing payloads as required; servers do not. WebSocketDesktopHost and WebSocketDesktopViewer are thin subclasses that override the channel-creation hook to perform the upgrade handshake before falling back to the shared auth + receive loop. The existing ssl_context plumbing stays in place — passing a context to WebSocketDesktopHost/Viewer transparently upgrades the connection to wss://, so no separate TLS-WS class is needed. Tests cover ws_protocol round trips (handshake, masked + unmasked binary frames, extended payload length, bad-request rejection) and end-to-end host<->viewer scenarios (auth, frame stream, input dispatch, host_id announce, mixed-transport rejection in both directions, path validation).
1 parent 7cb8e33 commit fcdf352

8 files changed

Lines changed: 728 additions & 57 deletions

File tree

je_auto_control/utils/remote_desktop/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,14 @@
2323
)
2424
from je_auto_control.utils.remote_desktop.registry import registry
2525
from je_auto_control.utils.remote_desktop.viewer import RemoteDesktopViewer
26+
from je_auto_control.utils.remote_desktop.ws_host import WebSocketDesktopHost
27+
from je_auto_control.utils.remote_desktop.ws_viewer import (
28+
WebSocketDesktopViewer,
29+
)
2630

2731
__all__ = [
2832
"RemoteDesktopHost", "RemoteDesktopViewer",
33+
"WebSocketDesktopHost", "WebSocketDesktopViewer",
2934
"InputDispatchError", "AuthenticationError", "ProtocolError",
3035
"MessageType", "encode_frame", "decode_frame_header",
3136
"dispatch_input", "registry",

je_auto_control/utils/remote_desktop/host.py

Lines changed: 36 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
)
2020
from je_auto_control.utils.remote_desktop.protocol import (
2121
AuthenticationError, MessageType, ProtocolError,
22-
encode_frame, read_message,
22+
)
23+
from je_auto_control.utils.remote_desktop.transport import (
24+
MessageChannel, TcpMessageChannel,
2325
)
2426

2527
FrameProvider = Callable[[], bytes]
@@ -51,12 +53,11 @@ def provide() -> bytes:
5153
class _ClientHandler:
5254
"""Per-connection auth + input-receive + frame-send state."""
5355

54-
def __init__(self, host: "RemoteDesktopHost", sock: socket.socket,
55-
address) -> None:
56+
def __init__(self, host: "RemoteDesktopHost",
57+
channel: MessageChannel, address) -> None:
5658
self._host = host
57-
self._sock = sock
59+
self._channel = channel
5860
self._address = address
59-
self._send_lock = threading.Lock()
6061
self._shutdown = threading.Event()
6162
self._sender_thread: Optional[threading.Thread] = None
6263
self._receiver_thread: Optional[threading.Thread] = None
@@ -95,27 +96,23 @@ def stop(self) -> None:
9596

9697
def _authenticate(self) -> None:
9798
nonce = make_nonce()
98-
self._sock.settimeout(_AUTH_TIMEOUT_S)
99-
self._send(MessageType.AUTH_CHALLENGE, nonce)
100-
msg_type, payload = read_message(self._sock)
99+
self._channel.settimeout(_AUTH_TIMEOUT_S)
100+
self._channel.send_typed(MessageType.AUTH_CHALLENGE, nonce)
101+
msg_type, payload = self._channel.read_typed()
101102
if msg_type is not MessageType.AUTH_RESPONSE:
102-
self._send(MessageType.AUTH_FAIL, b"expected AUTH_RESPONSE")
103+
self._channel.send_typed(MessageType.AUTH_FAIL,
104+
b"expected AUTH_RESPONSE")
103105
raise AuthenticationError(
104106
f"expected AUTH_RESPONSE, got {msg_type.name}"
105107
)
106108
if not verify_response(self._host._token, nonce, payload):
107-
self._send(MessageType.AUTH_FAIL, b"bad token")
109+
self._channel.send_typed(MessageType.AUTH_FAIL, b"bad token")
108110
raise AuthenticationError("bad token")
109111
ok_payload = json.dumps(
110112
{"host_id": self._host.host_id}, ensure_ascii=False,
111113
).encode("utf-8")
112-
self._send(MessageType.AUTH_OK, ok_payload)
113-
self._sock.settimeout(None)
114-
115-
def _send(self, message_type: MessageType, payload: bytes) -> None:
116-
data = encode_frame(message_type, payload)
117-
with self._send_lock:
118-
self._sock.sendall(data)
114+
self._channel.send_typed(MessageType.AUTH_OK, ok_payload)
115+
self._channel.settimeout(None)
119116

120117
def _send_loop(self) -> None:
121118
last_sent = 0
@@ -131,7 +128,7 @@ def _send_loop(self) -> None:
131128
if frame is None:
132129
continue
133130
try:
134-
self._send(MessageType.FRAME, frame)
131+
self._channel.send_typed(MessageType.FRAME, frame)
135132
except (OSError, ConnectionError) as error:
136133
autocontrol_logger.info(
137134
"remote_desktop send to %s failed: %r",
@@ -144,7 +141,7 @@ def _send_loop(self) -> None:
144141
def _recv_loop(self) -> None:
145142
while not self._shutdown.is_set():
146143
try:
147-
msg_type, payload = read_message(self._sock)
144+
msg_type, payload = self._channel.read_typed()
148145
except (OSError, ConnectionError, ProtocolError) as error:
149146
if not self._shutdown.is_set():
150147
autocontrol_logger.info(
@@ -186,14 +183,7 @@ def _handle_input_payload(self, payload: bytes) -> None:
186183
)
187184

188185
def _close(self) -> None:
189-
try:
190-
self._sock.shutdown(socket.SHUT_RDWR)
191-
except OSError:
192-
pass
193-
try:
194-
self._sock.close()
195-
except OSError:
196-
pass
186+
self._channel.close()
197187

198188

199189
class RemoteDesktopHost:
@@ -338,7 +328,19 @@ def _accept_loop(self) -> None:
338328
wrapped = self._maybe_wrap_tls(client_sock, address)
339329
if wrapped is None:
340330
continue
341-
handler = _ClientHandler(self, wrapped, address)
331+
try:
332+
channel = self._build_channel(wrapped, address)
333+
except (OSError, RuntimeError) as error:
334+
autocontrol_logger.info(
335+
"remote_desktop channel handshake from %s failed: %r",
336+
address, error,
337+
)
338+
try:
339+
wrapped.close()
340+
except OSError:
341+
pass
342+
continue
343+
handler = _ClientHandler(self, channel, address)
342344
with self._clients_lock:
343345
if len(self._clients) >= self._max_clients:
344346
autocontrol_logger.info(
@@ -351,6 +353,12 @@ def _accept_loop(self) -> None:
351353
handler.start()
352354
self._reap_dead_clients()
353355

356+
def _build_channel(self, sock: socket.socket,
357+
address) -> MessageChannel:
358+
"""Hook for transports: TCP wraps directly, WS overrides this."""
359+
del address
360+
return TcpMessageChannel(sock)
361+
354362
def _maybe_wrap_tls(self, client_sock: socket.socket,
355363
address) -> Optional[socket.socket]:
356364
"""Return a TLS-wrapped socket when an ssl_context is configured."""
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
"""Pluggable typed-message transport for the remote-desktop protocol.
2+
3+
The host and viewer always exchange the same typed messages
4+
(``MessageType`` from :mod:`protocol`), but the wire layer can be either
5+
the original raw-TCP framing or WebSocket binary frames. ``MessageChannel``
6+
hides that distinction so the rest of the codebase deals with
7+
``send_typed`` / ``read_typed`` only.
8+
"""
9+
import socket
10+
import threading
11+
from typing import Tuple
12+
13+
from je_auto_control.utils.remote_desktop.protocol import (
14+
HEADER_SIZE, MessageType, ProtocolError,
15+
decode_frame_header, encode_frame, read_message,
16+
)
17+
from je_auto_control.utils.remote_desktop.ws_protocol import (
18+
recv_message as ws_recv_message,
19+
send_binary as ws_send_binary,
20+
send_close as ws_send_close,
21+
)
22+
23+
24+
class MessageChannel:
25+
"""Abstract bidirectional typed-message endpoint."""
26+
27+
def send_typed(self, message_type: MessageType, payload: bytes) -> None:
28+
raise NotImplementedError
29+
30+
def read_typed(self) -> Tuple[MessageType, bytes]:
31+
raise NotImplementedError
32+
33+
def settimeout(self, timeout) -> None:
34+
raise NotImplementedError
35+
36+
def close(self) -> None:
37+
raise NotImplementedError
38+
39+
40+
class TcpMessageChannel(MessageChannel):
41+
"""Original transport: each typed message is one length-prefixed frame."""
42+
43+
def __init__(self, sock: socket.socket) -> None:
44+
self._sock = sock
45+
self._send_lock = threading.Lock()
46+
47+
def send_typed(self, message_type: MessageType, payload: bytes) -> None:
48+
data = encode_frame(message_type, payload)
49+
with self._send_lock:
50+
self._sock.sendall(data)
51+
52+
def read_typed(self) -> Tuple[MessageType, bytes]:
53+
return read_message(self._sock)
54+
55+
def settimeout(self, timeout) -> None:
56+
self._sock.settimeout(timeout)
57+
58+
def close(self) -> None:
59+
try:
60+
self._sock.shutdown(socket.SHUT_RDWR)
61+
except OSError:
62+
pass
63+
try:
64+
self._sock.close()
65+
except OSError:
66+
pass
67+
68+
@property
69+
def sock(self) -> socket.socket:
70+
return self._sock
71+
72+
73+
class WsMessageChannel(MessageChannel):
74+
"""WebSocket transport: each WS BINARY frame carries one typed message.
75+
76+
The WS payload is the existing typed-frame encoding (magic + type +
77+
length + body), so :func:`decode_frame_header` and :func:`encode_frame`
78+
are reused unchanged. ``mask_outgoing`` follows RFC 6455: clients must
79+
mask, servers must not.
80+
"""
81+
82+
def __init__(self, sock: socket.socket, mask_outgoing: bool) -> None:
83+
self._sock = sock
84+
self._mask = bool(mask_outgoing)
85+
self._send_lock = threading.Lock()
86+
87+
def send_typed(self, message_type: MessageType, payload: bytes) -> None:
88+
data = encode_frame(message_type, payload)
89+
with self._send_lock:
90+
ws_send_binary(self._sock, data, mask=self._mask)
91+
92+
def read_typed(self) -> Tuple[MessageType, bytes]:
93+
ws_payload = ws_recv_message(self._sock)
94+
if len(ws_payload) < HEADER_SIZE:
95+
raise ProtocolError("WS payload too short to contain typed header")
96+
msg_type, length = decode_frame_header(ws_payload[:HEADER_SIZE])
97+
body = ws_payload[HEADER_SIZE:HEADER_SIZE + length]
98+
if len(body) != length:
99+
raise ProtocolError(
100+
f"declared length {length} but ws payload had {len(body)}"
101+
)
102+
return msg_type, body
103+
104+
def settimeout(self, timeout) -> None:
105+
self._sock.settimeout(timeout)
106+
107+
def close(self) -> None:
108+
try:
109+
ws_send_close(self._sock, mask=self._mask)
110+
except OSError:
111+
pass
112+
try:
113+
self._sock.shutdown(socket.SHUT_RDWR)
114+
except OSError:
115+
pass
116+
try:
117+
self._sock.close()
118+
except OSError:
119+
pass
120+
121+
@property
122+
def sock(self) -> socket.socket:
123+
return self._sock

0 commit comments

Comments
 (0)