Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ rather than message. Each subclass also derives from `ValueError`, so existing
- **CLI option names aligned with the library**: `--palette-size` (canonical;
`--n` kept as an alias) and `--max-workers` (canonical; `--num-threads` kept
as an alias).
- **Explicit sample size**: `resize` now takes an `int` sample size or `None`
(no resize) instead of a bare bool, surfacing the previously hidden 256x256
downscale. The default is `resize=256` (unchanged behavior). The
`ExtractionParams` metadata `resize` field is now `int | None`.

### Deprecated

Expand All @@ -71,6 +75,9 @@ rather than message. Each subclass also derives from `ValueError`, so existing
`--max-workers`.
- **`Color.get_colors(...)`**: replaced by `Color.to(...)`. It still works for
one release and now emits a `DeprecationWarning`.
- **Bare bool `resize`**: pass an int sample size or `None` instead. `resize=True`
(→ `256`) and `resize=False` (→ `None`) still work for one release and now emit
a `DeprecationWarning`.

### Removed

Expand Down
48 changes: 39 additions & 9 deletions pylette/src/color_extraction.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import time
import urllib.parse
import warnings
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
from io import BytesIO
Expand Down Expand Up @@ -95,10 +96,31 @@ def _get_descriptive_image_source(image: ImageInput, pil_image: PILImage) -> str
return f"<unknown: {type(image).__name__}>"


def _resolve_resize(resize: int | bool | None) -> int | None:
"""Normalize the ``resize`` argument to a pixel sample size or ``None``.

Accepts an ``int`` sample size (the image is resized to ``(resize, resize)``
before sampling), ``None`` (no resize, sample the full image), or a
deprecated ``bool`` (``True`` -> 256, ``False`` -> ``None``).
"""
if isinstance(resize, bool):
warnings.warn(
"Passing a bool for `resize` is deprecated and will be removed; pass an int "
"sample size (e.g. resize=256) or None to disable resizing (True maps to 256, "
"False to None).",
DeprecationWarning,
stacklevel=3,
)
return 256 if resize else None
if resize is not None and resize < 1:
raise ValueError(f"resize must be a positive int or None, got {resize!r}.")
return resize


def batch_extract_colors(
images: Sequence[ImageInput],
palette_size: int = 5,
resize: bool = True,
resize: int | bool | None = 256,
mode: ExtractionMethod | str = ExtractionMethod.KM,
sort_mode: Literal["luminance", "frequency"] | None = None,
alpha_mask_threshold: int | None = None,
Expand All @@ -112,6 +134,8 @@ def batch_extract_colors(
Receives (task_number, result) as arguments.
"""

resize = _resolve_resize(resize)

def thread_fn(image: ImageInput):
return extract_colors(
image=image,
Expand Down Expand Up @@ -151,7 +175,7 @@ def thread_fn(image: ImageInput):
def extract_colors(
image: ImageInput,
palette_size: int = 5,
resize: bool = True,
resize: int | bool | None = 256,
mode: ExtractionMethod | str = ExtractionMethod.KM,
sort_mode: Literal["luminance", "frequency"] | None = None,
alpha_mask_threshold: int | None = None,
Expand All @@ -162,7 +186,12 @@ def extract_colors(
Parameters:
image: The input image.
palette_size: The number of colors to extract.
resize: Whether to resize the image before processing.
resize: The sample size. The image is downscaled to ``(resize, resize)``
before colors are extracted, which bounds runtime; pass ``None`` to
sample the image at full resolution instead. Smaller values are
faster but coarser; larger values are slower but capture more detail.
Defaults to ``256``. (Passing a ``bool`` is deprecated: ``True`` maps
to ``256`` and ``False`` to ``None``.)
mode: The color quantization algorithm to use.
sort_mode: The mode to sort colors.
alpha_mask_threshold: Optional integer between 0, 255.
Expand Down Expand Up @@ -196,13 +225,14 @@ def extract_colors(
Examples:
Colors can be extracted from a variety of sources, including local files, byte streams, URLs, and numpy arrays.

>>> extract_colors("path/to/image.jpg", palette_size=5, resize=True, mode="KM", sort_mode="luminance")
>>> extract_colors(b"image_bytes", palette_size=5, resize=True, mode="KM", sort_mode="luminance")
>>> extract_colors("path/to/image.jpg", palette_size=5, resize=256, mode="KM", sort_mode="luminance")
>>> extract_colors(b"image_bytes", palette_size=5, resize=None, mode="KM", sort_mode="luminance")
"""

start_time = time.time()

mode = coerce_to_enum(mode, ExtractionMethod, error_cls=UnknownExtractionMethodError)
resize = _resolve_resize(resize)

source_type = _get_source_type_from_image_input(image)
# Normalize input to PIL Image and convert to RGBA
Expand All @@ -213,15 +243,15 @@ def extract_colors(
# Store original image info
image_info = ImageInfo(
original_size=original_size,
processed_size=img.size if not resize else (256, 256),
processed_size=(resize, resize) if resize is not None else img.size,
format=getattr(img_obj, "format", None),
mode=img.mode,
has_alpha=img.mode in ("RGBA", "LA") or "transparency" in img_obj.info,
)

if resize:
img = img.resize((256, 256))
image_info["processed_size"] = (256, 256)
if resize is not None:
img = img.resize((resize, resize))
image_info["processed_size"] = (resize, resize)

width, height = img.size
arr = np.asarray(img)
Expand Down
2 changes: 1 addition & 1 deletion pylette/src/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class ExtractionParams(TypedDict):
palette_size: int
mode: ExtractionMethod
sort_mode: str | None
resize: bool
resize: int | None
alpha_mask_threshold: int | None


Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_alpha_masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_extraction_survives_alpha_masking(half_transparent_image: Image.Image,
half_transparent_image,
palette_size=5,
mode=mode,
resize=False,
resize=None,
alpha_mask_threshold=0,
)
assert len(palette) <= 5
Expand Down
18 changes: 9 additions & 9 deletions tests/integration/test_colorspaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_image_from_PIL() -> PILImage:

@pytest.fixture()
def test_kmean_extracted_palette(test_image_path_as_str: str):
return extract_colors(image=test_image_path_as_str, palette_size=10, resize=True, mode=ExtractionMethod.KM)
return extract_colors(image=test_image_path_as_str, palette_size=10, resize=256, mode=ExtractionMethod.KM)


@pytest.mark.parametrize("palette_size", [1, 5, 10, 100])
Expand All @@ -41,7 +41,7 @@ def test_palette_invariants_with_image_path(
palette = extract_colors(
image=test_image_path_as_str,
palette_size=palette_size,
resize=True,
resize=256,
mode=extraction_method,
)

Expand Down Expand Up @@ -76,7 +76,7 @@ def test_palette_invariants_with_image_pathlike(
palette = extract_colors(
image=test_image_path_as_pathlike,
palette_size=palette_size,
resize=True,
resize=256,
mode=extraction_method,
)

Expand Down Expand Up @@ -111,7 +111,7 @@ def test_palette_invariants_with_image_bytes(
palette = extract_colors(
image=test_image_as_bytes,
palette_size=palette_size,
resize=True,
resize=256,
mode=extraction_method,
)

Expand Down Expand Up @@ -143,7 +143,7 @@ def test_palette_invariants_with_PIL_image(
palette = extract_colors(
image=test_image_from_PIL,
palette_size=palette_size,
resize=True,
resize=256,
mode=extraction_method,
)

Expand Down Expand Up @@ -175,7 +175,7 @@ def test_palette_invariants_with_opencv(
palette = extract_colors(
image=test_image_from_opencv,
palette_size=palette_size,
resize=True,
resize=256,
mode=extraction_method,
)

Expand Down Expand Up @@ -207,7 +207,7 @@ def test_palette_invariants_with_image_url(
palette = extract_colors(
image=test_image_as_url,
palette_size=palette_size,
resize=True,
resize=256,
mode=extraction_method,
)

Expand Down Expand Up @@ -255,9 +255,9 @@ def test_colorspace_invariants_rgb(test_kmean_extracted_palette: Palette):
assert 0 <= b <= 255, f"Expected 0 <= b <= 255, got {b}"


@pytest.mark.parametrize("resize, sort_mode", [(True, "luminance"), (False, "frequency")])
@pytest.mark.parametrize("resize, sort_mode", [(256, "luminance"), (None, "frequency")])
def test_color_extraction_deterministic_kmeans(
test_image_path_as_str: PathLikeImage, resize: bool, sort_mode: Literal["luminance", "frequency"]
test_image_path_as_str: PathLikeImage, resize: int | None, sort_mode: Literal["luminance", "frequency"]
):
palette1 = extract_colors(
image=test_image_path_as_str,
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_invalid_url_image_raises_invalid_image_error(requests_mock) -> None: #
def test_fully_masked_image_raises_no_valid_pixels_error(fully_transparent_image: Image.Image) -> None:
"""The all-masked #76 case stays pinned to a typed error."""
with pytest.raises(NoValidPixelsError):
extract_colors(fully_transparent_image, alpha_mask_threshold=0, resize=False)
extract_colors(fully_transparent_image, alpha_mask_threshold=0, resize=None)


def test_unknown_mode_raises_unknown_extraction_method_error(opaque_image: Image.Image) -> None:
Expand Down
14 changes: 7 additions & 7 deletions tests/integration/test_invariants.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def _assert_palette_invariants(palette: Palette, palette_size: int) -> None:

@pytest.mark.parametrize("mode", METHODS)
@pytest.mark.parametrize("palette_size", [1, 3, 5])
@pytest.mark.parametrize("resize", [True, False])
def test_solid_image_is_handled(mode: str, palette_size: int, resize: bool) -> None:
@pytest.mark.parametrize("resize", [256, None])
def test_solid_image_is_handled(mode: str, palette_size: int, resize: int | None) -> None:
img = Image.new("RGB", (8, 8), (12, 200, 75))
palette = extract_colors(img, palette_size=palette_size, mode=mode, resize=resize)
_assert_palette_invariants(palette, palette_size)
Expand All @@ -48,7 +48,7 @@ def test_solid_image_is_handled(mode: str, palette_size: int, resize: bool) -> N
@pytest.mark.parametrize("mode", METHODS)
def test_one_by_one_image_is_handled(mode: str) -> None:
img = Image.fromarray(np.array([[[10, 20, 30]]], dtype=np.uint8), "RGB")
palette = extract_colors(img, palette_size=5, mode=mode, resize=False)
palette = extract_colors(img, palette_size=5, mode=mode, resize=None)
_assert_palette_invariants(palette, 5)
assert len(palette) >= 1

Expand All @@ -57,7 +57,7 @@ def test_one_by_one_image_is_handled(mode: str) -> None:
def test_palette_size_exceeds_distinct_colors(mode: str) -> None:
arr = np.array([[[0, 0, 0], [255, 255, 255]], [[255, 0, 0], [0, 0, 255]]], dtype=np.uint8)
img = Image.fromarray(arr, "RGB")
palette = extract_colors(img, palette_size=10, mode=mode, resize=False)
palette = extract_colors(img, palette_size=10, mode=mode, resize=None)
_assert_palette_invariants(palette, 10)


Expand All @@ -67,7 +67,7 @@ def test_partial_alpha_mask_is_handled(mode: str) -> None:
arr[..., :3] = np.random.default_rng(0).integers(0, 256, (16, 16, 3))
arr[::2, :, 3] = 255 # half opaque, half transparent
img = Image.fromarray(arr, "RGBA")
palette = extract_colors(img, palette_size=5, mode=mode, resize=False, alpha_mask_threshold=0)
palette = extract_colors(img, palette_size=5, mode=mode, resize=None, alpha_mask_threshold=0)
_assert_palette_invariants(palette, 5)


Expand All @@ -76,7 +76,7 @@ def test_total_alpha_mask_raises_typed_error(mode: str) -> None:
arr = np.zeros((16, 16, 4), dtype=np.uint8) # alpha = 0 everywhere
img = Image.fromarray(arr, "RGBA")
with pytest.raises(NoValidPixelsError):
extract_colors(img, palette_size=5, mode=mode, resize=False, alpha_mask_threshold=0)
extract_colors(img, palette_size=5, mode=mode, resize=None, alpha_mask_threshold=0)


@pytest.mark.parametrize("mode", METHODS)
Expand Down Expand Up @@ -119,7 +119,7 @@ def test_sort_order_is_stable_and_idempotent(mode, sort_mode, key, reverse) -> N
palette_size=st.integers(1, 8),
mode=st.sampled_from(METHODS),
sort_mode=st.sampled_from([None, "luminance", "frequency"]),
resize=st.booleans(),
resize=st.sampled_from([None, 64]),
)
def test_property_invariants_hold_for_arbitrary_images(arr, palette_size, mode, sort_mode, resize) -> None: # type: ignore[no-untyped-def]
mode_str = "RGB" if arr.shape[-1] == 3 else "RGBA"
Expand Down
20 changes: 10 additions & 10 deletions tests/integration/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@


@pytest.mark.parametrize("mode", [ExtractionMethod.KM, ExtractionMethod.MC])
@pytest.mark.parametrize("resize", [True, False])
@pytest.mark.parametrize("resize", [256, None])
class TestMetadata:
"""Test metadata verification for different extraction configurations."""

def test_metadata_file_path_as_str(self, test_image_path_as_str: str, mode: ExtractionMethod, resize: bool):
def test_metadata_file_path_as_str(self, test_image_path_as_str: str, mode: ExtractionMethod, resize: int | None):
"""Test metadata for file path as string input."""
palette = extract_colors(test_image_path_as_str, mode=mode, resize=resize)

Expand All @@ -31,13 +31,13 @@ def test_metadata_file_path_as_str(self, test_image_path_as_str: str, mode: Extr
assert image_info["mode"] == "RGBA"
assert image_info["has_alpha"] is True

if resize:
if resize is not None:
assert image_info["processed_size"] == (256, 256)
else:
assert image_info["processed_size"] == (1202, 1276)

def test_metadata_file_path_as_pathlike(
self, test_image_path_as_pathlike: PathLikeImage, mode: ExtractionMethod, resize: bool
self, test_image_path_as_pathlike: PathLikeImage, mode: ExtractionMethod, resize: int | None
):
"""Test metadata for file path as Path object input."""
palette = extract_colors(test_image_path_as_pathlike, mode=mode, resize=resize)
Expand All @@ -59,12 +59,12 @@ def test_metadata_file_path_as_pathlike(
assert image_info["mode"] == "RGBA"
assert image_info["has_alpha"] is True

if resize:
if resize is not None:
assert image_info["processed_size"] == (256, 256)
else:
assert image_info["processed_size"] == (1202, 1276)

def test_metadata_url(self, test_image_as_url: URLImage, mode: ExtractionMethod, resize: bool):
def test_metadata_url(self, test_image_as_url: URLImage, mode: ExtractionMethod, resize: int | None):
"""Test metadata for URL input."""
palette = extract_colors(test_image_as_url, mode=mode, resize=resize)

Expand All @@ -85,12 +85,12 @@ def test_metadata_url(self, test_image_as_url: URLImage, mode: ExtractionMethod,
assert image_info["mode"] == "RGBA"
assert image_info["has_alpha"] is True

if resize:
if resize is not None:
assert image_info["processed_size"] == (256, 256)
else:
assert image_info["processed_size"] == (1202, 1276)

def test_metadata_bytes(self, test_image_as_bytes: BytesImage, mode: ExtractionMethod, resize: bool):
def test_metadata_bytes(self, test_image_as_bytes: BytesImage, mode: ExtractionMethod, resize: int | None):
"""Test metadata for bytes input."""
palette = extract_colors(test_image_as_bytes, mode=mode, resize=resize)

Expand All @@ -110,7 +110,7 @@ def test_metadata_bytes(self, test_image_as_bytes: BytesImage, mode: ExtractionM
assert image_info["mode"] == "RGBA"
assert image_info["has_alpha"] is True

if resize:
if resize is not None:
assert image_info["processed_size"] == (256, 256)
else:
assert image_info["processed_size"] == (1202, 1276)
Expand All @@ -135,7 +135,7 @@ def test_metadata_processing_stats(test_image_path_as_str: str):

def test_metadata_processing_stats_no_resize(test_image_path_as_str: str):
"""Test processing statistics for non-resized image."""
palette = extract_colors(test_image_path_as_str, resize=False)
palette = extract_colors(test_image_path_as_str, resize=None)

assert palette.metadata
stats = palette.metadata["processing_stats"]
Expand Down
Loading