| 90 | return self.dataset[idx] |
| 91 | |
| 92 | class STL10Dataset: |
| 93 | def __init__(self, root='./data', split='train', transform=None, download=True): |
| 94 | if transform is None: |
| 95 | if split == 'train': |
| 96 | transform = transforms.Compose([ |
| 97 | transforms.RandomCrop(96, padding=12), |
| 98 | transforms.RandomHorizontalFlip(), |
| 99 | transforms.ToTensor(), |
| 100 | transforms.Normalize((0.4467, 0.4398, 0.4066), (0.2603, 0.2566, 0.2713)) |
| 101 | ]) |
| 102 | else: |
| 103 | transform = transforms.Compose([ |
| 104 | transforms.ToTensor(), |
| 105 | transforms.Normalize((0.4467, 0.4398, 0.4066), (0.2603, 0.2566, 0.2713)) |
| 106 | ]) |
| 107 | |
| 108 | self.dataset = datasets.STL10(root=root, split=split, transform=transform, download=download) |
| 109 | self.classes = ['airplane', 'bird', 'car', 'cat', 'deer', 'dog', 'horse', 'monkey', 'ship', 'truck'] |
| 110 | |
| 111 | def __len__(self): |
| 112 | return len(self.dataset) |
| 113 | |
| 114 | def __getitem__(self, idx): |
| 115 | return self.dataset[idx] |
| 116 | |
| 117 | def get_dataset(name='cifar10', root='./data', train=True, download=True): |
| 118 | name = name.lower() |