diff --git a/robustbench/loaders.py b/robustbench/loaders.py index e050b14..9aa39ad 100644 --- a/robustbench/loaders.py +++ b/robustbench/loaders.py @@ -1,7 +1,7 @@ """ This file is based on the code from https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py. """ -import pkg_resources +from importlib.resources import files, as_file from torchvision.datasets.vision import VisionDataset @@ -17,8 +17,10 @@ def make_custom_dataset(root, path_imgs, class_to_idx): - with open(pkg_resources.resource_filename(__name__, path_imgs), 'r') as f: - fnames = f.readlines() + resource = files(__package__) / path_imgs + with as_file(resource) as file_path: + with open(file_path, 'r') as f: + fnames = f.readlines() images = [(os.path.join(root, c.split('\n')[0]), class_to_idx[c.split('/')[0]]) for c in fnames] diff --git a/setup.py b/setup.py index 5d80a97..03159a4 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name="robustbench", - version="1.1", + version="1.1.1", author="Francesco Croce, Maksym Andriushchenko, Vikash Sehwag, Edoardo Debenedetti", author_email="adversarial.benchmark@gmail.com", description="This package provides the data for RobustBench together with the model zoo.", diff --git a/tests/custom_loader_test.py b/tests/custom_loader_test.py new file mode 100644 index 0000000..0aced8b --- /dev/null +++ b/tests/custom_loader_test.py @@ -0,0 +1,36 @@ +import robustbench +from torchvision.datasets.vision import VisionDataset + +import torch +import torch.utils.data as data +import torchvision.transforms as transforms + +from PIL import Image + +import os +import os.path +import sys + +from robustbench import data +from robustbench import loaders + +data_dir = '~/imagenet/val' + +imagenet = loaders.CustomDatasetFolder(data_dir, robustbench.loaders.default_loader, transform= + transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor()])) + +torch.manual_seed(0) + +test_loader = data.data.DataLoader(imagenet, + batch_size=50, + shuffle=True, + num_workers=3) + +x, y, path = next(iter(test_loader)) + +with open('path_imgs_2.txt', 'w') as f: + f.write('\n'.join(path)) + f.flush()