MCPcopy
hub / github.com/pytorch/vision / __init__

Method __init__

torchvision/datasets/caltech.py:41–80  ·  view source on GitHub ↗
(
        self,
        root: Union[str, Path],
        target_type: Union[list[str], str] = "category",
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
    )

Source from the content-addressed store, hash-verified

39 """
40
41 def __init__(
42 self,
43 root: Union[str, Path],
44 target_type: Union[list[str], str] = "category",
45 transform: Optional[Callable] = None,
46 target_transform: Optional[Callable] = None,
47 download: bool = False,
48 ) -> None:
49 super().__init__(os.path.join(root, "caltech101"), transform=transform, target_transform=target_transform)
50 os.makedirs(self.root, exist_ok=True)
51 if isinstance(target_type, str):
52 target_type = [target_type]
53 self.target_type = [verify_str_arg(t, "target_type", ("category", "annotation")) for t in target_type]
54
55 if download:
56 self.download()
57
58 if not self._check_integrity():
59 raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
60
61 self.categories = sorted(os.listdir(os.path.join(self.root, "101_ObjectCategories")))
62 self.categories.remove("BACKGROUND_Google") # this is not a real class
63
64 # For some reason, the category names in "101_ObjectCategories" and
65 # "Annotations" do not always match. This is a manual map between the
66 # two. Defaults to using same name, since most names are fine.
67 name_map = {
68 "Faces": "Faces_2",
69 "Faces_easy": "Faces_3",
70 "Motorbikes": "Motorbikes_16",
71 "airplanes": "Airplanes_Side_2",
72 }
73 self.annotation_categories = list(map(lambda x: name_map[x] if x in name_map else x, self.categories))
74
75 self.index: list[int] = []
76 self.y = []
77 for i, c in enumerate(self.categories):
78 n = len(os.listdir(os.path.join(self.root, "101_ObjectCategories", c)))
79 self.index.extend(range(1, n + 1))
80 self.y.extend(n * [i])
81
82 def __getitem__(self, index: int) -> tuple[Any, Any]:
83 """

Callers 1

__init__Method · 0.45

Calls 3

downloadMethod · 0.95
_check_integrityMethod · 0.95
verify_str_argFunction · 0.85

Tested by

no test coverage detected