|
| 1 | +"""Chunked file transfer over the typed-message channel. |
| 2 | +
|
| 3 | +Three message types form a transfer: |
| 4 | +
|
| 5 | +* ``FILE_BEGIN`` — JSON ``{transfer_id, dest_path, size}`` announces a new |
| 6 | + stream. ``transfer_id`` is a 36-character UUID hex string so the |
| 7 | + receiver can demultiplex multiple in-flight transfers on one channel. |
| 8 | +* ``FILE_CHUNK`` — first 36 bytes are the ASCII transfer id, the rest is |
| 9 | + raw payload. Chunks arrive in order; the receiver writes them |
| 10 | + sequentially and accumulates ``bytes_done``. |
| 11 | +* ``FILE_END`` — JSON ``{transfer_id, status, error?}`` finalises the |
| 12 | + stream. The receiver closes the file and fires ``on_complete`` with |
| 13 | + success / failure info. |
| 14 | +
|
| 15 | +There is no central per-host file-size limit — operators relying on |
| 16 | +this should keep ``trusted token holders == trusted users`` in mind, and |
| 17 | +treat the dropbox / destination filesystem accordingly. |
| 18 | +""" |
| 19 | +import json |
| 20 | +import os |
| 21 | +import threading |
| 22 | +import uuid |
| 23 | +from dataclasses import dataclass, field |
| 24 | +from pathlib import Path |
| 25 | +from typing import Any, Callable, Dict, Optional, Tuple |
| 26 | + |
| 27 | +from je_auto_control.utils.logging.logging_instance import autocontrol_logger |
| 28 | +from je_auto_control.utils.remote_desktop.protocol import MessageType |
| 29 | + |
| 30 | +DEFAULT_CHUNK_SIZE = 256 * 1024 |
| 31 | +TRANSFER_ID_LEN = 36 # str(uuid.uuid4()) length |
| 32 | + |
| 33 | +ProgressCallback = Callable[[str, int, int], None] |
| 34 | +CompleteCallback = Callable[[str, bool, Optional[str], str], None] |
| 35 | + |
| 36 | + |
| 37 | +class FileTransferError(RuntimeError): |
| 38 | + """Raised when a file-transfer payload is malformed.""" |
| 39 | + |
| 40 | + |
| 41 | +def new_transfer_id() -> str: |
| 42 | + """Return a fresh 36-character ASCII transfer ID.""" |
| 43 | + return str(uuid.uuid4()) |
| 44 | + |
| 45 | + |
| 46 | +def encode_begin(transfer_id: str, dest_path: str, size: int) -> bytes: |
| 47 | + if len(transfer_id) != TRANSFER_ID_LEN: |
| 48 | + raise FileTransferError("transfer_id must be a 36-char UUID string") |
| 49 | + return json.dumps({ |
| 50 | + "transfer_id": transfer_id, |
| 51 | + "dest_path": str(dest_path), |
| 52 | + "size": int(size), |
| 53 | + }, ensure_ascii=False).encode("utf-8") |
| 54 | + |
| 55 | + |
| 56 | +def decode_begin(payload: bytes) -> Tuple[str, str, int]: |
| 57 | + body = _decode_json(payload) |
| 58 | + transfer_id = body.get("transfer_id") |
| 59 | + dest_path = body.get("dest_path") |
| 60 | + size = body.get("size") |
| 61 | + if (not isinstance(transfer_id, str) |
| 62 | + or len(transfer_id) != TRANSFER_ID_LEN): |
| 63 | + raise FileTransferError("FILE_BEGIN missing valid transfer_id") |
| 64 | + if not isinstance(dest_path, str) or not dest_path: |
| 65 | + raise FileTransferError("FILE_BEGIN missing dest_path") |
| 66 | + if not isinstance(size, int) or size < 0: |
| 67 | + raise FileTransferError("FILE_BEGIN missing valid size") |
| 68 | + return transfer_id, dest_path, size |
| 69 | + |
| 70 | + |
| 71 | +def encode_chunk(transfer_id: str, chunk: bytes) -> bytes: |
| 72 | + if len(transfer_id) != TRANSFER_ID_LEN: |
| 73 | + raise FileTransferError("transfer_id must be a 36-char UUID string") |
| 74 | + return transfer_id.encode("ascii") + bytes(chunk) |
| 75 | + |
| 76 | + |
| 77 | +def decode_chunk(payload: bytes) -> Tuple[str, bytes]: |
| 78 | + if len(payload) < TRANSFER_ID_LEN: |
| 79 | + raise FileTransferError("FILE_CHUNK shorter than transfer id header") |
| 80 | + transfer_id = payload[:TRANSFER_ID_LEN].decode("ascii", errors="replace") |
| 81 | + return transfer_id, bytes(payload[TRANSFER_ID_LEN:]) |
| 82 | + |
| 83 | + |
| 84 | +def encode_end(transfer_id: str, status: str = "ok", |
| 85 | + error: Optional[str] = None) -> bytes: |
| 86 | + if len(transfer_id) != TRANSFER_ID_LEN: |
| 87 | + raise FileTransferError("transfer_id must be a 36-char UUID string") |
| 88 | + body: Dict[str, Any] = {"transfer_id": transfer_id, "status": status} |
| 89 | + if error is not None: |
| 90 | + body["error"] = str(error) |
| 91 | + return json.dumps(body, ensure_ascii=False).encode("utf-8") |
| 92 | + |
| 93 | + |
| 94 | +def decode_end(payload: bytes) -> Tuple[str, str, Optional[str]]: |
| 95 | + body = _decode_json(payload) |
| 96 | + transfer_id = body.get("transfer_id") |
| 97 | + status = body.get("status", "ok") |
| 98 | + if (not isinstance(transfer_id, str) |
| 99 | + or len(transfer_id) != TRANSFER_ID_LEN): |
| 100 | + raise FileTransferError("FILE_END missing valid transfer_id") |
| 101 | + if not isinstance(status, str): |
| 102 | + raise FileTransferError("FILE_END status must be a string") |
| 103 | + error = body.get("error") |
| 104 | + return transfer_id, status, error if isinstance(error, str) else None |
| 105 | + |
| 106 | + |
| 107 | +def _decode_json(payload: bytes) -> Dict[str, Any]: |
| 108 | + try: |
| 109 | + body = json.loads(payload.decode("utf-8")) |
| 110 | + except (UnicodeDecodeError, json.JSONDecodeError) as error: |
| 111 | + raise FileTransferError(f"invalid JSON: {error}") from error |
| 112 | + if not isinstance(body, dict): |
| 113 | + raise FileTransferError("payload must be a JSON object") |
| 114 | + return body |
| 115 | + |
| 116 | + |
| 117 | +@dataclass |
| 118 | +class _Incoming: |
| 119 | + """Per-transfer state owned by ``FileReceiver``.""" |
| 120 | + |
| 121 | + transfer_id: str |
| 122 | + dest_path: Path |
| 123 | + total_size: int |
| 124 | + handle: Any # file object |
| 125 | + bytes_done: int = 0 |
| 126 | + error: Optional[str] = None |
| 127 | + |
| 128 | + |
| 129 | +class FileReceiver: |
| 130 | + """Demultiplex incoming FILE_* messages into one or more file writes.""" |
| 131 | + |
| 132 | + def __init__(self, on_progress: Optional[ProgressCallback] = None, |
| 133 | + on_complete: Optional[CompleteCallback] = None) -> None: |
| 134 | + self._on_progress = on_progress |
| 135 | + self._on_complete = on_complete |
| 136 | + self._active: Dict[str, _Incoming] = {} |
| 137 | + self._lock = threading.Lock() |
| 138 | + |
| 139 | + def handle_begin(self, payload: bytes) -> None: |
| 140 | + transfer_id, dest_path, total_size = decode_begin(payload) |
| 141 | + path = Path(os.path.expanduser(dest_path)) |
| 142 | + path.parent.mkdir(parents=True, exist_ok=True) |
| 143 | + try: |
| 144 | + handle = open(path, "wb") # noqa: SIM115 managed manually |
| 145 | + except OSError as error: |
| 146 | + self._fire_complete(transfer_id, False, str(error), str(path)) |
| 147 | + return |
| 148 | + with self._lock: |
| 149 | + self._active[transfer_id] = _Incoming( |
| 150 | + transfer_id=transfer_id, dest_path=path, |
| 151 | + total_size=total_size, handle=handle, |
| 152 | + ) |
| 153 | + if self._on_progress is not None: |
| 154 | + self._on_progress(transfer_id, 0, total_size) |
| 155 | + |
| 156 | + def handle_chunk(self, payload: bytes) -> None: |
| 157 | + transfer_id, chunk = decode_chunk(payload) |
| 158 | + with self._lock: |
| 159 | + incoming = self._active.get(transfer_id) |
| 160 | + if incoming is None: |
| 161 | + autocontrol_logger.info( |
| 162 | + "remote_desktop FILE_CHUNK for unknown transfer %s", |
| 163 | + transfer_id, |
| 164 | + ) |
| 165 | + return |
| 166 | + try: |
| 167 | + incoming.handle.write(chunk) |
| 168 | + except OSError as error: |
| 169 | + incoming.error = str(error) |
| 170 | + self._abort(incoming) |
| 171 | + return |
| 172 | + incoming.bytes_done += len(chunk) |
| 173 | + if self._on_progress is not None: |
| 174 | + self._on_progress( |
| 175 | + transfer_id, incoming.bytes_done, incoming.total_size, |
| 176 | + ) |
| 177 | + |
| 178 | + def handle_end(self, payload: bytes) -> None: |
| 179 | + transfer_id, status, error = decode_end(payload) |
| 180 | + with self._lock: |
| 181 | + incoming = self._active.pop(transfer_id, None) |
| 182 | + if incoming is None: |
| 183 | + return |
| 184 | + try: |
| 185 | + incoming.handle.close() |
| 186 | + except OSError: |
| 187 | + pass |
| 188 | + ok = (status == "ok") and incoming.error is None |
| 189 | + message = error or incoming.error |
| 190 | + self._fire_complete( |
| 191 | + transfer_id, ok, message, str(incoming.dest_path), |
| 192 | + ) |
| 193 | + |
| 194 | + def _abort(self, incoming: _Incoming) -> None: |
| 195 | + try: |
| 196 | + incoming.handle.close() |
| 197 | + except OSError: |
| 198 | + pass |
| 199 | + with self._lock: |
| 200 | + self._active.pop(incoming.transfer_id, None) |
| 201 | + self._fire_complete( |
| 202 | + incoming.transfer_id, False, incoming.error, |
| 203 | + str(incoming.dest_path), |
| 204 | + ) |
| 205 | + |
| 206 | + def _fire_complete(self, transfer_id: str, ok: bool, |
| 207 | + error: Optional[str], dest_path: str) -> None: |
| 208 | + if self._on_complete is None: |
| 209 | + return |
| 210 | + try: |
| 211 | + self._on_complete(transfer_id, ok, error, dest_path) |
| 212 | + except Exception: # noqa: BLE001 |
| 213 | + autocontrol_logger.exception( |
| 214 | + "remote_desktop FileReceiver.on_complete callback raised" |
| 215 | + ) |
| 216 | + |
| 217 | + |
| 218 | +@dataclass |
| 219 | +class FileSendResult: |
| 220 | + """Outcome of one outbound transfer.""" |
| 221 | + |
| 222 | + transfer_id: str |
| 223 | + success: bool |
| 224 | + error: Optional[str] = None |
| 225 | + bytes_sent: int = 0 |
| 226 | + |
| 227 | + |
| 228 | +def send_file(channel, source_path: str, dest_path: str, |
| 229 | + on_progress: Optional[ProgressCallback] = None, |
| 230 | + chunk_size: int = DEFAULT_CHUNK_SIZE, |
| 231 | + transfer_id: Optional[str] = None) -> FileSendResult: |
| 232 | + """Stream ``source_path`` to ``dest_path`` over ``channel``. |
| 233 | +
|
| 234 | + Synchronous: the caller's thread does the I/O. Wrap in a thread for |
| 235 | + background uploads. ``on_progress(transfer_id, bytes_done, total)`` |
| 236 | + fires after every chunk (and once at the start with ``bytes_done=0``). |
| 237 | + """ |
| 238 | + transfer_id = transfer_id or new_transfer_id() |
| 239 | + source = Path(os.path.expanduser(source_path)) |
| 240 | + if not source.is_file(): |
| 241 | + raise FileTransferError(f"source not found: {source}") |
| 242 | + total_size = source.stat().st_size |
| 243 | + channel.send_typed(MessageType.FILE_BEGIN, |
| 244 | + encode_begin(transfer_id, dest_path, total_size)) |
| 245 | + if on_progress is not None: |
| 246 | + on_progress(transfer_id, 0, total_size) |
| 247 | + bytes_sent = 0 |
| 248 | + try: |
| 249 | + with open(source, "rb") as handle: |
| 250 | + while True: |
| 251 | + chunk = handle.read(int(chunk_size)) |
| 252 | + if not chunk: |
| 253 | + break |
| 254 | + channel.send_typed( |
| 255 | + MessageType.FILE_CHUNK, encode_chunk(transfer_id, chunk), |
| 256 | + ) |
| 257 | + bytes_sent += len(chunk) |
| 258 | + if on_progress is not None: |
| 259 | + on_progress(transfer_id, bytes_sent, total_size) |
| 260 | + except (OSError, ConnectionError) as error: |
| 261 | + channel.send_typed( |
| 262 | + MessageType.FILE_END, |
| 263 | + encode_end(transfer_id, status="error", error=str(error)), |
| 264 | + ) |
| 265 | + return FileSendResult(transfer_id=transfer_id, success=False, |
| 266 | + error=str(error), bytes_sent=bytes_sent) |
| 267 | + channel.send_typed(MessageType.FILE_END, encode_end(transfer_id)) |
| 268 | + return FileSendResult(transfer_id=transfer_id, success=True, |
| 269 | + bytes_sent=bytes_sent) |
0 commit comments