| 5 | from typing import Optional, Callable |
| 6 | |
| 7 | class CIFAR10Dataset: |
| 8 | def __init__(self, root='./data', train=True, transform=None, download=True): |
| 9 | if transform is None: |
| 10 | if train: |
| 11 | transform = transforms.Compose([ |
| 12 | transforms.RandomCrop(32, padding=4), |
| 13 | transforms.RandomHorizontalFlip(), |
| 14 | transforms.ToTensor(), |
| 15 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) |
| 16 | ]) |
| 17 | else: |
| 18 | transform = transforms.Compose([ |
| 19 | transforms.ToTensor(), |
| 20 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) |
| 21 | ]) |
| 22 | |
| 23 | self.dataset = datasets.CIFAR10(root=root, train=train, transform=transform, download=download) |
| 24 | self.classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] |
| 25 | |
| 26 | def __len__(self): |
| 27 | return len(self.dataset) |
| 28 | |
| 29 | def __getitem__(self, idx): |
| 30 | return self.dataset[idx] |
| 31 | |
| 32 | class CIFAR100Dataset: |
| 33 | def __init__(self, root='./data', train=True, transform=None, download=True): |