Load a subset of the COCO dataset. dataset_dir: The root directory of the COCO dataset. subset: What to load (train, val, minival, valminusminival) year: What dataset year to load (2014, 2017) as a string, not an integer class_ids: If provided, only loads images that
(self, dataset_dir, subset, year=DEFAULT_DATASET_YEAR, class_ids=None,
class_map=None, return_coco=False, auto_download=False)
| 90 | |
| 91 | class CocoDataset(utils.Dataset): |
| 92 | def load_coco(self, dataset_dir, subset, year=DEFAULT_DATASET_YEAR, class_ids=None, |
| 93 | class_map=None, return_coco=False, auto_download=False): |
| 94 | """Load a subset of the COCO dataset. |
| 95 | dataset_dir: The root directory of the COCO dataset. |
| 96 | subset: What to load (train, val, minival, valminusminival) |
| 97 | year: What dataset year to load (2014, 2017) as a string, not an integer |
| 98 | class_ids: If provided, only loads images that have the given classes. |
| 99 | class_map: TODO: Not implemented yet. Supports maping classes from |
| 100 | different datasets to the same class ID. |
| 101 | return_coco: If True, returns the COCO object. |
| 102 | auto_download: Automatically download and unzip MS-COCO images and annotations |
| 103 | """ |
| 104 | |
| 105 | if auto_download is True: |
| 106 | self.auto_download(dataset_dir, subset, year) |
| 107 | |
| 108 | coco = COCO("{}/annotations/instances_{}{}.json".format(dataset_dir, subset, year)) |
| 109 | if subset == "minival" or subset == "valminusminival": |
| 110 | subset = "val" |
| 111 | image_dir = "{}/{}{}".format(dataset_dir, subset, year) |
| 112 | |
| 113 | # Load all classes or a subset? |
| 114 | if not class_ids: |
| 115 | # All classes |
| 116 | class_ids = sorted(coco.getCatIds()) |
| 117 | |
| 118 | # All images or a subset? |
| 119 | if class_ids: |
| 120 | image_ids = [] |
| 121 | for id in class_ids: |
| 122 | image_ids.extend(list(coco.getImgIds(catIds=[id]))) |
| 123 | # Remove duplicates |
| 124 | image_ids = list(set(image_ids)) |
| 125 | else: |
| 126 | # All images |
| 127 | image_ids = list(coco.imgs.keys()) |
| 128 | |
| 129 | # Add classes |
| 130 | for i in class_ids: |
| 131 | self.add_class("coco", i, coco.loadCats(i)[0]["name"]) |
| 132 | |
| 133 | # Add images |
| 134 | for i in image_ids: |
| 135 | self.add_image( |
| 136 | "coco", image_id=i, |
| 137 | path=os.path.join(image_dir, coco.imgs[i]['file_name']), |
| 138 | width=coco.imgs[i]["width"], |
| 139 | height=coco.imgs[i]["height"], |
| 140 | annotations=coco.loadAnns(coco.getAnnIds( |
| 141 | imgIds=[i], catIds=class_ids, iscrowd=None))) |
| 142 | if return_coco: |
| 143 | return coco |
| 144 | |
| 145 | def auto_download(self, dataDir, dataType, dataYear): |
| 146 | """Download the COCO dataset/annotations if requested. |
no test coverage detected