| 18 | from timm.data import ImageDataset as TimmDatasetTar |
| 19 | |
| 20 | class INatDataset(ImageFolder): |
| 21 | def __init__(self, root, train=True, year=2018, transform=None, target_transform=None, |
| 22 | category='name', loader=default_loader): |
| 23 | self.transform = transform |
| 24 | self.loader = loader |
| 25 | self.target_transform = target_transform |
| 26 | self.year = year |
| 27 | # assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name'] |
| 28 | path_json = os.path.join( |
| 29 | root, f'{"train" if train else "val"}{year}.json') |
| 30 | with open(path_json) as json_file: |
| 31 | data = json.load(json_file) |
| 32 | |
| 33 | with open(os.path.join(root, 'categories.json')) as json_file: |
| 34 | data_catg = json.load(json_file) |
| 35 | |
| 36 | path_json_for_targeter = os.path.join(root, f"train{year}.json") |
| 37 | |
| 38 | with open(path_json_for_targeter) as json_file: |
| 39 | data_for_targeter = json.load(json_file) |
| 40 | |
| 41 | targeter = {} |
| 42 | indexer = 0 |
| 43 | for elem in data_for_targeter['annotations']: |
| 44 | king = [] |
| 45 | king.append(data_catg[int(elem['category_id'])][category]) |
| 46 | if king[0] not in targeter.keys(): |
| 47 | targeter[king[0]] = indexer |
| 48 | indexer += 1 |
| 49 | self.nb_classes = len(targeter) |
| 50 | |
| 51 | self.samples = [] |
| 52 | for elem in data['images']: |
| 53 | cut = elem['file_name'].split('/') |
| 54 | target_current = int(cut[2]) |
| 55 | path_current = os.path.join(root, cut[0], cut[2], cut[3]) |
| 56 | |
| 57 | categors = data_catg[target_current] |
| 58 | target_current_true = targeter[categors[category]] |
| 59 | self.samples.append((path_current, target_current_true)) |
| 60 | |
| 61 | # __getitem__ and __len__ inherited from ImageFolder |
| 62 | |
| 63 | |
| 64 | def build_dataset(is_train, args): |