Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,20 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
k-means in the perceptual [OKLab](https://bottosson.github.io/posts/oklab/)
color space, so clusters are grouped by perceived color difference. Pixels are
linearized before conversion.
- **`Color.rgb_float`**: New property returning the canonical color as float
sRGB components in `[0, 1]`, plus a `Color.from_srgb_float(...)` constructor
for building colors from continuous (non-quantized) centroids.

### Changed

- **Import package renamed `Pylette` → `pylette`**
Update imports: `from Pylette import x` → `from pylette import x`.
- **`Color` stores float sRGB canonically**: colors are now kept as float sRGB
in `[0, 1]` internally.
`Color.rgb` now always returns a `tuple[int, int, int]` of plain Python ints
(previously it could return NumPy integers); `.hex`, `.hsv`, `.hls`, and
`.luminance` are derived from the internal float representation. The OKLab extractor keeps its
centroids pre-quantization for extra precision.
- `extract_colors` now resolves the extraction algorithm through the registry
instead of dispatching on the extraction method directly.
- **Extractor `extract()` signature**: Dropped the unused `height` and `width`
Expand Down
128 changes: 112 additions & 16 deletions pylette/src/color.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import colorsys
from typing import cast

import numpy as np

Expand All @@ -9,22 +8,118 @@
luminance_weights = np.array([0.2126, 0.7152, 0.0722])


def _clamp_unit(value: float) -> float:
"""Clamp a float into the closed unit interval [0, 1]."""
return min(1.0, max(0.0, float(value)))


class Color(object):
"""A single palette color.

The canonical representation is float sRGB in ``[0, 1]`` (plus a float
alpha). 8-bit quantization happens only at the output boundaries
(:attr:`rgb`, :attr:`rgba`, :attr:`a`, :attr:`hex`), so colors constructed
from continuous centroids keep their precision until they are read out.
"""

def __init__(self, rgba: tuple[int, ...], frequency: float):
"""
Initializes a Color object with RGBA values and frequency.
Initializes a Color object from 8-bit RGBA values.

The 8-bit input is the quantized view of the color; it is converted to
the canonical float store on construction.

Parameters:
rgba (tuple[int, ...]): A tuple of RGBA values.
rgba (tuple[int, ...]): A tuple of RGBA values, each in [0, 255].
frequency (float): The frequency of the color.
"""
assert len(rgba) == 4, "RGBA values must be a tuple of length 4"
*rgb, alpha = rgba
self.rgb = cast(tuple[int, int, int], rgb)
self.rgba = rgba
self.a = alpha
r, g, b, alpha = (int(round(float(v))) for v in rgba)
self._srgb: tuple[float, float, float] = (r / 255.0, g / 255.0, b / 255.0)
self._alpha: float = alpha / 255.0
self.freq: float = frequency
self.weight = alpha / 255.0

@classmethod
def from_srgb_float(
cls,
srgb: tuple[float, float, float],
frequency: float,
alpha: float = 1.0,
) -> "Color":
"""
Constructs a Color from float sRGB components in ``[0, 1]``.

This is the precision-preserving entry point for extractors whose
centroids live in continuous space (e.g. OKLab); it avoids the round
trip through 8-bit that :meth:`__init__` performs. Components are
clamped into ``[0, 1]`` so out-of-gamut centroids are handled gracefully.

Parameters:
srgb (tuple[float, float, float]): Gamma-encoded sRGB components.
frequency (float): The frequency of the color.
alpha (float): Alpha in ``[0, 1]`` (default fully opaque).

Returns:
Color: A color whose canonical store holds the given floats.
"""
obj = cls.__new__(cls)
r, g, b = srgb
obj._srgb = (_clamp_unit(r), _clamp_unit(g), _clamp_unit(b))
obj._alpha = _clamp_unit(alpha)
obj.freq = frequency
return obj

@property
def rgb_float(self) -> tuple[float, float, float]:
"""
The canonical color as float sRGB components in ``[0, 1]``.

Returns:
tuple[float, float, float]: The (r, g, b) components.
"""
return self._srgb

@property
def rgb(self) -> tuple[int, int, int]:
"""
The color as 8-bit sRGB.

Returns:
tuple[int, int, int]: (r, g, b) as plain Python ints in [0, 255].
"""
r, g, b = self._srgb
return (int(round(r * 255.0)), int(round(g * 255.0)), int(round(b * 255.0)))

@property
def a(self) -> int:
"""
The alpha channel as an 8-bit value.

Returns:
int: Alpha as a plain Python int in [0, 255].
"""
return int(round(self._alpha * 255.0))

@property
def rgba(self) -> tuple[int, int, int, int]:
"""
The color as 8-bit RGBA.

Returns:
tuple[int, int, int, int]: (r, g, b, a) as plain Python ints in [0, 255].
"""
r, g, b = self.rgb
return (r, g, b, self.a)

@property
def weight(self) -> float:
"""
The alpha channel as a fraction in ``[0, 1]``.

Returns:
float: Alpha in [0, 1].
"""
return self._alpha

def display(self, w: int = 50, h: int = 50) -> None:
"""
Expand Down Expand Up @@ -68,22 +163,22 @@ def get_colors(self, colorspace: ColorSpace = ColorSpace.RGB) -> tuple[int, ...]
@property
def hsv(self) -> tuple[float, float, float]:
"""
Converts the RGB color to HSV color space.
Converts the color to HSV color space, derived from the canonical float store.

Returns:
tuple[float, float, float]: The color values in HSV color space.
"""
return colorsys.rgb_to_hsv(r=self.rgb[0] / 255, g=self.rgb[1] / 255, b=self.rgb[2] / 255)
return colorsys.rgb_to_hsv(*self._srgb)

@property
def hls(self) -> tuple[float, float, float]:
"""
Converts the RGB color to HLS color space.
Converts the color to HLS color space, derived from the canonical float store.

Returns:
tuple[float, float, float]: The color values in HLS color space.
"""
return colorsys.rgb_to_hls(r=self.rgb[0] / 255, g=self.rgb[1] / 255, b=self.rgb[2] / 255)
return colorsys.rgb_to_hls(*self._srgb)

@property
def hex(self) -> str:
Expand All @@ -93,14 +188,15 @@ def hex(self) -> str:
Returns:
str: The color in hexadecimal format (e.g., "#FF5733").
"""
return f"#{self.rgb[0]:02X}{self.rgb[1]:02X}{self.rgb[2]:02X}"
r, g, b = self.rgb
return f"#{r:02X}{g:02X}{b:02X}"

@property
def luminance(self) -> float:
"""
Calculates the luminance of the color.
Calculates the luminance of the color, derived from the canonical float store.

Returns:
float: The luminance of the color.
float: The luminance of the color, on the same 0-255 scale as the 8-bit channels.
"""
return np.dot(luminance_weights, self.rgb)
return float(np.dot(luminance_weights, self._srgb)) * 255.0
12 changes: 6 additions & 6 deletions pylette/src/extractors/oklab.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,9 @@ def extract(self, arr: NDArray[NP_T], palette_size: int) -> list[Color]:
labels = model.fit_predict(lab)
centers_lab = np.asarray(model.cluster_centers_)

# OKLab centroids -> sRGB8
centers_srgb = linear_to_srgb(oklab_to_linear_srgb(centers_lab))
centers_rgb8 = np.clip(np.round(centers_srgb * 255.0), 0, 255).astype(int)
# OKLab centroids -> float sRGB in [0, 1], kept pre-quantization so the
# Color stores full precision; out-of-gamut values are clamped.
centers_srgb = np.clip(linear_to_srgb(oklab_to_linear_srgb(centers_lab)), 0.0, 1.0)

counts = np.bincount(labels, minlength=palette_size)
total = float(counts.sum())
Expand All @@ -125,7 +125,7 @@ def extract(self, arr: NDArray[NP_T], palette_size: int) -> list[Color]:
for i in range(palette_size):
if counts[i] == 0:
continue
mean_alpha = int(round(float(alpha[labels == i].mean())))
r, g, b = (int(c) for c in centers_rgb8[i])
colors.append(Color((r, g, b, mean_alpha), counts[i] / total))
mean_alpha = float(alpha[labels == i].mean()) / 255.0
r, g, b = (float(c) for c in centers_srgb[i])
colors.append(Color.from_srgb_float((r, g, b), counts[i] / total, alpha=mean_alpha))
return colors
84 changes: 84 additions & 0 deletions tests/integration/test_color_representation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""
The canonical store is float sRGB in [0, 1]; 8-bit quantization happens only at
output boundaries (``.rgb``, ``.rgba``, ``.hex``).
"""

import colorsys

import numpy as np
import pytest
from PIL import Image

from pylette import Color, extract_colors
from pylette.types import ExtractionMethod


@pytest.fixture
def test_image() -> Image.Image:
rng = np.random.default_rng(2024)
arr = rng.integers(0, 256, size=(32, 32, 3), dtype=np.uint8)
return Image.fromarray(arr, "RGB")


@pytest.mark.parametrize("mode", list(ExtractionMethod))
def test_rgb_is_plain_python_int_for_every_extractor(test_image: Image.Image, mode: ExtractionMethod) -> None:
"""Acceptance: ``all(isinstance(c, int) for c in color.rgb)`` for every extractor."""
palette = extract_colors(test_image, palette_size=4, mode=mode)
for color in palette.colors:
assert all(isinstance(c, int) for c in color.rgb)
assert all(not isinstance(c, np.integer) for c in color.rgb)


@pytest.mark.parametrize("mode", list(ExtractionMethod))
def test_rgb_float_in_unit_interval(test_image: Image.Image, mode: ExtractionMethod) -> None:
palette = extract_colors(test_image, palette_size=4, mode=mode)
for color in palette.colors:
assert len(color.rgb_float) == 3
assert all(isinstance(c, float) for c in color.rgb_float)
assert all(0.0 <= c <= 1.0 for c in color.rgb_float)


def test_from_srgb_float_quantizes_to_rgb() -> None:
red = Color.from_srgb_float((1.0, 0.0, 0.0), frequency=1.0)
assert red.rgb == (255, 0, 0)
assert red.hex == "#FF0000"
assert red.rgb_float == (1.0, 0.0, 0.0)


def test_from_srgb_float_clamps_out_of_gamut() -> None:
c = Color.from_srgb_float((1.5, -0.2, 0.5), frequency=1.0)
assert c.rgb_float == (1.0, 0.0, 0.5)
assert c.rgb == (255, 0, 128)


@pytest.mark.parametrize(
"rgba, expected_hex",
[
((255, 0, 0, 255), "#FF0000"),
((0, 255, 0, 255), "#00FF00"),
((142, 152, 174, 255), "#8E98AE"),
],
)
def test_hex_roundtrip_stable(rgba: tuple[int, int, int, int], expected_hex: str) -> None:
"""Round-trip ``Color -> hex -> Color`` is stable."""
color = Color(rgba=rgba, frequency=0.5)
assert color.hex == expected_hex

r, g, b = color.rgb
roundtripped = Color(rgba=(r, g, b, 255), frequency=0.5)
assert roundtripped.hex == expected_hex
assert roundtripped.rgb == color.rgb


def test_eight_bit_constructor_matches_legacy_hsv() -> None:
"""For 8-bit-constructed colors, derived spaces match the legacy formula."""
color = Color(rgba=(142, 152, 174, 255), frequency=0.5)
assert color.hsv == colorsys.rgb_to_hsv(142 / 255, 152 / 255, 174 / 255)
assert color.hls == colorsys.rgb_to_hls(142 / 255, 152 / 255, 174 / 255)


def test_rgba_and_alpha_are_plain_ints() -> None:
color = Color(rgba=(10, 20, 30, 128), frequency=0.5)
assert color.rgba == (10, 20, 30, 128)
assert isinstance(color.a, int)
assert color.weight == pytest.approx(128 / 255)