Skip to content

Commit 6a7922e

Browse files
authored
feat: add fixed dataloader (#87)
* Create dataloder.py * add unit tests * update unit test
1 parent e789ea6 commit 6a7922e

2 files changed

Lines changed: 53 additions & 0 deletions

File tree

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import torch
2+
import pytest
3+
import numpy as np
4+
from torch.utils.data import TensorDataset, DataLoader
5+
6+
from torchensemble.utils.dataloder import FixedDataLoader
7+
8+
9+
# Data
10+
X = torch.Tensor(np.array(([0.1, 0.1], [0.2, 0.2], [0.3, 0.3], [0.4, 0.4])))
11+
y = torch.LongTensor(np.array(([0, 0, 1, 1])))
12+
13+
data = TensorDataset(X, y)
14+
dataloder = DataLoader(data, batch_size=2, shuffle=False)
15+
16+
17+
def test_fixed_dataloder():
18+
fixed_dataloader = FixedDataLoader(dataloder)
19+
for _, (fixed_elem, elem) in enumerate(zip(fixed_dataloader, dataloder)):
20+
# Check same elements
21+
for elem_1, elem_2 in zip(fixed_elem, elem):
22+
assert torch.equal(elem_1, elem_2)
23+
24+
# Check dataloder length
25+
assert len(fixed_dataloader) == 2
26+
27+
28+
def test_fixed_dataloader_invalid_type():
29+
with pytest.raises(ValueError) as excinfo:
30+
FixedDataLoader((X, y))
31+
assert "input used to instantiate FixedDataLoader" in str(excinfo.value)

torchensemble/utils/dataloder.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from torch.utils.data import DataLoader
2+
3+
4+
class FixedDataLoader(object):
5+
def __init__(self, dataloader):
6+
# Check input
7+
if not isinstance(dataloader, DataLoader):
8+
msg = (
9+
"The input used to instantiate FixedDataLoader should be a"
10+
" DataLoader from `torch.utils.data`."
11+
)
12+
raise ValueError(msg)
13+
14+
self.elem_list = []
15+
for _, elem in enumerate(dataloader):
16+
self.elem_list.append(elem)
17+
18+
def __getitem__(self, index):
19+
return self.elem_list[index]
20+
21+
def __len__(self):
22+
return len(self.elem_list)

0 commit comments

Comments
 (0)