From d047beac1d813e343bb7077e46d633c5e887ec76 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Mon, 14 Apr 2025 21:56:52 +0900 Subject: [PATCH 1/6] bugfix: fix the behavior of ImageDataset --- bdpy/dl/torch/dataset.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/bdpy/dl/torch/dataset.py b/bdpy/dl/torch/dataset.py index c821592f..d5fdf98f 100644 --- a/bdpy/dl/torch/dataset.py +++ b/bdpy/dl/torch/dataset.py @@ -170,10 +170,7 @@ def __init__( ): self.root_path = root_path if stimulus_names is None: - stimulus_names = [ - _removesuffix(path.name, "." + extension) - for path in Path(root_path).glob(f"*{extension}") - ] + stimulus_names = [path.stem for path in Path(root_path).glob(f"*{extension}")] self._stimulus_names = stimulus_names self._extension = extension @@ -184,7 +181,7 @@ def __getitem__(self, index: int): stimulus_name = self._stimulus_names[index] image = Image.open(Path(self.root_path) / f"{stimulus_name}.{self._extension}") image = image.convert("RGB") - return np.array(image) / 255.0, stimulus_name + return np.array(image).transpose(0, 3, 1, 2) / 255.0, stimulus_name class RenameFeatureKeys: From d1c0b5c024d2c56d4e723699357d1e9377653951 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Tue, 15 Apr 2025 21:50:45 +0900 Subject: [PATCH 2/6] Bugfix: bdpy/dl/torch/dataset.py --- bdpy/dl/torch/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bdpy/dl/torch/dataset.py b/bdpy/dl/torch/dataset.py index d5fdf98f..bd237e24 100644 --- a/bdpy/dl/torch/dataset.py +++ b/bdpy/dl/torch/dataset.py @@ -181,7 +181,7 @@ def __getitem__(self, index: int): stimulus_name = self._stimulus_names[index] image = Image.open(Path(self.root_path) / f"{stimulus_name}.{self._extension}") image = image.convert("RGB") - return np.array(image).transpose(0, 3, 1, 2) / 255.0, stimulus_name + return np.array(image).transpose(2, 0, 1) / 255.0, stimulus_name class RenameFeatureKeys: From 63d1443d4bbc903a12e3066e6b399b2353811033 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Wed, 13 May 2026 14:47:20 +0900 Subject: [PATCH 3/6] test: add ImageDataset tests and fix auto-detection order Add tests/dl/torch/test_dataset.py with 9 test cases covering CHW axis order, per-channel values, DataLoader integration, value normalization, length, explicit stimulus ordering, and auto-detection via Path.stem. Also sort auto-detected stimulus names for deterministic ordering, and remove the empty TestImageDataset stub from test_torch.py. Co-Authored-By: Claude Opus 4.7 --- bdpy/dl/torch/dataset.py | 2 +- tests/dl/torch/test_dataset.py | 90 ++++++++++++++++++++++++++++++++++ tests/dl/torch/test_torch.py | 4 -- 3 files changed, 91 insertions(+), 5 deletions(-) create mode 100644 tests/dl/torch/test_dataset.py diff --git a/bdpy/dl/torch/dataset.py b/bdpy/dl/torch/dataset.py index bd237e24..0b00d8d3 100644 --- a/bdpy/dl/torch/dataset.py +++ b/bdpy/dl/torch/dataset.py @@ -170,7 +170,7 @@ def __init__( ): self.root_path = root_path if stimulus_names is None: - stimulus_names = [path.stem for path in Path(root_path).glob(f"*{extension}")] + stimulus_names = sorted(path.stem for path in Path(root_path).glob(f"*{extension}")) self._stimulus_names = stimulus_names self._extension = extension diff --git a/tests/dl/torch/test_dataset.py b/tests/dl/torch/test_dataset.py new file mode 100644 index 00000000..88b58585 --- /dev/null +++ b/tests/dl/torch/test_dataset.py @@ -0,0 +1,90 @@ +import tempfile +import unittest +from pathlib import Path + +import numpy as np +import torch +from PIL import Image +from torch.utils.data import DataLoader + +from bdpy.dl.torch.dataset import ImageDataset + +# Use non-square images (H=4, W=6, C=3) to fully discriminate every axis. +# A square image (H=W) cannot distinguish (3, H, H) from (H, H, 3). +_H, _W = 4, 6 + + +def _save_image(path: Path, r: int, g: int, b: int, h: int = _H, w: int = _W) -> None: + data = np.zeros((h, w, 3), dtype=np.uint8) + data[:, :, 0] = r + data[:, :, 1] = g + data[:, :, 2] = b + Image.fromarray(data).save(path) + + +class TestImageDataset(unittest.TestCase): + def setUp(self) -> None: + self.tmpdir = tempfile.TemporaryDirectory() + root = Path(self.tmpdir.name) + _save_image(root / "a.jpg", r=200, g=100, b=50) + _save_image(root / "b.jpg", r=10, g=20, b=30) + _save_image(root / "c.jpg", r=0, g=128, b=255) + self.root = root + + def tearDown(self) -> None: + self.tmpdir.cleanup() + + def test_getitem_returns_chw_shape(self): + dataset = ImageDataset(self.root, stimulus_names=["a"]) + arr, _ = dataset[0] + self.assertEqual(arr.shape, (3, _H, _W)) + + def test_getitem_preserves_channels(self): + # R=200 G=100 B=50 — verifies C axis maps to the correct channel. + dataset = ImageDataset(self.root, stimulus_names=["a"]) + arr, _ = dataset[0] + self.assertTrue(np.allclose(arr[0], 200 / 255.0)) + self.assertTrue(np.allclose(arr[1], 100 / 255.0)) + self.assertTrue(np.allclose(arr[2], 50 / 255.0)) + + def test_dataloader_integration_batch_shape(self): + dataset = ImageDataset(self.root, stimulus_names=["a", "b"]) + loader = DataLoader(dataset, batch_size=2) + batch_images, _ = next(iter(loader)) + self.assertEqual(tuple(batch_images.shape), (2, 3, _H, _W)) + + def test_value_range_normalized_to_unit_interval(self): + dataset = ImageDataset(self.root, stimulus_names=["a", "b", "c"]) + for i in range(len(dataset)): + arr, _ = dataset[i] + self.assertGreaterEqual(float(arr.min()), 0.0) + self.assertLessEqual(float(arr.max()), 1.0) + + def test_len_matches_stimulus_names(self): + dataset = ImageDataset(self.root, stimulus_names=["a", "b", "c"]) + self.assertEqual(len(dataset), 3) + + def test_explicit_stimulus_names_respected(self): + dataset = ImageDataset(self.root, stimulus_names=["a", "c"]) + self.assertEqual(len(dataset), 2) + _, label0 = dataset[0] + _, label1 = dataset[1] + self.assertEqual(label0, "a") + self.assertEqual(label1, "c") + + def test_auto_detected_stimulus_names_use_stem(self): + dataset = ImageDataset(self.root) + self.assertEqual(set(dataset._stimulus_names), {"a", "b", "c"}) + + def test_explicit_stimulus_names_preserve_input_order(self): + dataset = ImageDataset(self.root, stimulus_names=["c", "a", "b"]) + labels = [dataset[i][1] for i in range(len(dataset))] + self.assertEqual(labels, ["c", "a", "b"]) + + def test_auto_detected_stimulus_names_are_sorted(self): + dataset = ImageDataset(self.root) + self.assertEqual(dataset._stimulus_names, ["a", "b", "c"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/dl/torch/test_torch.py b/tests/dl/torch/test_torch.py index c5bcc2d6..bcaf51ec 100644 --- a/tests/dl/torch/test_torch.py +++ b/tests/dl/torch/test_torch.py @@ -69,9 +69,5 @@ def test_run_with_layer_map(self): self.assertEqual(features[layer].shape, shape) -class TestImageDataset(unittest.TestCase): - ... - - if __name__ == '__main__': unittest.main() \ No newline at end of file From c82dfac40d4e0de1d86c5cbd1371113ff4ddd540 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Wed, 13 May 2026 23:10:34 +0900 Subject: [PATCH 4/6] refactor: remove unused import of torch in test_dataset.py --- tests/dl/torch/test_dataset.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/dl/torch/test_dataset.py b/tests/dl/torch/test_dataset.py index 88b58585..847b20f7 100644 --- a/tests/dl/torch/test_dataset.py +++ b/tests/dl/torch/test_dataset.py @@ -3,7 +3,6 @@ from pathlib import Path import numpy as np -import torch from PIL import Image from torch.utils.data import DataLoader From 57451e770bb3fa5df3fdccb352264932ba41e9bc Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Wed, 13 May 2026 23:14:03 +0900 Subject: [PATCH 5/6] test: update ImageDataset tests to use PNG format for images --- tests/dl/torch/test_dataset.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/dl/torch/test_dataset.py b/tests/dl/torch/test_dataset.py index 847b20f7..ff0be98b 100644 --- a/tests/dl/torch/test_dataset.py +++ b/tests/dl/torch/test_dataset.py @@ -25,46 +25,46 @@ class TestImageDataset(unittest.TestCase): def setUp(self) -> None: self.tmpdir = tempfile.TemporaryDirectory() root = Path(self.tmpdir.name) - _save_image(root / "a.jpg", r=200, g=100, b=50) - _save_image(root / "b.jpg", r=10, g=20, b=30) - _save_image(root / "c.jpg", r=0, g=128, b=255) + _save_image(root / "a.png", r=200, g=100, b=50) + _save_image(root / "b.png", r=10, g=20, b=30) + _save_image(root / "c.png", r=0, g=128, b=255) self.root = root def tearDown(self) -> None: self.tmpdir.cleanup() def test_getitem_returns_chw_shape(self): - dataset = ImageDataset(self.root, stimulus_names=["a"]) + dataset = ImageDataset(self.root, stimulus_names=["a"], extension="png") arr, _ = dataset[0] self.assertEqual(arr.shape, (3, _H, _W)) def test_getitem_preserves_channels(self): # R=200 G=100 B=50 — verifies C axis maps to the correct channel. - dataset = ImageDataset(self.root, stimulus_names=["a"]) + dataset = ImageDataset(self.root, stimulus_names=["a"], extension="png") arr, _ = dataset[0] self.assertTrue(np.allclose(arr[0], 200 / 255.0)) self.assertTrue(np.allclose(arr[1], 100 / 255.0)) self.assertTrue(np.allclose(arr[2], 50 / 255.0)) def test_dataloader_integration_batch_shape(self): - dataset = ImageDataset(self.root, stimulus_names=["a", "b"]) + dataset = ImageDataset(self.root, stimulus_names=["a", "b"], extension="png" ) loader = DataLoader(dataset, batch_size=2) batch_images, _ = next(iter(loader)) self.assertEqual(tuple(batch_images.shape), (2, 3, _H, _W)) def test_value_range_normalized_to_unit_interval(self): - dataset = ImageDataset(self.root, stimulus_names=["a", "b", "c"]) + dataset = ImageDataset(self.root, stimulus_names=["a", "b", "c"], extension="png") for i in range(len(dataset)): arr, _ = dataset[i] self.assertGreaterEqual(float(arr.min()), 0.0) self.assertLessEqual(float(arr.max()), 1.0) def test_len_matches_stimulus_names(self): - dataset = ImageDataset(self.root, stimulus_names=["a", "b", "c"]) + dataset = ImageDataset(self.root, stimulus_names=["a", "b", "c"], extension="png") self.assertEqual(len(dataset), 3) def test_explicit_stimulus_names_respected(self): - dataset = ImageDataset(self.root, stimulus_names=["a", "c"]) + dataset = ImageDataset(self.root, stimulus_names=["a", "c"], extension="png") self.assertEqual(len(dataset), 2) _, label0 = dataset[0] _, label1 = dataset[1] @@ -72,16 +72,16 @@ def test_explicit_stimulus_names_respected(self): self.assertEqual(label1, "c") def test_auto_detected_stimulus_names_use_stem(self): - dataset = ImageDataset(self.root) + dataset = ImageDataset(self.root, extension="png") self.assertEqual(set(dataset._stimulus_names), {"a", "b", "c"}) def test_explicit_stimulus_names_preserve_input_order(self): - dataset = ImageDataset(self.root, stimulus_names=["c", "a", "b"]) + dataset = ImageDataset(self.root, stimulus_names=["c", "a", "b"], extension="png") labels = [dataset[i][1] for i in range(len(dataset))] self.assertEqual(labels, ["c", "a", "b"]) def test_auto_detected_stimulus_names_are_sorted(self): - dataset = ImageDataset(self.root) + dataset = ImageDataset(self.root, extension="png") self.assertEqual(dataset._stimulus_names, ["a", "b", "c"]) From 662f8186de1c323db387da95dd095dfb8b8d6db7 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Thu, 14 May 2026 10:18:27 +0900 Subject: [PATCH 6/6] doc: update ImageDataset docstring to clarify stimulus_names behavior and add notes on image format --- bdpy/dl/torch/dataset.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/bdpy/dl/torch/dataset.py b/bdpy/dl/torch/dataset.py index 0b00d8d3..c45f6162 100644 --- a/bdpy/dl/torch/dataset.py +++ b/bdpy/dl/torch/dataset.py @@ -157,9 +157,16 @@ class ImageDataset(Dataset): root_path : str | Path Path to the root directory of images. stimulus_names : list[str], optional - List of stimulus names. If None, all stimulus names are used. + List of stimulus names. If provided, images are returned in the given + order. If None, all images found under ``root_path`` are used in + alphabetical order. extension : str, optional Extension of the image files. + + Notes + ----- + Images are returned as float64 arrays in CHW (channels, height, width) + format with pixel values normalized to [0, 1]. """ def __init__(