| 115 | return self.dataset[idx] |
| 116 | |
| 117 | def get_dataset(name='cifar10', root='./data', train=True, download=True): |
| 118 | name = name.lower() |
| 119 | |
| 120 | if name == 'cifar10': |
| 121 | return CIFAR10Dataset(root=root, train=train, download=download) |
| 122 | elif name == 'cifar100': |
| 123 | return CIFAR100Dataset(root=root, train=train, download=download) |
| 124 | elif name == 'mnist': |
| 125 | return MNISTDataset(root=root, train=train, download=download) |
| 126 | elif name == 'fashion_mnist' or name == 'fashionmnist': |
| 127 | return FashionMNISTDataset(root=root, train=train, download=download) |
| 128 | elif name == 'stl10': |
| 129 | split = 'train' if train else 'test' |
| 130 | return STL10Dataset(root=root, split=split, download=download) |
| 131 | else: |
| 132 | raise ValueError(f"Unknown dataset: {name}") |
| 133 | |
| 134 | class ImageNetDataset: |
| 135 | def __init__(self, root='./data/imagenet', split='train', transform=None, download=False): |