Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions bdpy/dl/torch/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,16 @@ class ImageDataset(Dataset):
root_path : str | Path
Path to the root directory of images.
stimulus_names : list[str], optional
List of stimulus names. If None, all stimulus names are used.
List of stimulus names. If provided, images are returned in the given
order. If None, all images found under ``root_path`` are used in
alphabetical order.
extension : str, optional
Extension of the image files.

Notes
-----
Images are returned as float64 arrays in CHW (channels, height, width)
format with pixel values normalized to [0, 1].
"""

def __init__(
Expand All @@ -170,10 +177,7 @@ def __init__(
):
self.root_path = root_path
if stimulus_names is None:
stimulus_names = [
_removesuffix(path.name, "." + extension)
for path in Path(root_path).glob(f"*{extension}")
]
stimulus_names = sorted(path.stem for path in Path(root_path).glob(f"*{extension}"))
self._stimulus_names = stimulus_names
self._extension = extension

Expand All @@ -184,7 +188,7 @@ def __getitem__(self, index: int):
stimulus_name = self._stimulus_names[index]
image = Image.open(Path(self.root_path) / f"{stimulus_name}.{self._extension}")
image = image.convert("RGB")
return np.array(image) / 255.0, stimulus_name
return np.array(image).transpose(2, 0, 1) / 255.0, stimulus_name


class RenameFeatureKeys:
Expand Down
89 changes: 89 additions & 0 deletions tests/dl/torch/test_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import tempfile
import unittest
from pathlib import Path

import numpy as np
from PIL import Image
from torch.utils.data import DataLoader

from bdpy.dl.torch.dataset import ImageDataset

# Use non-square images (H=4, W=6, C=3) to fully discriminate every axis.
# A square image (H=W) cannot distinguish (3, H, H) from (H, H, 3).
_H, _W = 4, 6


def _save_image(path: Path, r: int, g: int, b: int, h: int = _H, w: int = _W) -> None:
data = np.zeros((h, w, 3), dtype=np.uint8)
data[:, :, 0] = r
data[:, :, 1] = g
data[:, :, 2] = b
Image.fromarray(data).save(path)


class TestImageDataset(unittest.TestCase):
def setUp(self) -> None:
self.tmpdir = tempfile.TemporaryDirectory()
root = Path(self.tmpdir.name)
_save_image(root / "a.png", r=200, g=100, b=50)
_save_image(root / "b.png", r=10, g=20, b=30)
_save_image(root / "c.png", r=0, g=128, b=255)
self.root = root

def tearDown(self) -> None:
self.tmpdir.cleanup()

def test_getitem_returns_chw_shape(self):
dataset = ImageDataset(self.root, stimulus_names=["a"], extension="png")
arr, _ = dataset[0]
self.assertEqual(arr.shape, (3, _H, _W))

def test_getitem_preserves_channels(self):
# R=200 G=100 B=50 — verifies C axis maps to the correct channel.
dataset = ImageDataset(self.root, stimulus_names=["a"], extension="png")
arr, _ = dataset[0]
self.assertTrue(np.allclose(arr[0], 200 / 255.0))
self.assertTrue(np.allclose(arr[1], 100 / 255.0))
self.assertTrue(np.allclose(arr[2], 50 / 255.0))

def test_dataloader_integration_batch_shape(self):
dataset = ImageDataset(self.root, stimulus_names=["a", "b"], extension="png" )
loader = DataLoader(dataset, batch_size=2)
batch_images, _ = next(iter(loader))
self.assertEqual(tuple(batch_images.shape), (2, 3, _H, _W))

def test_value_range_normalized_to_unit_interval(self):
dataset = ImageDataset(self.root, stimulus_names=["a", "b", "c"], extension="png")
for i in range(len(dataset)):
arr, _ = dataset[i]
self.assertGreaterEqual(float(arr.min()), 0.0)
self.assertLessEqual(float(arr.max()), 1.0)

def test_len_matches_stimulus_names(self):
dataset = ImageDataset(self.root, stimulus_names=["a", "b", "c"], extension="png")
self.assertEqual(len(dataset), 3)

def test_explicit_stimulus_names_respected(self):
dataset = ImageDataset(self.root, stimulus_names=["a", "c"], extension="png")
self.assertEqual(len(dataset), 2)
_, label0 = dataset[0]
_, label1 = dataset[1]
self.assertEqual(label0, "a")
self.assertEqual(label1, "c")

def test_auto_detected_stimulus_names_use_stem(self):
dataset = ImageDataset(self.root, extension="png")
self.assertEqual(set(dataset._stimulus_names), {"a", "b", "c"})

def test_explicit_stimulus_names_preserve_input_order(self):
dataset = ImageDataset(self.root, stimulus_names=["c", "a", "b"], extension="png")
labels = [dataset[i][1] for i in range(len(dataset))]
self.assertEqual(labels, ["c", "a", "b"])

def test_auto_detected_stimulus_names_are_sorted(self):
dataset = ImageDataset(self.root, extension="png")
self.assertEqual(dataset._stimulus_names, ["a", "b", "c"])


if __name__ == "__main__":
unittest.main()
4 changes: 0 additions & 4 deletions tests/dl/torch/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,5 @@ def test_run_with_layer_map(self):
self.assertEqual(features[layer].shape, shape)


class TestImageDataset(unittest.TestCase):
...


if __name__ == '__main__':
unittest.main()
Loading