|
6 | 6 | import os |
7 | 7 | import re |
8 | 8 | import stat |
9 | | -import tempfile |
10 | 9 | import uuid |
11 | 10 | from collections.abc import Awaitable, Callable, Mapping |
12 | 11 | from pathlib import Path |
13 | | -from typing import TYPE_CHECKING, Literal, cast |
| 12 | +from typing import TYPE_CHECKING, Literal |
14 | 13 |
|
15 | 14 | from pydantic import Field, field_serializer, field_validator |
16 | 15 |
|
|
31 | 30 | _COMMIT_REF_RE = re.compile(r"[0-9a-fA-F]{7,40}") |
32 | 31 | _OPEN_SUPPORTS_DIR_FD = os.open in os.supports_dir_fd |
33 | 32 | _HAS_O_DIRECTORY = hasattr(os, "O_DIRECTORY") |
34 | | -_LOCAL_COPY_BUFFER_MAX_MEMORY_BYTES = 8 * 1024 * 1024 |
35 | | - |
36 | | - |
37 | | -def _buffer_handle_with_checksum(handle: io.BufferedReader) -> tuple[io.IOBase, str]: |
38 | | - digest = hashlib.sha256() |
39 | | - buffer = tempfile.SpooledTemporaryFile( |
40 | | - max_size=_LOCAL_COPY_BUFFER_MAX_MEMORY_BYTES, |
41 | | - mode="w+b", |
42 | | - ) |
43 | | - try: |
44 | | - while True: |
45 | | - chunk = handle.read(1024 * 1024) |
46 | | - if not chunk: |
47 | | - break |
48 | | - digest.update(chunk) |
49 | | - buffer.write(chunk) |
50 | | - buffer.seek(0) |
51 | | - return cast(io.IOBase, buffer), digest.hexdigest() |
52 | | - except BaseException: |
53 | | - buffer.close() |
54 | | - raise |
| 33 | + |
| 34 | + |
| 35 | +class _HashingReader(io.IOBase): |
| 36 | + def __init__(self, stream: io.BufferedReader) -> None: |
| 37 | + self._stream = stream |
| 38 | + self._digest = hashlib.sha256() |
| 39 | + self._started = False |
| 40 | + self._finished = False |
| 41 | + |
| 42 | + def readable(self) -> bool: |
| 43 | + return True |
| 44 | + |
| 45 | + def read(self, size: int = -1) -> bytes: |
| 46 | + chunk = self._stream.read(size) |
| 47 | + if chunk is None: |
| 48 | + self._finished = True |
| 49 | + return b"" |
| 50 | + if isinstance(chunk, bytearray): |
| 51 | + chunk = bytes(chunk) |
| 52 | + self._started = True |
| 53 | + if not chunk: |
| 54 | + self._finished = True |
| 55 | + return b"" |
| 56 | + self._digest.update(chunk) |
| 57 | + if size < 0 or len(chunk) < size: |
| 58 | + self._finished = True |
| 59 | + return chunk |
| 60 | + |
| 61 | + def readinto(self, b: bytearray) -> int: |
| 62 | + data = self.read(len(b)) |
| 63 | + n = len(data) |
| 64 | + b[:n] = data |
| 65 | + return n |
| 66 | + |
| 67 | + def seek(self, offset: int, whence: int = io.SEEK_SET) -> int: |
| 68 | + if self._started: |
| 69 | + raise io.UnsupportedOperation("cannot seek after reads begin") |
| 70 | + return int(self._stream.seek(offset, whence)) |
| 71 | + |
| 72 | + def tell(self) -> int: |
| 73 | + return int(self._stream.tell()) |
| 74 | + |
| 75 | + def hexdigest(self) -> str: |
| 76 | + if not self._finished: |
| 77 | + raise RuntimeError("checksum is not available until the stream is fully consumed") |
| 78 | + return self._digest.hexdigest() |
55 | 79 |
|
56 | 80 |
|
57 | 81 | class Dir(BaseEntry): |
@@ -120,21 +144,15 @@ async def apply( |
120 | 144 | base_dir: Path, |
121 | 145 | ) -> list[MaterializedFile]: |
122 | 146 | src = (base_dir / self.src).resolve() |
123 | | - try: |
124 | | - with src.open("rb") as f: |
125 | | - buffered, checksum = _buffer_handle_with_checksum(f) |
126 | | - except OSError as e: |
127 | | - raise LocalFileReadError(src=src, cause=e) from e |
128 | 147 | await session.mkdir(Path(dest).parent, parents=True) |
129 | 148 | try: |
130 | | - try: |
131 | | - await session.write(dest, buffered) |
132 | | - finally: |
133 | | - buffered.close() |
| 149 | + with src.open("rb") as f: |
| 150 | + hashing_reader = _HashingReader(f) |
| 151 | + await session.write(dest, hashing_reader) |
134 | 152 | except OSError as e: |
135 | 153 | raise LocalFileReadError(src=src, cause=e) from e |
136 | 154 | await self._apply_metadata(session, dest) |
137 | | - return [MaterializedFile(path=dest, sha256=checksum)] |
| 155 | + return [MaterializedFile(path=dest, sha256=hashing_reader.hexdigest())] |
138 | 156 |
|
139 | 157 |
|
140 | 158 | class LocalDir(BaseEntry): |
@@ -362,18 +380,15 @@ async def _copy_local_dir_file( |
362 | 380 | ) |
363 | 381 | with os.fdopen(fd, "rb") as f: |
364 | 382 | fd = None |
365 | | - buffered, checksum = _buffer_handle_with_checksum(f) |
| 383 | + hashing_reader = _HashingReader(f) |
366 | 384 | await session.mkdir(child_dest.parent, parents=True, user=user) |
367 | | - try: |
368 | | - await session.write(child_dest, buffered, user=user) |
369 | | - finally: |
370 | | - buffered.close() |
| 385 | + await session.write(child_dest, hashing_reader, user=user) |
371 | 386 | except OSError as e: |
372 | 387 | raise LocalFileReadError(src=src, cause=e) from e |
373 | 388 | finally: |
374 | 389 | if fd is not None: |
375 | 390 | os.close(fd) |
376 | | - return MaterializedFile(path=child_dest, sha256=checksum) |
| 391 | + return MaterializedFile(path=child_dest, sha256=hashing_reader.hexdigest()) |
377 | 392 |
|
378 | 393 | def _open_local_dir_file_for_copy( |
379 | 394 | self, *, base_dir: Path, src_root: Path, rel_child: Path |
|
0 commit comments