(
self,
root: Union[str, Path],
target_type: Union[list[str], str] = "category",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
)
| 39 | """ |
| 40 | |
| 41 | def __init__( |
| 42 | self, |
| 43 | root: Union[str, Path], |
| 44 | target_type: Union[list[str], str] = "category", |
| 45 | transform: Optional[Callable] = None, |
| 46 | target_transform: Optional[Callable] = None, |
| 47 | download: bool = False, |
| 48 | ) -> None: |
| 49 | super().__init__(os.path.join(root, "caltech101"), transform=transform, target_transform=target_transform) |
| 50 | os.makedirs(self.root, exist_ok=True) |
| 51 | if isinstance(target_type, str): |
| 52 | target_type = [target_type] |
| 53 | self.target_type = [verify_str_arg(t, "target_type", ("category", "annotation")) for t in target_type] |
| 54 | |
| 55 | if download: |
| 56 | self.download() |
| 57 | |
| 58 | if not self._check_integrity(): |
| 59 | raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it") |
| 60 | |
| 61 | self.categories = sorted(os.listdir(os.path.join(self.root, "101_ObjectCategories"))) |
| 62 | self.categories.remove("BACKGROUND_Google") # this is not a real class |
| 63 | |
| 64 | # For some reason, the category names in "101_ObjectCategories" and |
| 65 | # "Annotations" do not always match. This is a manual map between the |
| 66 | # two. Defaults to using same name, since most names are fine. |
| 67 | name_map = { |
| 68 | "Faces": "Faces_2", |
| 69 | "Faces_easy": "Faces_3", |
| 70 | "Motorbikes": "Motorbikes_16", |
| 71 | "airplanes": "Airplanes_Side_2", |
| 72 | } |
| 73 | self.annotation_categories = list(map(lambda x: name_map[x] if x in name_map else x, self.categories)) |
| 74 | |
| 75 | self.index: list[int] = [] |
| 76 | self.y = [] |
| 77 | for i, c in enumerate(self.categories): |
| 78 | n = len(os.listdir(os.path.join(self.root, "101_ObjectCategories", c))) |
| 79 | self.index.extend(range(1, n + 1)) |
| 80 | self.y.extend(n * [i]) |
| 81 | |
| 82 | def __getitem__(self, index: int) -> tuple[Any, Any]: |
| 83 | """ |
no test coverage detected