diff --git a/bdpy/dl/torch/dataset.py b/bdpy/dl/torch/dataset.py index c821592..c45f616 100644 --- a/bdpy/dl/torch/dataset.py +++ b/bdpy/dl/torch/dataset.py @@ -157,9 +157,16 @@ class ImageDataset(Dataset): root_path : str | Path Path to the root directory of images. stimulus_names : list[str], optional - List of stimulus names. If None, all stimulus names are used. + List of stimulus names. If provided, images are returned in the given + order. If None, all images found under ``root_path`` are used in + alphabetical order. extension : str, optional Extension of the image files. + + Notes + ----- + Images are returned as float64 arrays in CHW (channels, height, width) + format with pixel values normalized to [0, 1]. """ def __init__( @@ -170,10 +177,7 @@ def __init__( ): self.root_path = root_path if stimulus_names is None: - stimulus_names = [ - _removesuffix(path.name, "." + extension) - for path in Path(root_path).glob(f"*{extension}") - ] + stimulus_names = sorted(path.stem for path in Path(root_path).glob(f"*{extension}")) self._stimulus_names = stimulus_names self._extension = extension @@ -184,7 +188,7 @@ def __getitem__(self, index: int): stimulus_name = self._stimulus_names[index] image = Image.open(Path(self.root_path) / f"{stimulus_name}.{self._extension}") image = image.convert("RGB") - return np.array(image) / 255.0, stimulus_name + return np.array(image).transpose(2, 0, 1) / 255.0, stimulus_name class RenameFeatureKeys: diff --git a/tests/dl/torch/test_dataset.py b/tests/dl/torch/test_dataset.py new file mode 100644 index 0000000..ff0be98 --- /dev/null +++ b/tests/dl/torch/test_dataset.py @@ -0,0 +1,89 @@ +import tempfile +import unittest +from pathlib import Path + +import numpy as np +from PIL import Image +from torch.utils.data import DataLoader + +from bdpy.dl.torch.dataset import ImageDataset + +# Use non-square images (H=4, W=6, C=3) to fully discriminate every axis. +# A square image (H=W) cannot distinguish (3, H, H) from (H, H, 3). +_H, _W = 4, 6 + + +def _save_image(path: Path, r: int, g: int, b: int, h: int = _H, w: int = _W) -> None: + data = np.zeros((h, w, 3), dtype=np.uint8) + data[:, :, 0] = r + data[:, :, 1] = g + data[:, :, 2] = b + Image.fromarray(data).save(path) + + +class TestImageDataset(unittest.TestCase): + def setUp(self) -> None: + self.tmpdir = tempfile.TemporaryDirectory() + root = Path(self.tmpdir.name) + _save_image(root / "a.png", r=200, g=100, b=50) + _save_image(root / "b.png", r=10, g=20, b=30) + _save_image(root / "c.png", r=0, g=128, b=255) + self.root = root + + def tearDown(self) -> None: + self.tmpdir.cleanup() + + def test_getitem_returns_chw_shape(self): + dataset = ImageDataset(self.root, stimulus_names=["a"], extension="png") + arr, _ = dataset[0] + self.assertEqual(arr.shape, (3, _H, _W)) + + def test_getitem_preserves_channels(self): + # R=200 G=100 B=50 — verifies C axis maps to the correct channel. + dataset = ImageDataset(self.root, stimulus_names=["a"], extension="png") + arr, _ = dataset[0] + self.assertTrue(np.allclose(arr[0], 200 / 255.0)) + self.assertTrue(np.allclose(arr[1], 100 / 255.0)) + self.assertTrue(np.allclose(arr[2], 50 / 255.0)) + + def test_dataloader_integration_batch_shape(self): + dataset = ImageDataset(self.root, stimulus_names=["a", "b"], extension="png" ) + loader = DataLoader(dataset, batch_size=2) + batch_images, _ = next(iter(loader)) + self.assertEqual(tuple(batch_images.shape), (2, 3, _H, _W)) + + def test_value_range_normalized_to_unit_interval(self): + dataset = ImageDataset(self.root, stimulus_names=["a", "b", "c"], extension="png") + for i in range(len(dataset)): + arr, _ = dataset[i] + self.assertGreaterEqual(float(arr.min()), 0.0) + self.assertLessEqual(float(arr.max()), 1.0) + + def test_len_matches_stimulus_names(self): + dataset = ImageDataset(self.root, stimulus_names=["a", "b", "c"], extension="png") + self.assertEqual(len(dataset), 3) + + def test_explicit_stimulus_names_respected(self): + dataset = ImageDataset(self.root, stimulus_names=["a", "c"], extension="png") + self.assertEqual(len(dataset), 2) + _, label0 = dataset[0] + _, label1 = dataset[1] + self.assertEqual(label0, "a") + self.assertEqual(label1, "c") + + def test_auto_detected_stimulus_names_use_stem(self): + dataset = ImageDataset(self.root, extension="png") + self.assertEqual(set(dataset._stimulus_names), {"a", "b", "c"}) + + def test_explicit_stimulus_names_preserve_input_order(self): + dataset = ImageDataset(self.root, stimulus_names=["c", "a", "b"], extension="png") + labels = [dataset[i][1] for i in range(len(dataset))] + self.assertEqual(labels, ["c", "a", "b"]) + + def test_auto_detected_stimulus_names_are_sorted(self): + dataset = ImageDataset(self.root, extension="png") + self.assertEqual(dataset._stimulus_names, ["a", "b", "c"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/dl/torch/test_torch.py b/tests/dl/torch/test_torch.py index c5bcc2d..bcaf51e 100644 --- a/tests/dl/torch/test_torch.py +++ b/tests/dl/torch/test_torch.py @@ -69,9 +69,5 @@ def test_run_with_layer_map(self): self.assertEqual(features[layer].shape, shape) -class TestImageDataset(unittest.TestCase): - ... - - if __name__ == '__main__': unittest.main() \ No newline at end of file