Skip to content

Commit 7cb8e33

Browse files
committed
Add TLS transport for Remote Desktop host and viewer
RemoteDesktopHost and RemoteDesktopViewer now accept an ssl.SSLContext; when provided, the host wraps each accepted connection server-side and the viewer wraps the connect socket client-side. Failed handshakes on the host are logged and the raw socket is closed before the client handler is registered, so a TLS-only host can be hit by plain TCP viewers without leaking entries into the connected_clients counter. Tests use a self-signed loopback certificate generated with cryptography to cover: full TLS round-trip with both a trusting and an insecure client context, plain viewer rejected against a TLS host, TLS-only viewer rejected against a plain host, and confirmation that the wrapped socket is an SSLSocket after connect.
1 parent 4403537 commit 7cb8e33

4 files changed

Lines changed: 252 additions & 8 deletions

File tree

je_auto_control/utils/remote_desktop/host.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""TCP host that streams JPEG frames and applies viewer input."""
22
import json
33
import socket
4+
import ssl
45
import threading
56
import time
67
from io import BytesIO
@@ -214,6 +215,7 @@ def __init__(self, token: str,
214215
frame_provider: Optional[FrameProvider] = None,
215216
input_dispatcher: Optional[InputDispatcher] = None,
216217
host_id: Optional[str] = None,
218+
ssl_context: Optional[ssl.SSLContext] = None,
217219
) -> None:
218220
if not isinstance(token, str) or not token:
219221
raise ValueError("token must be a non-empty string")
@@ -224,6 +226,7 @@ def __init__(self, token: str,
224226
self._host_id = (validate_host_id(host_id) if host_id
225227
else load_or_create_host_id())
226228
self._token = token
229+
self._ssl_context = ssl_context
227230
self._bind = bind
228231
self._requested_port = int(port)
229232
self._period = 1.0 / float(fps)
@@ -332,7 +335,10 @@ def _accept_loop(self) -> None:
332335
continue
333336
except OSError:
334337
return
335-
handler = _ClientHandler(self, client_sock, address)
338+
wrapped = self._maybe_wrap_tls(client_sock, address)
339+
if wrapped is None:
340+
continue
341+
handler = _ClientHandler(self, wrapped, address)
336342
with self._clients_lock:
337343
if len(self._clients) >= self._max_clients:
338344
autocontrol_logger.info(
@@ -345,6 +351,29 @@ def _accept_loop(self) -> None:
345351
handler.start()
346352
self._reap_dead_clients()
347353

354+
def _maybe_wrap_tls(self, client_sock: socket.socket,
355+
address) -> Optional[socket.socket]:
356+
"""Return a TLS-wrapped socket when an ssl_context is configured."""
357+
if self._ssl_context is None:
358+
return client_sock
359+
try:
360+
client_sock.settimeout(_AUTH_TIMEOUT_S)
361+
wrapped = self._ssl_context.wrap_socket(
362+
client_sock, server_side=True,
363+
)
364+
wrapped.settimeout(None)
365+
return wrapped
366+
except (ssl.SSLError, OSError) as error:
367+
autocontrol_logger.info(
368+
"remote_desktop TLS handshake from %s failed: %r",
369+
address, error,
370+
)
371+
try:
372+
client_sock.close()
373+
except OSError:
374+
pass
375+
return None
376+
348377
def _capture_loop(self) -> None:
349378
next_tick = time.monotonic()
350379
while not self._shutdown.is_set():

je_auto_control/utils/remote_desktop/registry.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
references here keeps :mod:`action_executor` thin and avoids circular
66
imports between the executor and the host/viewer classes.
77
"""
8+
import ssl
89
from typing import Any, Callable, Dict, Optional, Sequence
910

1011
from je_auto_control.utils.remote_desktop.host import RemoteDesktopHost
@@ -36,14 +37,16 @@ def start_host(self, token: str,
3637
quality: int = 70,
3738
region: Optional[Sequence[int]] = None,
3839
max_clients: int = 4,
39-
host_id: Optional[str] = None) -> Dict[str, Any]:
40+
host_id: Optional[str] = None,
41+
ssl_context: Optional[ssl.SSLContext] = None,
42+
) -> Dict[str, Any]:
4043
"""Stop any existing host, then start a fresh one with the given config."""
4144
self.stop_host()
4245
host = RemoteDesktopHost(
4346
token=token, bind=bind, port=int(port),
4447
fps=float(fps), quality=int(quality),
4548
region=region, max_clients=int(max_clients),
46-
host_id=host_id,
49+
host_id=host_id, ssl_context=ssl_context,
4750
)
4851
host.start()
4952
self._host = host
@@ -75,19 +78,24 @@ def connect_viewer(self, host: str, port: int, token: str,
7578
on_frame: Optional[FrameCallback] = None,
7679
on_error: Optional[ErrorCallback] = None,
7780
expected_host_id: Optional[str] = None,
81+
ssl_context: Optional[ssl.SSLContext] = None,
82+
server_hostname: Optional[str] = None,
7883
) -> Dict[str, Any]:
7984
"""Disconnect any existing viewer, then connect a fresh one.
8085
8186
``on_frame`` and ``on_error`` are wired before the receiver
8287
thread starts, so no frame can arrive while the GUI is still
8388
attaching its callbacks. When ``expected_host_id`` is provided
8489
the handshake is rejected if the server reports a different ID.
90+
Pass an ``ssl_context`` to upgrade the connection to TLS.
8591
"""
8692
self.disconnect_viewer()
8793
viewer = RemoteDesktopViewer(
8894
host=host, port=int(port), token=token,
8995
on_frame=on_frame, on_error=on_error,
9096
expected_host_id=expected_host_id,
97+
ssl_context=ssl_context,
98+
server_hostname=server_hostname,
9199
)
92100
viewer.connect(timeout=float(timeout))
93101
self._viewer = viewer

je_auto_control/utils/remote_desktop/viewer.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""TCP viewer that receives JPEG frames and forwards input messages."""
22
import json
33
import socket
4+
import ssl
45
import threading
56
from typing import Any, Callable, Mapping, Optional
67

@@ -43,6 +44,8 @@ def __init__(self, host: str, port: int, token: str,
4344
on_frame: Optional[FrameCallback] = None,
4445
on_error: Optional[ErrorCallback] = None,
4546
expected_host_id: Optional[str] = None,
47+
ssl_context: Optional[ssl.SSLContext] = None,
48+
server_hostname: Optional[str] = None,
4649
) -> None:
4750
if not isinstance(host, str) or not host:
4851
raise ValueError("host must be a non-empty string")
@@ -56,6 +59,8 @@ def __init__(self, host: str, port: int, token: str,
5659
self._expected_host_id = (validate_host_id(expected_host_id)
5760
if expected_host_id else None)
5861
self._remote_host_id: Optional[str] = None
62+
self._ssl_context = ssl_context
63+
self._server_hostname = server_hostname
5964
self._sock: Optional[socket.socket] = None
6065
self._send_lock = threading.Lock()
6166
self._shutdown = threading.Event()
@@ -72,22 +77,23 @@ def remote_host_id(self) -> Optional[str]:
7277
return self._remote_host_id
7378

7479
def connect(self, timeout: float = _DEFAULT_CONNECT_TIMEOUT_S) -> None:
75-
"""Open the TCP connection and complete the auth handshake.
80+
"""Open the (optionally TLS) connection and complete the auth handshake.
7681
7782
Spawns a receiver thread on success. Raises
7883
:class:`AuthenticationError` if the handshake fails.
7984
"""
8085
if self._connected:
8186
return
82-
sock = socket.create_connection(
87+
raw_sock = socket.create_connection(
8388
(self._host, self._port), timeout=timeout,
8489
)
85-
sock.settimeout(_DEFAULT_AUTH_TIMEOUT_S)
90+
raw_sock.settimeout(_DEFAULT_AUTH_TIMEOUT_S)
8691
try:
92+
sock = self._maybe_wrap_tls(raw_sock)
8793
self._handshake(sock)
88-
except (AuthenticationError, ProtocolError, OSError):
94+
except (AuthenticationError, ProtocolError, OSError, ssl.SSLError):
8995
try:
90-
sock.close()
96+
raw_sock.close()
9197
except OSError:
9298
pass
9399
raise
@@ -100,6 +106,19 @@ def connect(self, timeout: float = _DEFAULT_CONNECT_TIMEOUT_S) -> None:
100106
)
101107
self._receiver.start()
102108

109+
def _maybe_wrap_tls(self, raw_sock: socket.socket) -> socket.socket:
110+
"""Return a TLS-wrapped socket when an ssl_context was configured."""
111+
if self._ssl_context is None:
112+
return raw_sock
113+
hostname = self._server_hostname or self._host
114+
if (self._ssl_context.check_hostname is False
115+
and self._ssl_context.verify_mode == ssl.CERT_NONE):
116+
# ``wrap_socket`` rejects server_hostname when verification is off.
117+
hostname = None
118+
return self._ssl_context.wrap_socket(
119+
raw_sock, server_hostname=hostname,
120+
)
121+
103122
def disconnect(self, timeout: float = 2.0) -> None:
104123
"""Close the connection and join the receiver thread."""
105124
self._shutdown.set()
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
"""End-to-end TLS tests using a self-signed loopback certificate."""
2+
import datetime
3+
import ipaddress
4+
import socket
5+
import ssl
6+
import time
7+
from pathlib import Path
8+
from typing import Tuple
9+
10+
import pytest
11+
12+
cryptography = pytest.importorskip("cryptography")
13+
14+
from cryptography import x509 # noqa: E402
15+
from cryptography.hazmat.primitives import hashes, serialization # noqa: E402
16+
from cryptography.hazmat.primitives.asymmetric import rsa # noqa: E402
17+
from cryptography.x509.oid import NameOID # noqa: E402
18+
19+
from je_auto_control.utils.remote_desktop import (
20+
RemoteDesktopHost, RemoteDesktopViewer,
21+
)
22+
from je_auto_control.utils.remote_desktop.protocol import AuthenticationError
23+
24+
25+
def _wait_until(predicate, timeout: float = 2.0,
26+
interval: float = 0.02) -> bool:
27+
deadline = time.monotonic() + timeout
28+
while time.monotonic() < deadline:
29+
if predicate():
30+
return True
31+
time.sleep(interval)
32+
return predicate()
33+
34+
35+
def _generate_self_signed(tmp_path: Path) -> Tuple[Path, Path]:
36+
"""Write a self-signed cert + key for ``127.0.0.1`` to ``tmp_path``."""
37+
key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
38+
name = x509.Name([
39+
x509.NameAttribute(NameOID.COMMON_NAME, "remote-desktop-test"),
40+
])
41+
now = datetime.datetime.now(datetime.timezone.utc)
42+
cert = (
43+
x509.CertificateBuilder()
44+
.subject_name(name)
45+
.issuer_name(name)
46+
.public_key(key.public_key())
47+
.serial_number(x509.random_serial_number())
48+
.not_valid_before(now - datetime.timedelta(minutes=1))
49+
.not_valid_after(now + datetime.timedelta(days=1))
50+
.add_extension(
51+
x509.SubjectAlternativeName([
52+
x509.IPAddress(ipaddress.IPv4Address("127.0.0.1")),
53+
x509.DNSName("localhost"),
54+
]),
55+
critical=False,
56+
)
57+
.sign(private_key=key, algorithm=hashes.SHA256())
58+
)
59+
cert_path = tmp_path / "cert.pem"
60+
key_path = tmp_path / "key.pem"
61+
cert_path.write_bytes(
62+
cert.public_bytes(serialization.Encoding.PEM)
63+
)
64+
key_path.write_bytes(
65+
key.private_bytes(
66+
encoding=serialization.Encoding.PEM,
67+
format=serialization.PrivateFormat.TraditionalOpenSSL,
68+
encryption_algorithm=serialization.NoEncryption(),
69+
)
70+
)
71+
return cert_path, key_path
72+
73+
74+
def _server_context(cert_path: Path, key_path: Path) -> ssl.SSLContext:
75+
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
76+
ctx.load_cert_chain(certfile=str(cert_path), keyfile=str(key_path))
77+
return ctx
78+
79+
80+
def _trusting_client_context(ca_path: Path) -> ssl.SSLContext:
81+
"""Verifying client context that trusts only the supplied test CA cert."""
82+
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
83+
ctx.load_verify_locations(cafile=str(ca_path))
84+
ctx.check_hostname = True
85+
ctx.verify_mode = ssl.CERT_REQUIRED
86+
return ctx
87+
88+
89+
def _insecure_client_context() -> ssl.SSLContext:
90+
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
91+
ctx.check_hostname = False
92+
ctx.verify_mode = ssl.CERT_NONE
93+
return ctx
94+
95+
96+
def _start_tls_host(tmp_path: Path) -> Tuple[RemoteDesktopHost, Path, Path]:
97+
cert_path, key_path = _generate_self_signed(tmp_path)
98+
server_ctx = _server_context(cert_path, key_path)
99+
host = RemoteDesktopHost(
100+
token="tok", bind="127.0.0.1", port=0, fps=50.0,
101+
frame_provider=lambda: b"tls-frame",
102+
input_dispatcher=lambda *_a, **_k: None,
103+
host_id="111111111", ssl_context=server_ctx,
104+
)
105+
host.start()
106+
return host, cert_path, key_path
107+
108+
109+
def test_tls_round_trip_with_trusting_client(tmp_path):
110+
host, cert_path, _ = _start_tls_host(tmp_path)
111+
try:
112+
client_ctx = _trusting_client_context(cert_path)
113+
received = []
114+
viewer = RemoteDesktopViewer(
115+
host="127.0.0.1", port=host.port, token="tok",
116+
on_frame=received.append,
117+
ssl_context=client_ctx,
118+
)
119+
viewer.connect(timeout=2.0)
120+
assert _wait_until(lambda: len(received) >= 1, timeout=2.0)
121+
assert all(frame == b"tls-frame" for frame in received)
122+
viewer.disconnect()
123+
finally:
124+
host.stop(timeout=1.0)
125+
126+
127+
def test_tls_round_trip_with_insecure_client(tmp_path):
128+
host, _, _ = _start_tls_host(tmp_path)
129+
try:
130+
viewer = RemoteDesktopViewer(
131+
host="127.0.0.1", port=host.port, token="tok",
132+
ssl_context=_insecure_client_context(),
133+
)
134+
viewer.connect(timeout=2.0)
135+
assert viewer.connected
136+
viewer.disconnect()
137+
finally:
138+
host.stop(timeout=1.0)
139+
140+
141+
def test_plain_viewer_against_tls_host_fails(tmp_path):
142+
"""A non-TLS viewer cannot finish the handshake against a TLS host."""
143+
host, _, _ = _start_tls_host(tmp_path)
144+
try:
145+
viewer = RemoteDesktopViewer(
146+
host="127.0.0.1", port=host.port, token="tok",
147+
)
148+
with pytest.raises((OSError, AuthenticationError)):
149+
viewer.connect(timeout=2.0)
150+
# Host should refuse to count an incomplete handshake as connected.
151+
assert _wait_until(lambda: host.connected_clients == 0, timeout=2.0)
152+
finally:
153+
host.stop(timeout=1.0)
154+
155+
156+
def test_tls_client_against_plain_host_fails():
157+
"""A TLS-only viewer cannot speak to a plain TCP host."""
158+
host = RemoteDesktopHost(
159+
token="tok", bind="127.0.0.1", port=0, fps=50.0,
160+
frame_provider=lambda: b"plain",
161+
input_dispatcher=lambda *_a, **_k: None,
162+
host_id="222222222",
163+
)
164+
host.start()
165+
try:
166+
viewer = RemoteDesktopViewer(
167+
host="127.0.0.1", port=host.port, token="tok",
168+
ssl_context=_insecure_client_context(),
169+
)
170+
with pytest.raises((OSError, ssl.SSLError, AuthenticationError)):
171+
viewer.connect(timeout=2.0)
172+
finally:
173+
host.stop(timeout=1.0)
174+
175+
176+
def test_tls_uses_socket_class_after_wrap(tmp_path):
177+
"""After connect, the viewer's socket should be an SSLSocket."""
178+
host, cert_path, _ = _start_tls_host(tmp_path)
179+
try:
180+
viewer = RemoteDesktopViewer(
181+
host="127.0.0.1", port=host.port, token="tok",
182+
ssl_context=_trusting_client_context(cert_path),
183+
)
184+
viewer.connect(timeout=2.0)
185+
assert isinstance(viewer._sock, ssl.SSLSocket) # noqa: SLF001
186+
viewer.disconnect()
187+
finally:
188+
host.stop(timeout=1.0)

0 commit comments

Comments
 (0)