| 254 | return self.dataset[idx] |
| 255 | |
| 256 | class OxfordPetsDataset: |
| 257 | def __init__(self, root='./data', split='trainval', transform=None, download=True): |
| 258 | if transform is None: |
| 259 | transform = transforms.Compose([ |
| 260 | transforms.Resize(256), |
| 261 | transforms.CenterCrop(224), |
| 262 | transforms.ToTensor(), |
| 263 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| 264 | ]) |
| 265 | |
| 266 | self.dataset = datasets.OxfordIIITPet(root=root, split=split, transform=transform, download=download) |
| 267 | |
| 268 | def __len__(self): |
| 269 | return len(self.dataset) |
| 270 | |
| 271 | def __getitem__(self, idx): |
| 272 | return self.dataset[idx] |
| 273 | |
| 274 | def get_dataset(name='cifar10', root='./data', train=True, download=True): |
| 275 | name = name.lower() |