| 72 | return self.dataset[idx] |
| 73 | |
| 74 | class FashionMNISTDataset: |
| 75 | def __init__(self, root='./data', train=True, transform=None, download=True): |
| 76 | if transform is None: |
| 77 | transform = transforms.Compose([ |
| 78 | transforms.ToTensor(), |
| 79 | transforms.Normalize((0.2860,), (0.3530,)) |
| 80 | ]) |
| 81 | |
| 82 | self.dataset = datasets.FashionMNIST(root=root, train=train, transform=transform, download=download) |
| 83 | self.classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', |
| 84 | 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'] |
| 85 | |
| 86 | def __len__(self): |
| 87 | return len(self.dataset) |
| 88 | |
| 89 | def __getitem__(self, idx): |
| 90 | return self.dataset[idx] |
| 91 | |
| 92 | class STL10Dataset: |
| 93 | def __init__(self, root='./data', split='train', transform=None, download=True): |