| 14 | |
| 15 | class Flowers(ImageFolder): |
| 16 | def __init__(self, root, train=True, transform=None, **kwargs): |
| 17 | self.dataset_root = root |
| 18 | self.loader = default_loader |
| 19 | self.target_transform = None |
| 20 | self.transform = transform |
| 21 | label_path = os.path.join(root, 'imagelabels.mat') |
| 22 | split_path = os.path.join(root, 'setid.mat') |
| 23 | |
| 24 | print('Dataset Flowers is trained with resolution 224!') |
| 25 | |
| 26 | # labels |
| 27 | labels = sio.loadmat(label_path)['labels'][0] |
| 28 | self.img_to_label = dict() |
| 29 | for i in range(len(labels)): |
| 30 | self.img_to_label[i] = labels[i] |
| 31 | |
| 32 | splits = sio.loadmat(split_path) |
| 33 | self.trnid, self.valid, self.tstid = sorted(splits['trnid'][0].tolist()), \ |
| 34 | sorted(splits['valid'][0].tolist()), \ |
| 35 | sorted(splits['tstid'][0].tolist()) |
| 36 | if train: |
| 37 | self.imgs = self.trnid + self.valid |
| 38 | else: |
| 39 | self.imgs = self.tstid |
| 40 | |
| 41 | self.samples = [] |
| 42 | for item in self.imgs: |
| 43 | self.samples.append((os.path.join(root, 'jpg', "image_{:05d}.jpg".format(item)), self.img_to_label[item-1]-1)) |
| 44 | |
| 45 | class Cars196(ImageFolder, datasets.CIFAR10): |
| 46 | base_folder_devkit = 'devkit' |