diff --git a/run_benchmarks.py b/run_benchmarks.py index 9c58a99..2801791 100644 --- a/run_benchmarks.py +++ b/run_benchmarks.py @@ -85,11 +85,14 @@ def write_shapefile_with_PyShp(target: str | PathLike): for file_path in SHAPEFILES.values(): file_path.read_bytes() +COLS_WIDTHS = (22, 10) + reader_benchmarks = [ functools.partial( benchmark, name=f"Read {test_name}", func=functools.partial(open_shapefile_with_PyShp, target=target), + col_widths=COLS_WIDTHS, ) for test_name, target in SHAPEFILES.items() ] @@ -101,13 +104,17 @@ def write_shapefile_with_PyShp(target: str | PathLike): benchmark, name=f"Write {test_name}", func=functools.partial(write_shapefile_with_PyShp, target=target), + col_widths=COLS_WIDTHS, ) for test_name, target in SHAPEFILES.items() ] -def run(run_count: int, benchmarks: list[Callable[[], None]]) -> None: - col_widths = (22, 10) +def run( + run_count: int, + benchmarks: list[Callable[[], None]], + col_widths: tuple[int, int] = COLS_WIDTHS, +) -> None: col_head = ("parser", "exec time", "performance (more is better)") print(f"Running benchmarks {run_count} times:") print("-" * col_widths[0] + "---" + "-" * col_widths[1]) @@ -116,7 +123,6 @@ def run(run_count: int, benchmarks: list[Callable[[], None]]) -> None: for benchmark in benchmarks: benchmark( # type: ignore [call-arg] run_count=run_count, - col_widths=col_widths, ) diff --git a/src/shapefile.py b/src/shapefile.py index c4b4f74..24dd409 100644 --- a/src/shapefile.py +++ b/src/shapefile.py @@ -10,6 +10,7 @@ __version__ = "3.0.9.dev" +import abc import array import doctest import functools @@ -42,6 +43,7 @@ Union, cast, overload, + runtime_checkable, ) from urllib.error import HTTPError from urllib.parse import ParseResult, urlparse, urlunparse @@ -139,12 +141,14 @@ class ReadableBinStream(Protocol): def read(self, size: int = -1) -> bytes: ... +@runtime_checkable class WriteSeekableBinStream(Protocol): def write(self, b: bytes) -> int: ... def seek(self, offset: int, whence: int = 0) -> int: ... def tell(self) -> int: ... +@runtime_checkable class ReadSeekableBinStream(Protocol): def seek(self, offset: int, whence: int = 0) -> int: ... def tell(self) -> int: ... @@ -2354,43 +2358,184 @@ def ensure_within_bounds(i: int, num_records: int) -> int: return i -class DbfReader: - """Reads a dbf file. You can instantiate a DbfReader without specifying a shapefile - and then specify one later with the load() method. - """ +class ShapefileException(Exception): + """An exception to handle shapefile specific problems.""" + + +class dbfFileException(ShapefileException): + """Indicates a problem with the .dbf file.""" + + +# Use ExitStack to Support not closing opened file objects passed in e.g.(handled by some +# external context manager, or the caller manually calling .close). +# +# This will only ever hold at most one context manager. +# But an ExitStack is the right tool for the job +# when the number of context manager(s) depends on user input. +class _HasExitStack(AbstractContextManager["_HasExitStack", None]): + def __init__(self) -> None: + self.exit_stack = ExitStack() + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + self.close() + + def close(self) -> None: + self.exit_stack.close() + + +FileProtoT = TypeVar("FileProtoT") + + +class _FileChecker(_HasExitStack, Generic[FileProtoT]): + @property + @abc.abstractmethod + def FileProto(self) -> type[FileProtoT]: ... + + @property + @abc.abstractmethod + def new_file_mode(self) -> Literal["rb", "w+b"]: ... + + @property + @abc.abstractmethod + def ext(self) -> Literal[".shp", ".shx", ".dbf"]: ... + + ExceptionClass = ShapefileException def __init__( self, + file: str | PathLike[Any] | FileProtoT, + encoding: str = "utf-8", + encodingErrors: str = "strict", + ): + super().__init__() + + file = fsdecode_if_pathlike(file) + self._file: str | FileProtoT + if isinstance(file, str): + self._file = f"{os.path.splitext(file)[0]}{self.ext}" + elif file: + self._file = file + else: + raise TypeError( + f"file must be set to a str, Path or file-like object. Got: {file}" + ) + + # Encoding + self.encoding = encoding + self.encodingErrors = encodingErrors + + @functools.cached_property + def file(self) -> FileProtoT: + return self._ensure_file_obj() + # f=self._file, + # FileProto=self.FileProto, + # exit_stack=self.exit_stack, + # new_file_mode="rb", + # ExceptionClass=dbfFileException, + # ) + + def _ensure_file_obj( + self, + f: str | FileProtoT | None = None, + # FileProto: type[FileProtoT], + # exit_stack: ExitStack, + # new_file_mode: Literal["rb", "w+b"] = "w+b", + # ExceptionClass: type[ShapefileException] = ShapefileException, + ) -> FileProtoT: + """Safety handler to verify file-like objects""" + + f = f or self._file + exit_stack = self.exit_stack + FileProto = self.FileProto + new_file_mode = self.new_file_mode + ExceptionClass = self.ExceptionClass + + if not f: + raise ExceptionClass(f"No file-like object received. Got: {f}") + if isinstance(f, str): + pth = os.path.split(f)[0] + if pth and not os.path.exists(pth): + os.makedirs(pth) + fp = open(f, new_file_mode) + + # Only push files created here to the exit stack. + # The user must close their own file objects. + exit_stack.enter_context(fp) + return cast(FileProtoT, fp) + + # See Minor hack below. + if isinstance(f, FileProto): + return f + + # Ugly, but perhaps needed to avoid weird Python 3.14 + # specific bug in run_benchmarks.py on Windows + # if ( + # (attrs := getattr(FileProto, "__protocol_attrs__", None)) + # and all(hasattr(f, attr) for attr in attrs) + # ): + # return f + raise ExceptionClass( + f"Unsupported file-like object: {f}. Must satisfy: {FileProto}" + ) + + +# Minor hack. Relies on Protocols being ABC subclasses. +# They are currently ABCs anyway, so why not use that +# for something useful? +# +# tempfile.NamedTemporaryFile is a dynamic wrapper +# https://github.com/python/cpython/blob/2dd91d2b92a6c74d78cd3385ede328190cd8eaa9/Lib/tempfile.py#L510 +# so normal (naive) isinstance checks of tempfile.NamedTemporaryFiles +# against @runtime_checkable Protocols are not possible. +ReadSeekableBinStream.register(tempfile._TemporaryFileWrapper) +WriteSeekableBinStream.register(tempfile._TemporaryFileWrapper) + + +class DbfReader(_FileChecker[ReadSeekableBinStream]): + """Reads a dbf file. You can instantiate a DbfReader without specifying a shapefile.""" + + FileProto = ReadSeekableBinStream + new_file_mode = "rb" + ext = ".dbf" + ExceptionClass = dbfFileException + + def __init__( + self, + dbf: str | PathLike[Any] | ReadSeekableBinStream, *, - file_obj: IO[bytes], encoding: str = "utf-8", encodingErrors: str = "strict", ): - self._file = file_obj + super().__init__(file=dbf, encoding=encoding, encodingErrors=encodingErrors) + self.fields: list[Field] = [] self.__fieldLookup: dict[str, int] = {} - self.encoding = encoding - self.encodingErrors = encodingErrors self._dbfHeader() - @property - def dbf(self) -> IO[bytes]: - if not self._file: - raise dbfFileException( - f"DbfReader requires a .dbf file or file-like object. Got: {self._file}" - ) - return self._file + # @functools.cached_property + # def dbf(self) -> ReadSeekableBinStream: + # return self._ensure_file_obj( + # # f=self._file, + # # FileProto=self.FileProto, + # # exit_stack=self.exit_stack, + # # new_file_mode="rb", + # # ExceptionClass=dbfFileException, + # ) def __len__(self) -> int: """Returns the number of records in the .dbf file.""" - return self.numRecords def _dbfHeader(self) -> None: """Reads a dbf header. Xbase-related code borrows heavily from ActiveState Python Cookbook Recipe 362715 by Raymond Hettinger""" - dbf = self.dbf + dbf = self.file # read relevant header parts dbf.seek(0) self.numRecords, self.__dbfHdrLength, self._record_length = cast( @@ -2506,7 +2651,7 @@ def _record( a list of field info Field namedtuples 'fieldTuples', a record name-index dict 'recLookup', and a Struct instance 'recStruct' for unpacking these fields. """ - f = self.dbf + f = self.file # The only format chars in from self._record_fmt, in recStruct from _record_fields, # are s and x (ascii encoded str and pad byte) so everything in recordContents is bytes @@ -2600,7 +2745,7 @@ def record(self, i: int = 0, fields: list[str] | None = None) -> _Record | None: To only read some of the fields, specify the 'fields' arg as a list of one or more fieldnames. """ - f = self.dbf + f = self.file i = ensure_within_bounds(i, self.numRecords) recSize = self._record_length @@ -2616,9 +2761,10 @@ def records(self, fields: list[str] | None = None) -> list[_Record]: To only read some of the fields, specify the 'fields' arg as a list of one or more fieldnames. """ + f = self.file records = [] - self.dbf.seek(self.__dbfHdrLength) + f.seek(self.__dbfHdrLength) fieldTuples, recLookup, recStruct = self._record_fields(fields) for i in range(self.numRecords): @@ -2646,6 +2792,7 @@ def iterRecords( start <= i < number_of_records + stop if stop < 0). """ + f = self.file if not isinstance(self.numRecords, int): raise ShapefileException( @@ -2661,7 +2808,7 @@ def iterRecords( elif stop < 0: stop = range(self.numRecords)[stop] recSize = self._record_length - self.dbf.seek(self.__dbfHdrLength + (start * recSize)) + f.seek(self.__dbfHdrLength + (start * recSize)) fieldTuples, recLookup, recStruct = self._record_fields(fields) for i in range(start, stop): r = self._record( @@ -2671,14 +2818,6 @@ def iterRecords( yield r -class ShapefileException(Exception): - """An exception to handle shapefile specific problems.""" - - -class dbfFileException(ShapefileException): - """Indicates a problem with the .dbf file.""" - - class _NoShpSentinel: """For use as a default value for shp to preserve the behaviour (from when all keyword args were gathered @@ -2690,7 +2829,7 @@ class _NoShpSentinel: _NO_SHP_SENTINEL = _NoShpSentinel() -class Reader: +class Reader(_FileChecker[ReadSeekableBinStream]): """Reads the three files of a shapefile as a unit or separately. If one of the three files (.shp, .shx, .dbf) is missing no exception is thrown until you try @@ -2711,6 +2850,11 @@ class Reader: but they can be. """ + FileProto = ReadSeekableBinStream + new_file_mode = "rb" + ext = ".shp" + ExceptionClass = ShapefileException + def __init__( self, shapefile_path: str | PathLike[Any] = "", @@ -2733,7 +2877,7 @@ def __init__( self._offsets: list[int] = [] self.shpLength: int | None = None self.numShapes: int | None = None - self._exit_stack = ExitStack() + self.exit_stack = ExitStack() # See if a shapefile name was passed as the first argument if shapefile_path: path = fsdecode_if_pathlike(shapefile_path) @@ -2778,7 +2922,7 @@ def _get_dbf_reader(self) -> DbfReader: "Shapefile DbfReader requires a .dbf file or file-like object." ) return DbfReader( - file_obj=self._dbf, + dbf=self._dbf, encoding=self.encoding, encodingErrors=self.encodingErrors, ) @@ -2788,24 +2932,26 @@ def dbf_reader(self) -> DbfReader: return self._get_dbf_reader() @functools.cached_property - def shp(self) -> IO[bytes]: - if self._shp is None: - raise ShapefileException( - "Shapefile Reader requires a .shp shapefile or file-like object." - ) - return self._shp + def shp(self) -> ReadSeekableBinStream: + return self._ensure_file_obj( + f=self._shp, + # FileProto=ReadSeekableBinStream, + # exit_stack=self.exit_stack, + # new_file_mode="rb", + ) @functools.cached_property - def shx(self) -> IO[bytes]: - if self._shx is None: - raise ShapefileException( - "Shapefile Reader shx use requires a .shx shapefile or file-like object." - ) - return self._shx + def shx(self) -> ReadSeekableBinStream: + return self._ensure_file_obj( + f=self._shx, + # FileProto=ReadSeekableBinStream, + # exit_stack=self.exit_stack, + # new_file_mode="rb", + ) @property - def dbf(self) -> IO[bytes]: - return self.dbf_reader.dbf + def dbf(self) -> ReadableBinStream: + return self.dbf_reader.file @property def numRecords(self) -> int | None: @@ -2843,7 +2989,7 @@ def _seek_0_on_file_obj_wrap_or_open_from_name( baseName, __ = os.path.splitext(file_) file_obj = _try_get_open_constituent_file(baseName, ext) if file_obj is not None: - self._exit_stack.enter_context(file_obj) + self.exit_stack.enter_context(file_obj) return file_obj if hasattr(file_, "read"): @@ -2876,7 +3022,7 @@ def _load_from_url(self, url: str) -> None: # Use tempfile as source for url data. fileobj = _save_to_named_tmp_file(resp, initial_bytes=sniffed_bytes) setattr(self, f"_{ext}", fileobj) - self._exit_stack.enter_context(fileobj) + self.exit_stack.enter_context(fileobj) if not shp_or_dbf_downloaded: raise ShapefileException(f"Failed to download .shp or .dbf from: {url}") @@ -2945,7 +3091,7 @@ def _load_from_zip(self, path: str) -> None: # Use read+write tempfile as source for member data. fileobj = _save_to_named_tmp_file(member) setattr(self, f"_{ext.lower()}", fileobj) - self._exit_stack.enter_context(fileobj) + self.exit_stack.enter_context(fileobj) except (OSError, AttributeError, KeyError): pass # Close and delete the temporary zipfile @@ -2979,7 +3125,7 @@ def load_shp(self, shapefile_name: str) -> None: """ self._shp = _try_get_open_constituent_file(shapefile_name, "shp") if self._shp: - self._exit_stack.enter_context(self._shp) + self.exit_stack.enter_context(self._shp) self._shpHeader() def load_shx(self, shapefile_name: str) -> None: @@ -2988,7 +3134,7 @@ def load_shx(self, shapefile_name: str) -> None: """ self._shx = _try_get_open_constituent_file(shapefile_name, "shx") if self._shx: - self._exit_stack.enter_context(self._shx) + self.exit_stack.enter_context(self._shx) self._shxHeader() def load_dbf(self, shapefile_name: str) -> None: @@ -2997,7 +3143,7 @@ def load_dbf(self, shapefile_name: str) -> None: """ self._dbf = _try_get_open_constituent_file(shapefile_name, "dbf") if self._dbf: - self._exit_stack.enter_context(self._dbf) + self.exit_stack.enter_context(self._dbf) self._get_dbf_reader() def __len__(self) -> int: @@ -3142,7 +3288,7 @@ def __del__(self) -> None: self.close() def close(self) -> None: - self._exit_stack.close() + self.exit_stack.close() # Close any files that the reader opened (but not those given by user) # for file_ in [self._shp, self._dbf, self._shx]: # if file_ is None: @@ -3167,18 +3313,13 @@ def __str__(self) -> str: info.append(f" {len(self)} records ({len(self.fields)} fields)") return "\n".join(info) - def __enter__(self) -> Reader: - self._exit_stack.__enter__() - return self - def __exit__( self, - exc_type: BaseException | None, + exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, - ) -> bool | None: + ) -> None: self.close() - return None def _shape(self, oid: int | None = None, bbox: BBox | None = None) -> Shape | None: """Returns the header info and geometry for a single shape.""" @@ -3370,28 +3511,6 @@ def iterShapeRecords( yield ShapeRecord(shape=shape, record=record) -def _ensure_file_obj( - f: str | WriteSeekableBinStream | None, - exit_stack: ExitStack, - file_mode: str = "wb+", - ExceptionClass: type[ShapefileException] = ShapefileException, -) -> WriteSeekableBinStream: - """Safety handler to verify file-like objects""" - if not f: - raise ExceptionClass("No file-like object available.") - if isinstance(f, str): - pth = os.path.split(f)[0] - if pth and not os.path.exists(pth): - os.makedirs(pth) - fp = open(f, file_mode) - exit_stack.enter_context(fp) - return fp - - if hasattr(f, "write"): - return f - raise ExceptionClass(f"Unsupported file-like object: {f}") - - def _is_file_obj_open(f: WriteSeekableBinStream | str | None) -> bool: if not f: return False @@ -3409,13 +3528,16 @@ def _try_to_flush_file_obj(f: WriteSeekableBinStream | str | None) -> None: pass -class DbfWriter(AbstractContextManager["DbfWriter", None]): +class DbfWriter(_FileChecker[WriteSeekableBinStream]): """Writes .dbf files (dBASE database files), in particular those of Shapefiles.""" + FileProto = WriteSeekableBinStream + new_file_mode = "w+b" + ext = ".dbf" + def __init__( self, - path: str | PathLike[Any] | None = None, - dbf: str | WriteSeekableBinStream | None = None, + dbf: str | PathLike[Any] | WriteSeekableBinStream, *, encoding: str = "utf-8", encodingErrors: str = "strict", @@ -3423,28 +3545,21 @@ def __init__( # Keep kwargs even though unused, to preserve PyShp 2.4 API **kwargs: Any, ): - self.path = fsdecode_if_pathlike(path) - self._dbf: str | WriteSeekableBinStream - self.fields: list[Field] = [] - self.max_num_fields = max_num_fields - # Encoding - self.encoding = encoding - self.encodingErrors = encodingErrors - if self.path: - if not isinstance(self.path, str): - raise TypeError( - f"Path {self.path!r} must be of type str or path-like, not {type(self.path)}." - ) - self._dbf = os.path.splitext(self.path)[0] + ".dbf" - elif dbf: - self._dbf = dbf - else: - raise TypeError( - "Either the target filepath, or dbf must be set to create a .dbf file." - ) - - self.recNum = 0 - self.deletionFlag = 0 + super().__init__(file=dbf, encoding=encoding, encodingErrors=encodingErrors) + + # dbf = fsdecode_if_pathlike(dbf) + # self._dbf: str | WriteSeekableBinStream + # # Encoding + # self.encoding = encoding + # self.encodingErrors = encodingErrors + # if isinstance(dbf, str): + # self._dbf = os.path.splitext(dbf)[0] + ".dbf" + # elif dbf: + # self._dbf = self.file + # else: + # raise TypeError( + # f"dbf must be set to a str, Path or file-like object. Got: {dbf}" + # ) # Support not closing opened file objects passed in e.g.(handled by some # external context manager, or the caller manually calling .close). @@ -3452,28 +3567,23 @@ def __init__( # This will only ever hold at most one context manager. # But an ExitStack is the right tool for the job # when the number of context manager(s) depends on user input. - self._exit_stack = ExitStack() + # self.exit_stack = ExitStack() + + self.fields: list[Field] = [] + self.max_num_fields = max_num_fields + self.recNum = 0 + self.deletionFlag = 0 @functools.cached_property def dbf(self) -> WriteSeekableBinStream: - return _ensure_file_obj( - self._dbf, - exit_stack=self._exit_stack, - ExceptionClass=dbfFileException, + return self._ensure_file_obj( + # f=self._dbf, + # FileProto=WriteSeekableBinStream, + # exit_stack=self.exit_stack, + # new_file_mode="w+b", + # ExceptionClass=dbfFileException, ) - def __enter__(self) -> DbfWriter: - return self - - def __exit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> None: - self.close() - return None - def close(self) -> None: """ Write final dbf header, close opened files. @@ -3485,7 +3595,7 @@ def close(self) -> None: _try_to_flush_file_obj(self.dbf) - self._exit_stack.close() + super().close() def field( # Types of args should match *Field @@ -3683,10 +3793,15 @@ def __dbfRecord(self, record: list[RecordValue]) -> None: f.write(encoded) -class Writer: +class Writer(_FileChecker[WriteSeekableBinStream]): """Provides write support for ESRI Shapefiles.""" - W = TypeVar("W", bound=WriteSeekableBinStream) + # W = TypeVar("W", bound=WriteSeekableBinStream) + + FileProto = WriteSeekableBinStream + new_file_mode = "w+b" + ext = ".shp" + ExceptionClass = ShapefileException def __init__( self, @@ -3718,7 +3833,7 @@ def __init__( self._shx: str | WriteSeekableBinStream | None = shx self._dbf: str | WriteSeekableBinStream | None = dbf self._dbf_writer: DbfWriter | None = None - self._exit_stack = ExitStack() + self.exit_stack = ExitStack() if target: if not isinstance(target, str): raise TypeError( @@ -3753,16 +3868,20 @@ def __init__( @functools.cached_property def shp(self) -> WriteSeekableBinStream: - return _ensure_file_obj( - self._shp, - exit_stack=self._exit_stack, + return self._ensure_file_obj( + f=self._shp, + # FileProto=WriteSeekableBinStream, + # exit_stack=self.exit_stack, + # new_file_mode="w+b", ) @functools.cached_property def shx(self) -> WriteSeekableBinStream: - return _ensure_file_obj( - self._shx, - exit_stack=self._exit_stack, + return self._ensure_file_obj( + f=self._shx, + # FileProto=WriteSeekableBinStream, + # exit_stack=self.exit_stack, + # new_file_mode="w+b", ) @functools.cached_property @@ -3771,7 +3890,7 @@ def dbf_writer(self) -> DbfWriter: raise dbfFileException( f"No dbf file. Got target: {self.target} & dbf: {self._dbf}" ) - self._exit_stack.enter_context(self._dbf_writer) + self.exit_stack.enter_context(self._dbf_writer) return self._dbf_writer @property @@ -3805,12 +3924,11 @@ def __enter__(self) -> Writer: def __exit__( self, - exc_type: BaseException | None, + exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, - ) -> bool | None: + ) -> None: self.close() - return None def __del__(self) -> None: self.close() @@ -3861,7 +3979,7 @@ def close(self) -> None: # user-supplied, already opened file objects that # might be closed by an outer context manager). # Idempotent. - self._exit_stack.close() + self.exit_stack.close() def _shp_file_length_B(self) -> int: """Calculates the file length of the shp file."""