| 30 | return self.dataset[idx] |
| 31 | |
| 32 | class CIFAR100Dataset: |
| 33 | def __init__(self, root='./data', train=True, transform=None, download=True): |
| 34 | if transform is None: |
| 35 | if train: |
| 36 | transform = transforms.Compose([ |
| 37 | transforms.RandomCrop(32, padding=4), |
| 38 | transforms.RandomHorizontalFlip(), |
| 39 | transforms.RandomRotation(15), |
| 40 | transforms.ToTensor(), |
| 41 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)) |
| 42 | ]) |
| 43 | else: |
| 44 | transform = transforms.Compose([ |
| 45 | transforms.ToTensor(), |
| 46 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)) |
| 47 | ]) |
| 48 | |
| 49 | self.dataset = datasets.CIFAR100(root=root, train=train, transform=transform, download=download) |
| 50 | |
| 51 | def __len__(self): |
| 52 | return len(self.dataset) |
| 53 | |
| 54 | def __getitem__(self, idx): |
| 55 | return self.dataset[idx] |
| 56 | |
| 57 | class MNISTDataset: |
| 58 | def __init__(self, root='./data', train=True, transform=None, download=True): |