Skip to content

Commit 4403537

Browse files
committed
Add persistent host ID handshake for Remote Desktop
Each host now exposes a stable 9-digit numeric ID — short enough to read aloud, persisted at ~/.je_auto_control/remote_host_id so it stays the same across restarts. The ID is announced inside AUTH_OK as JSON so only authenticated viewers see it. Viewers that pass expected_host_id raise AuthenticationError when the announced ID does not match, defending against TCP-level impersonation by a different process listening on the same address. The ID is *not* a substitute for the auth token — token-based HMAC gates the actual session; the ID is meant to be shared (token + ID together identify a host).
1 parent a0f62bd commit 4403537

6 files changed

Lines changed: 306 additions & 6 deletions

File tree

je_auto_control/utils/remote_desktop/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
front-end.
1111
"""
1212
from je_auto_control.utils.remote_desktop.host import RemoteDesktopHost
13+
from je_auto_control.utils.remote_desktop.host_id import (
14+
HostIdError, format_host_id, generate_host_id, load_or_create_host_id,
15+
parse_host_id, validate_host_id,
16+
)
1317
from je_auto_control.utils.remote_desktop.input_dispatch import (
1418
InputDispatchError, dispatch_input,
1519
)
@@ -25,4 +29,6 @@
2529
"InputDispatchError", "AuthenticationError", "ProtocolError",
2630
"MessageType", "encode_frame", "decode_frame_header",
2731
"dispatch_input", "registry",
32+
"HostIdError", "format_host_id", "generate_host_id",
33+
"load_or_create_host_id", "parse_host_id", "validate_host_id",
2834
]

je_auto_control/utils/remote_desktop/host.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from je_auto_control.utils.remote_desktop.auth import (
1111
NONCE_BYTES, make_nonce, verify_response,
1212
)
13+
from je_auto_control.utils.remote_desktop.host_id import (
14+
load_or_create_host_id, validate_host_id,
15+
)
1316
from je_auto_control.utils.remote_desktop.input_dispatch import (
1417
InputDispatchError, dispatch_input,
1518
)
@@ -102,7 +105,10 @@ def _authenticate(self) -> None:
102105
if not verify_response(self._host._token, nonce, payload):
103106
self._send(MessageType.AUTH_FAIL, b"bad token")
104107
raise AuthenticationError("bad token")
105-
self._send(MessageType.AUTH_OK, b"")
108+
ok_payload = json.dumps(
109+
{"host_id": self._host.host_id}, ensure_ascii=False,
110+
).encode("utf-8")
111+
self._send(MessageType.AUTH_OK, ok_payload)
106112
self._sock.settimeout(None)
107113

108114
def _send(self, message_type: MessageType, payload: bytes) -> None:
@@ -207,13 +213,16 @@ def __init__(self, token: str,
207213
max_clients: int = 4,
208214
frame_provider: Optional[FrameProvider] = None,
209215
input_dispatcher: Optional[InputDispatcher] = None,
216+
host_id: Optional[str] = None,
210217
) -> None:
211218
if not isinstance(token, str) or not token:
212219
raise ValueError("token must be a non-empty string")
213220
if fps <= 0:
214221
raise ValueError("fps must be positive")
215222
if not 1 <= int(quality) <= 95:
216223
raise ValueError("quality must be in [1, 95]")
224+
self._host_id = (validate_host_id(host_id) if host_id
225+
else load_or_create_host_id())
217226
self._token = token
218227
self._bind = bind
219228
self._requested_port = int(port)
@@ -236,6 +245,11 @@ def __init__(self, token: str,
236245

237246
# public API ----------------------------------------------------------
238247

248+
@property
249+
def host_id(self) -> str:
250+
"""The 9-digit numeric ID viewers use to verify this host."""
251+
return self._host_id
252+
239253
@property
240254
def port(self) -> int:
241255
return self._port
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
"""Stable, persistent host identifier exposed during the auth handshake.
2+
3+
Each host has a 9-digit numeric ID — short enough to read aloud, long
4+
enough to be hard to guess by chance. The ID is generated on first use
5+
and cached at ``~/.je_auto_control/remote_host_id`` so it stays the same
6+
across restarts; users hand the ID + token + address to the people they
7+
want to connect, and viewers can verify ``expected_host_id`` after auth
8+
to defend against TCP-level impersonation.
9+
10+
The ID is *not* a substitute for the auth token — it is broadcast in
11+
plain text inside ``AUTH_OK`` and is meant to be shared. Token-based
12+
HMAC auth gates the actual session.
13+
"""
14+
import os
15+
import re
16+
import secrets
17+
from pathlib import Path
18+
from typing import Optional
19+
20+
_HOST_ID_DIGITS = 9
21+
_DEFAULT_PATH_RELATIVE = ".je_auto_control/remote_host_id"
22+
_HOST_ID_PATTERN = re.compile(r"^\d{9}$")
23+
24+
25+
class HostIdError(ValueError):
26+
"""Raised when a host ID is malformed."""
27+
28+
29+
def generate_host_id() -> str:
30+
"""Return a fresh random 9-digit host ID (zero-padded)."""
31+
return f"{secrets.randbelow(10 ** _HOST_ID_DIGITS):0{_HOST_ID_DIGITS}d}"
32+
33+
34+
def validate_host_id(value: str) -> str:
35+
"""Return ``value`` unchanged if it is a valid 9-digit host ID."""
36+
if not isinstance(value, str) or _HOST_ID_PATTERN.fullmatch(value) is None:
37+
raise HostIdError(
38+
f"host_id must be {_HOST_ID_DIGITS} numeric digits, got {value!r}"
39+
)
40+
return value
41+
42+
43+
def format_host_id(value: str) -> str:
44+
"""Render a host ID with grouping for display (e.g. ``123 456 789``)."""
45+
digits = validate_host_id(value)
46+
return f"{digits[:3]} {digits[3:6]} {digits[6:]}"
47+
48+
49+
def parse_host_id(value: str) -> str:
50+
"""Strip whitespace / separators from user input and validate."""
51+
if not isinstance(value, str):
52+
raise HostIdError(f"host_id must be a string, got {type(value).__name__}")
53+
cleaned = re.sub(r"[\s\-_]", "", value)
54+
return validate_host_id(cleaned)
55+
56+
57+
def default_host_id_path() -> Path:
58+
"""Return the on-disk path used to persist the host ID."""
59+
home = Path(os.path.expanduser("~"))
60+
return home / _DEFAULT_PATH_RELATIVE
61+
62+
63+
def load_or_create_host_id(path: Optional[Path] = None) -> str:
64+
"""Return the persisted host ID, creating one on first call."""
65+
target = Path(path) if path is not None else default_host_id_path()
66+
if target.exists():
67+
try:
68+
existing = target.read_text(encoding="utf-8").strip()
69+
return validate_host_id(existing)
70+
except (OSError, HostIdError):
71+
# Corrupt / unreadable — regenerate rather than fail the host.
72+
pass
73+
new_id = generate_host_id()
74+
try:
75+
target.parent.mkdir(parents=True, exist_ok=True)
76+
target.write_text(new_id, encoding="utf-8")
77+
except OSError:
78+
# Persisting is best-effort; an in-memory ID still works for the
79+
# current process even if the home directory is read-only.
80+
pass
81+
return new_id

je_auto_control/utils/remote_desktop/registry.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,15 @@ def start_host(self, token: str,
3535
fps: float = 10.0,
3636
quality: int = 70,
3737
region: Optional[Sequence[int]] = None,
38-
max_clients: int = 4) -> Dict[str, Any]:
38+
max_clients: int = 4,
39+
host_id: Optional[str] = None) -> Dict[str, Any]:
3940
"""Stop any existing host, then start a fresh one with the given config."""
4041
self.stop_host()
4142
host = RemoteDesktopHost(
4243
token=token, bind=bind, port=int(port),
4344
fps=float(fps), quality=int(quality),
4445
region=region, max_clients=int(max_clients),
46+
host_id=host_id,
4547
)
4648
host.start()
4749
self._host = host
@@ -57,28 +59,35 @@ def stop_host(self, timeout: float = 2.0) -> Dict[str, Any]:
5759
def host_status(self) -> Dict[str, Any]:
5860
host = self._host
5961
if host is None:
60-
return {"running": False, "port": 0, "connected_clients": 0}
62+
return {
63+
"running": False, "port": 0, "connected_clients": 0,
64+
"host_id": None,
65+
}
6166
return {
6267
"running": host.is_running,
6368
"port": host.port,
6469
"connected_clients": host.connected_clients,
70+
"host_id": host.host_id,
6571
}
6672

6773
def connect_viewer(self, host: str, port: int, token: str,
6874
timeout: float = 5.0,
6975
on_frame: Optional[FrameCallback] = None,
7076
on_error: Optional[ErrorCallback] = None,
77+
expected_host_id: Optional[str] = None,
7178
) -> Dict[str, Any]:
7279
"""Disconnect any existing viewer, then connect a fresh one.
7380
7481
``on_frame`` and ``on_error`` are wired before the receiver
7582
thread starts, so no frame can arrive while the GUI is still
76-
attaching its callbacks.
83+
attaching its callbacks. When ``expected_host_id`` is provided
84+
the handshake is rejected if the server reports a different ID.
7785
"""
7886
self.disconnect_viewer()
7987
viewer = RemoteDesktopViewer(
8088
host=host, port=int(port), token=token,
8189
on_frame=on_frame, on_error=on_error,
90+
expected_host_id=expected_host_id,
8291
)
8392
viewer.connect(timeout=float(timeout))
8493
self._viewer = viewer
@@ -94,8 +103,11 @@ def disconnect_viewer(self, timeout: float = 2.0) -> Dict[str, Any]:
94103
def viewer_status(self) -> Dict[str, Any]:
95104
viewer = self._viewer
96105
if viewer is None:
97-
return {"connected": False}
98-
return {"connected": viewer.connected}
106+
return {"connected": False, "host_id": None}
107+
return {
108+
"connected": viewer.connected,
109+
"host_id": viewer.remote_host_id,
110+
}
99111

100112
def send_input(self, action: Dict[str, Any]) -> Dict[str, Any]:
101113
"""Forward ``action`` through the connected viewer, raise if offline."""

je_auto_control/utils/remote_desktop/viewer.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from je_auto_control.utils.logging.logging_instance import autocontrol_logger
88
from je_auto_control.utils.remote_desktop.auth import compute_response
9+
from je_auto_control.utils.remote_desktop.host_id import validate_host_id
910
from je_auto_control.utils.remote_desktop.protocol import (
1011
AuthenticationError, MessageType, ProtocolError,
1112
encode_frame, read_message,
@@ -18,6 +19,18 @@
1819
_DEFAULT_CONNECT_TIMEOUT_S = 5.0
1920

2021

22+
def _extract_host_id(payload: bytes) -> Optional[str]:
23+
"""Pull ``host_id`` out of an AUTH_OK payload (JSON or empty)."""
24+
if not payload:
25+
return None
26+
try:
27+
body = json.loads(payload.decode("utf-8"))
28+
except (UnicodeDecodeError, json.JSONDecodeError):
29+
return None
30+
value = body.get("host_id") if isinstance(body, dict) else None
31+
return value if isinstance(value, str) else None
32+
33+
2134
class RemoteDesktopViewer:
2235
"""Connect to a :class:`RemoteDesktopHost` and stream frames + input.
2336
@@ -29,6 +42,7 @@ class RemoteDesktopViewer:
2942
def __init__(self, host: str, port: int, token: str,
3043
on_frame: Optional[FrameCallback] = None,
3144
on_error: Optional[ErrorCallback] = None,
45+
expected_host_id: Optional[str] = None,
3246
) -> None:
3347
if not isinstance(host, str) or not host:
3448
raise ValueError("host must be a non-empty string")
@@ -39,6 +53,9 @@ def __init__(self, host: str, port: int, token: str,
3953
self._token = token
4054
self._on_frame = on_frame
4155
self._on_error = on_error
56+
self._expected_host_id = (validate_host_id(expected_host_id)
57+
if expected_host_id else None)
58+
self._remote_host_id: Optional[str] = None
4259
self._sock: Optional[socket.socket] = None
4360
self._send_lock = threading.Lock()
4461
self._shutdown = threading.Event()
@@ -49,6 +66,11 @@ def __init__(self, host: str, port: int, token: str,
4966
def connected(self) -> bool:
5067
return self._connected and not self._shutdown.is_set()
5168

69+
@property
70+
def remote_host_id(self) -> Optional[str]:
71+
"""The host ID announced in AUTH_OK; ``None`` until handshake completes."""
72+
return self._remote_host_id
73+
5274
def connect(self, timeout: float = _DEFAULT_CONNECT_TIMEOUT_S) -> None:
5375
"""Open the TCP connection and complete the auth handshake.
5476
@@ -138,6 +160,8 @@ def _handshake(self, sock: socket.socket) -> None:
138160
sock.sendall(encode_frame(MessageType.AUTH_RESPONSE, response))
139161
msg_type, payload = read_message(sock)
140162
if msg_type is MessageType.AUTH_OK:
163+
self._remote_host_id = _extract_host_id(payload)
164+
self._verify_host_id(self._remote_host_id)
141165
return
142166
if msg_type is MessageType.AUTH_FAIL:
143167
raise AuthenticationError(
@@ -147,6 +171,16 @@ def _handshake(self, sock: socket.socket) -> None:
147171
f"unexpected handshake reply {msg_type.name}"
148172
)
149173

174+
def _verify_host_id(self, announced: Optional[str]) -> None:
175+
"""Reject the connection when the server's ID does not match expectation."""
176+
if self._expected_host_id is None:
177+
return
178+
if announced != self._expected_host_id:
179+
raise AuthenticationError(
180+
f"host_id mismatch: expected {self._expected_host_id}, "
181+
f"got {announced!r}"
182+
)
183+
150184
def _recv_loop(self) -> None:
151185
sock = self._sock
152186
if sock is None:

0 commit comments

Comments
 (0)