| 208 | return self.dataset[idx] |
| 209 | |
| 210 | class Food101Dataset: |
| 211 | def __init__(self, root='./data', split='train', transform=None, download=True): |
| 212 | if transform is None: |
| 213 | if split == 'train': |
| 214 | transform = transforms.Compose([ |
| 215 | transforms.RandomResizedCrop(224), |
| 216 | transforms.RandomHorizontalFlip(), |
| 217 | transforms.RandomRotation(15), |
| 218 | transforms.ColorJitter(0.3, 0.3, 0.3), |
| 219 | transforms.ToTensor(), |
| 220 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| 221 | ]) |
| 222 | else: |
| 223 | transform = transforms.Compose([ |
| 224 | transforms.Resize(256), |
| 225 | transforms.CenterCrop(224), |
| 226 | transforms.ToTensor(), |
| 227 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| 228 | ]) |
| 229 | |
| 230 | self.dataset = datasets.Food101(root=root, split=split, transform=transform, download=download) |
| 231 | |
| 232 | def __len__(self): |
| 233 | return len(self.dataset) |
| 234 | |
| 235 | def __getitem__(self, idx): |
| 236 | return self.dataset[idx] |
| 237 | |
| 238 | class Caltech256Dataset: |
| 239 | def __init__(self, root='./data', transform=None, download=True): |