MCPcopy
hub / github.com/whai362/PVT / INatDataset

Class INatDataset

classification/datasets.py:14–54  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

12
13
14class INatDataset(ImageFolder):
15 def __init__(self, root, train=True, year=2018, transform=None, target_transform=None,
16 category='name', loader=default_loader):
17 self.transform = transform
18 self.loader = loader
19 self.target_transform = target_transform
20 self.year = year
21 # assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name']
22 path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json')
23 with open(path_json) as json_file:
24 data = json.load(json_file)
25
26 with open(os.path.join(root, 'categories.json')) as json_file:
27 data_catg = json.load(json_file)
28
29 path_json_for_targeter = os.path.join(root, f"train{year}.json")
30
31 with open(path_json_for_targeter) as json_file:
32 data_for_targeter = json.load(json_file)
33
34 targeter = {}
35 indexer = 0
36 for elem in data_for_targeter['annotations']:
37 king = []
38 king.append(data_catg[int(elem['category_id'])][category])
39 if king[0] not in targeter.keys():
40 targeter[king[0]] = indexer
41 indexer += 1
42 self.nb_classes = len(targeter)
43
44 self.samples = []
45 for elem in data['images']:
46 cut = elem['file_name'].split('/')
47 target_current = int(cut[2])
48 path_current = os.path.join(root, cut[0], cut[2], cut[3])
49
50 categors = data_catg[target_current]
51 target_current_true = targeter[categors[category]]
52 self.samples.append((path_current, target_current_true))
53
54 # __getitem__ and __len__ inherited from ImageFolder
55
56
57def build_dataset(is_train, args):

Callers 1

build_datasetFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected