| 55 | return self.dataset[idx] |
| 56 | |
| 57 | class MNISTDataset: |
| 58 | def __init__(self, root='./data', train=True, transform=None, download=True): |
| 59 | if transform is None: |
| 60 | transform = transforms.Compose([ |
| 61 | transforms.ToTensor(), |
| 62 | transforms.Normalize((0.1307,), (0.3081,)) |
| 63 | ]) |
| 64 | |
| 65 | self.dataset = datasets.MNIST(root=root, train=train, transform=transform, download=download) |
| 66 | self.classes = [str(i) for i in range(10)] |
| 67 | |
| 68 | def __len__(self): |
| 69 | return len(self.dataset) |
| 70 | |
| 71 | def __getitem__(self, idx): |
| 72 | return self.dataset[idx] |
| 73 | |
| 74 | class FashionMNISTDataset: |
| 75 | def __init__(self, root='./data', train=True, transform=None, download=True): |