Dataset for classification.
| 4 | |
| 5 | |
| 6 | class ClassificationDataset(Dataset): |
| 7 | """Dataset for classification. |
| 8 | """ |
| 9 | |
| 10 | def __init__(self, split='train', pipeline=None): |
| 11 | if split == 'train': |
| 12 | self.data_source = ImageNet(root='data/imagenet/train', |
| 13 | list_file='data/imagenet/meta/train.txt', |
| 14 | memcached=True, |
| 15 | mclient_path='/mnt/lustre/share/memcached_client') |
| 16 | else: |
| 17 | self.data_source = ImageNet(root='data/imagenet/val', |
| 18 | list_file='data/imagenet/meta/val.txt', |
| 19 | memcached=True, |
| 20 | mclient_path='/mnt/lustre/share/memcached_client') |
| 21 | self.pipeline = pipeline |
| 22 | |
| 23 | def __len__(self): |
| 24 | return self.data_source.get_length() |
| 25 | |
| 26 | def __getitem__(self, idx): |
| 27 | img, target = self.data_source.get_sample(idx) |
| 28 | if self.pipeline is not None: |
| 29 | img = self.pipeline(img) |
| 30 | |
| 31 | return img, target |