diff --git a/CHANGELOG.md b/CHANGELOG.md index e3d291c..823beb0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -61,6 +61,10 @@ rather than message. Each subclass also derives from `ValueError`, so existing - **CLI option names aligned with the library**: `--palette-size` (canonical; `--n` kept as an alias) and `--max-workers` (canonical; `--num-threads` kept as an alias). +- **Explicit sample size**: `resize` now takes an `int` sample size or `None` + (no resize) instead of a bare bool, surfacing the previously hidden 256x256 + downscale. The default is `resize=256` (unchanged behavior). The + `ExtractionParams` metadata `resize` field is now `int | None`. ### Deprecated @@ -71,6 +75,9 @@ rather than message. Each subclass also derives from `ValueError`, so existing `--max-workers`. - **`Color.get_colors(...)`**: replaced by `Color.to(...)`. It still works for one release and now emits a `DeprecationWarning`. +- **Bare bool `resize`**: pass an int sample size or `None` instead. `resize=True` + (→ `256`) and `resize=False` (→ `None`) still work for one release and now emit + a `DeprecationWarning`. ### Removed diff --git a/pylette/src/color_extraction.py b/pylette/src/color_extraction.py index c735f33..ba5eaeb 100644 --- a/pylette/src/color_extraction.py +++ b/pylette/src/color_extraction.py @@ -1,5 +1,6 @@ import time import urllib.parse +import warnings from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime from io import BytesIO @@ -95,10 +96,31 @@ def _get_descriptive_image_source(image: ImageInput, pil_image: PILImage) -> str return f"" +def _resolve_resize(resize: int | bool | None) -> int | None: + """Normalize the ``resize`` argument to a pixel sample size or ``None``. + + Accepts an ``int`` sample size (the image is resized to ``(resize, resize)`` + before sampling), ``None`` (no resize, sample the full image), or a + deprecated ``bool`` (``True`` -> 256, ``False`` -> ``None``). + """ + if isinstance(resize, bool): + warnings.warn( + "Passing a bool for `resize` is deprecated and will be removed; pass an int " + "sample size (e.g. resize=256) or None to disable resizing (True maps to 256, " + "False to None).", + DeprecationWarning, + stacklevel=3, + ) + return 256 if resize else None + if resize is not None and resize < 1: + raise ValueError(f"resize must be a positive int or None, got {resize!r}.") + return resize + + def batch_extract_colors( images: Sequence[ImageInput], palette_size: int = 5, - resize: bool = True, + resize: int | bool | None = 256, mode: ExtractionMethod | str = ExtractionMethod.KM, sort_mode: Literal["luminance", "frequency"] | None = None, alpha_mask_threshold: int | None = None, @@ -112,6 +134,8 @@ def batch_extract_colors( Receives (task_number, result) as arguments. """ + resize = _resolve_resize(resize) + def thread_fn(image: ImageInput): return extract_colors( image=image, @@ -151,7 +175,7 @@ def thread_fn(image: ImageInput): def extract_colors( image: ImageInput, palette_size: int = 5, - resize: bool = True, + resize: int | bool | None = 256, mode: ExtractionMethod | str = ExtractionMethod.KM, sort_mode: Literal["luminance", "frequency"] | None = None, alpha_mask_threshold: int | None = None, @@ -162,7 +186,12 @@ def extract_colors( Parameters: image: The input image. palette_size: The number of colors to extract. - resize: Whether to resize the image before processing. + resize: The sample size. The image is downscaled to ``(resize, resize)`` + before colors are extracted, which bounds runtime; pass ``None`` to + sample the image at full resolution instead. Smaller values are + faster but coarser; larger values are slower but capture more detail. + Defaults to ``256``. (Passing a ``bool`` is deprecated: ``True`` maps + to ``256`` and ``False`` to ``None``.) mode: The color quantization algorithm to use. sort_mode: The mode to sort colors. alpha_mask_threshold: Optional integer between 0, 255. @@ -196,13 +225,14 @@ def extract_colors( Examples: Colors can be extracted from a variety of sources, including local files, byte streams, URLs, and numpy arrays. - >>> extract_colors("path/to/image.jpg", palette_size=5, resize=True, mode="KM", sort_mode="luminance") - >>> extract_colors(b"image_bytes", palette_size=5, resize=True, mode="KM", sort_mode="luminance") + >>> extract_colors("path/to/image.jpg", palette_size=5, resize=256, mode="KM", sort_mode="luminance") + >>> extract_colors(b"image_bytes", palette_size=5, resize=None, mode="KM", sort_mode="luminance") """ start_time = time.time() mode = coerce_to_enum(mode, ExtractionMethod, error_cls=UnknownExtractionMethodError) + resize = _resolve_resize(resize) source_type = _get_source_type_from_image_input(image) # Normalize input to PIL Image and convert to RGBA @@ -213,15 +243,15 @@ def extract_colors( # Store original image info image_info = ImageInfo( original_size=original_size, - processed_size=img.size if not resize else (256, 256), + processed_size=(resize, resize) if resize is not None else img.size, format=getattr(img_obj, "format", None), mode=img.mode, has_alpha=img.mode in ("RGBA", "LA") or "transparency" in img_obj.info, ) - if resize: - img = img.resize((256, 256)) - image_info["processed_size"] = (256, 256) + if resize is not None: + img = img.resize((resize, resize)) + image_info["processed_size"] = (resize, resize) width, height = img.size arr = np.asarray(img) diff --git a/pylette/src/types.py b/pylette/src/types.py index c52103f..ab1208f 100644 --- a/pylette/src/types.py +++ b/pylette/src/types.py @@ -118,7 +118,7 @@ class ExtractionParams(TypedDict): palette_size: int mode: ExtractionMethod sort_mode: str | None - resize: bool + resize: int | None alpha_mask_threshold: int | None diff --git a/tests/integration/test_alpha_masking.py b/tests/integration/test_alpha_masking.py index 47184ae..70fd006 100644 --- a/tests/integration/test_alpha_masking.py +++ b/tests/integration/test_alpha_masking.py @@ -33,7 +33,7 @@ def test_extraction_survives_alpha_masking(half_transparent_image: Image.Image, half_transparent_image, palette_size=5, mode=mode, - resize=False, + resize=None, alpha_mask_threshold=0, ) assert len(palette) <= 5 diff --git a/tests/integration/test_colorspaces.py b/tests/integration/test_colorspaces.py index d09ea08..66159b9 100644 --- a/tests/integration/test_colorspaces.py +++ b/tests/integration/test_colorspaces.py @@ -27,7 +27,7 @@ def test_image_from_PIL() -> PILImage: @pytest.fixture() def test_kmean_extracted_palette(test_image_path_as_str: str): - return extract_colors(image=test_image_path_as_str, palette_size=10, resize=True, mode=ExtractionMethod.KM) + return extract_colors(image=test_image_path_as_str, palette_size=10, resize=256, mode=ExtractionMethod.KM) @pytest.mark.parametrize("palette_size", [1, 5, 10, 100]) @@ -41,7 +41,7 @@ def test_palette_invariants_with_image_path( palette = extract_colors( image=test_image_path_as_str, palette_size=palette_size, - resize=True, + resize=256, mode=extraction_method, ) @@ -76,7 +76,7 @@ def test_palette_invariants_with_image_pathlike( palette = extract_colors( image=test_image_path_as_pathlike, palette_size=palette_size, - resize=True, + resize=256, mode=extraction_method, ) @@ -111,7 +111,7 @@ def test_palette_invariants_with_image_bytes( palette = extract_colors( image=test_image_as_bytes, palette_size=palette_size, - resize=True, + resize=256, mode=extraction_method, ) @@ -143,7 +143,7 @@ def test_palette_invariants_with_PIL_image( palette = extract_colors( image=test_image_from_PIL, palette_size=palette_size, - resize=True, + resize=256, mode=extraction_method, ) @@ -175,7 +175,7 @@ def test_palette_invariants_with_opencv( palette = extract_colors( image=test_image_from_opencv, palette_size=palette_size, - resize=True, + resize=256, mode=extraction_method, ) @@ -207,7 +207,7 @@ def test_palette_invariants_with_image_url( palette = extract_colors( image=test_image_as_url, palette_size=palette_size, - resize=True, + resize=256, mode=extraction_method, ) @@ -255,9 +255,9 @@ def test_colorspace_invariants_rgb(test_kmean_extracted_palette: Palette): assert 0 <= b <= 255, f"Expected 0 <= b <= 255, got {b}" -@pytest.mark.parametrize("resize, sort_mode", [(True, "luminance"), (False, "frequency")]) +@pytest.mark.parametrize("resize, sort_mode", [(256, "luminance"), (None, "frequency")]) def test_color_extraction_deterministic_kmeans( - test_image_path_as_str: PathLikeImage, resize: bool, sort_mode: Literal["luminance", "frequency"] + test_image_path_as_str: PathLikeImage, resize: int | None, sort_mode: Literal["luminance", "frequency"] ): palette1 = extract_colors( image=test_image_path_as_str, diff --git a/tests/integration/test_exceptions.py b/tests/integration/test_exceptions.py index 76d0386..80e78c7 100644 --- a/tests/integration/test_exceptions.py +++ b/tests/integration/test_exceptions.py @@ -60,7 +60,7 @@ def test_invalid_url_image_raises_invalid_image_error(requests_mock) -> None: # def test_fully_masked_image_raises_no_valid_pixels_error(fully_transparent_image: Image.Image) -> None: """The all-masked #76 case stays pinned to a typed error.""" with pytest.raises(NoValidPixelsError): - extract_colors(fully_transparent_image, alpha_mask_threshold=0, resize=False) + extract_colors(fully_transparent_image, alpha_mask_threshold=0, resize=None) def test_unknown_mode_raises_unknown_extraction_method_error(opaque_image: Image.Image) -> None: diff --git a/tests/integration/test_invariants.py b/tests/integration/test_invariants.py index 3bc8c4e..38f8e92 100644 --- a/tests/integration/test_invariants.py +++ b/tests/integration/test_invariants.py @@ -37,8 +37,8 @@ def _assert_palette_invariants(palette: Palette, palette_size: int) -> None: @pytest.mark.parametrize("mode", METHODS) @pytest.mark.parametrize("palette_size", [1, 3, 5]) -@pytest.mark.parametrize("resize", [True, False]) -def test_solid_image_is_handled(mode: str, palette_size: int, resize: bool) -> None: +@pytest.mark.parametrize("resize", [256, None]) +def test_solid_image_is_handled(mode: str, palette_size: int, resize: int | None) -> None: img = Image.new("RGB", (8, 8), (12, 200, 75)) palette = extract_colors(img, palette_size=palette_size, mode=mode, resize=resize) _assert_palette_invariants(palette, palette_size) @@ -48,7 +48,7 @@ def test_solid_image_is_handled(mode: str, palette_size: int, resize: bool) -> N @pytest.mark.parametrize("mode", METHODS) def test_one_by_one_image_is_handled(mode: str) -> None: img = Image.fromarray(np.array([[[10, 20, 30]]], dtype=np.uint8), "RGB") - palette = extract_colors(img, palette_size=5, mode=mode, resize=False) + palette = extract_colors(img, palette_size=5, mode=mode, resize=None) _assert_palette_invariants(palette, 5) assert len(palette) >= 1 @@ -57,7 +57,7 @@ def test_one_by_one_image_is_handled(mode: str) -> None: def test_palette_size_exceeds_distinct_colors(mode: str) -> None: arr = np.array([[[0, 0, 0], [255, 255, 255]], [[255, 0, 0], [0, 0, 255]]], dtype=np.uint8) img = Image.fromarray(arr, "RGB") - palette = extract_colors(img, palette_size=10, mode=mode, resize=False) + palette = extract_colors(img, palette_size=10, mode=mode, resize=None) _assert_palette_invariants(palette, 10) @@ -67,7 +67,7 @@ def test_partial_alpha_mask_is_handled(mode: str) -> None: arr[..., :3] = np.random.default_rng(0).integers(0, 256, (16, 16, 3)) arr[::2, :, 3] = 255 # half opaque, half transparent img = Image.fromarray(arr, "RGBA") - palette = extract_colors(img, palette_size=5, mode=mode, resize=False, alpha_mask_threshold=0) + palette = extract_colors(img, palette_size=5, mode=mode, resize=None, alpha_mask_threshold=0) _assert_palette_invariants(palette, 5) @@ -76,7 +76,7 @@ def test_total_alpha_mask_raises_typed_error(mode: str) -> None: arr = np.zeros((16, 16, 4), dtype=np.uint8) # alpha = 0 everywhere img = Image.fromarray(arr, "RGBA") with pytest.raises(NoValidPixelsError): - extract_colors(img, palette_size=5, mode=mode, resize=False, alpha_mask_threshold=0) + extract_colors(img, palette_size=5, mode=mode, resize=None, alpha_mask_threshold=0) @pytest.mark.parametrize("mode", METHODS) @@ -119,7 +119,7 @@ def test_sort_order_is_stable_and_idempotent(mode, sort_mode, key, reverse) -> N palette_size=st.integers(1, 8), mode=st.sampled_from(METHODS), sort_mode=st.sampled_from([None, "luminance", "frequency"]), - resize=st.booleans(), + resize=st.sampled_from([None, 64]), ) def test_property_invariants_hold_for_arbitrary_images(arr, palette_size, mode, sort_mode, resize) -> None: # type: ignore[no-untyped-def] mode_str = "RGB" if arr.shape[-1] == 3 else "RGBA" diff --git a/tests/integration/test_metadata.py b/tests/integration/test_metadata.py index d2ffd96..ba9b56d 100644 --- a/tests/integration/test_metadata.py +++ b/tests/integration/test_metadata.py @@ -6,11 +6,11 @@ @pytest.mark.parametrize("mode", [ExtractionMethod.KM, ExtractionMethod.MC]) -@pytest.mark.parametrize("resize", [True, False]) +@pytest.mark.parametrize("resize", [256, None]) class TestMetadata: """Test metadata verification for different extraction configurations.""" - def test_metadata_file_path_as_str(self, test_image_path_as_str: str, mode: ExtractionMethod, resize: bool): + def test_metadata_file_path_as_str(self, test_image_path_as_str: str, mode: ExtractionMethod, resize: int | None): """Test metadata for file path as string input.""" palette = extract_colors(test_image_path_as_str, mode=mode, resize=resize) @@ -31,13 +31,13 @@ def test_metadata_file_path_as_str(self, test_image_path_as_str: str, mode: Extr assert image_info["mode"] == "RGBA" assert image_info["has_alpha"] is True - if resize: + if resize is not None: assert image_info["processed_size"] == (256, 256) else: assert image_info["processed_size"] == (1202, 1276) def test_metadata_file_path_as_pathlike( - self, test_image_path_as_pathlike: PathLikeImage, mode: ExtractionMethod, resize: bool + self, test_image_path_as_pathlike: PathLikeImage, mode: ExtractionMethod, resize: int | None ): """Test metadata for file path as Path object input.""" palette = extract_colors(test_image_path_as_pathlike, mode=mode, resize=resize) @@ -59,12 +59,12 @@ def test_metadata_file_path_as_pathlike( assert image_info["mode"] == "RGBA" assert image_info["has_alpha"] is True - if resize: + if resize is not None: assert image_info["processed_size"] == (256, 256) else: assert image_info["processed_size"] == (1202, 1276) - def test_metadata_url(self, test_image_as_url: URLImage, mode: ExtractionMethod, resize: bool): + def test_metadata_url(self, test_image_as_url: URLImage, mode: ExtractionMethod, resize: int | None): """Test metadata for URL input.""" palette = extract_colors(test_image_as_url, mode=mode, resize=resize) @@ -85,12 +85,12 @@ def test_metadata_url(self, test_image_as_url: URLImage, mode: ExtractionMethod, assert image_info["mode"] == "RGBA" assert image_info["has_alpha"] is True - if resize: + if resize is not None: assert image_info["processed_size"] == (256, 256) else: assert image_info["processed_size"] == (1202, 1276) - def test_metadata_bytes(self, test_image_as_bytes: BytesImage, mode: ExtractionMethod, resize: bool): + def test_metadata_bytes(self, test_image_as_bytes: BytesImage, mode: ExtractionMethod, resize: int | None): """Test metadata for bytes input.""" palette = extract_colors(test_image_as_bytes, mode=mode, resize=resize) @@ -110,7 +110,7 @@ def test_metadata_bytes(self, test_image_as_bytes: BytesImage, mode: ExtractionM assert image_info["mode"] == "RGBA" assert image_info["has_alpha"] is True - if resize: + if resize is not None: assert image_info["processed_size"] == (256, 256) else: assert image_info["processed_size"] == (1202, 1276) @@ -135,7 +135,7 @@ def test_metadata_processing_stats(test_image_path_as_str: str): def test_metadata_processing_stats_no_resize(test_image_path_as_str: str): """Test processing statistics for non-resized image.""" - palette = extract_colors(test_image_path_as_str, resize=False) + palette = extract_colors(test_image_path_as_str, resize=None) assert palette.metadata stats = palette.metadata["processing_stats"] diff --git a/tests/integration/test_resize.py b/tests/integration/test_resize.py new file mode 100644 index 0000000..31f9bbe --- /dev/null +++ b/tests/integration/test_resize.py @@ -0,0 +1,78 @@ +""" +`resize` accepts an int sample size or None (no resize) +""" + +import numpy as np +import pytest +from PIL import Image + +from pylette import batch_extract_colors, extract_colors + +pytestmark = pytest.mark.filterwarnings("ignore::UserWarning") + + +@pytest.fixture +def image() -> Image.Image: + # Distinctive, non-square original size so "no resize" is detectable. + arr = np.random.default_rng(0).integers(0, 256, (30, 40, 3), dtype=np.uint8) + return Image.fromarray(arr, "RGB") + + +def _processed_size(palette) -> tuple[int, int]: # type: ignore[no-untyped-def] + assert palette.metadata + return palette.metadata["image_info"]["processed_size"] + + +def test_default_resize_is_256(image: Image.Image) -> None: + palette = extract_colors(image) + assert _processed_size(palette) == (256, 256) + assert palette.metadata["extraction_params"]["resize"] == 256 + + +def test_explicit_int_resize(image: Image.Image) -> None: + palette = extract_colors(image, resize=64) + assert _processed_size(palette) == (64, 64) + assert palette.metadata["extraction_params"]["resize"] == 64 + + +def test_none_disables_resize(image: Image.Image) -> None: + palette = extract_colors(image, resize=None) + # PIL size is (width, height) == (40, 30) for a (30, 40, 3) array. + assert _processed_size(palette) == (40, 30) + assert palette.metadata["extraction_params"]["resize"] is None + + +def test_invalid_resize_raises(image: Image.Image) -> None: + with pytest.raises(ValueError): + extract_colors(image, resize=0) + with pytest.raises(ValueError): + extract_colors(image, resize=-5) + + +def test_default_call_emits_no_deprecation_warning(image: Image.Image) -> None: + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("error", DeprecationWarning) + extract_colors(image) # default resize=256, must not warn + + +@pytest.mark.parametrize("flag, expected_size", [(True, (256, 256)), (False, (40, 30))]) +def test_bool_resize_is_deprecated_but_works(image: Image.Image, flag: bool, expected_size: tuple[int, int]) -> None: + with pytest.warns(DeprecationWarning): + palette = extract_colors(image, resize=flag) + assert _processed_size(palette) == expected_size + + +def test_batch_bool_resize_is_deprecated_but_works(image: Image.Image, tmp_path) -> None: # type: ignore[no-untyped-def] + # batch uses each source as a dict key, so sources must be hashable (paths). + paths = [] + for i in range(2): + p = tmp_path / f"img_{i}.png" + image.save(p) + paths.append(str(p)) + + with pytest.warns(DeprecationWarning): + results = batch_extract_colors(paths, resize=True) + assert all(r.success for r in results) + assert all(_processed_size(r.palette) == (256, 256) for r in results)