(self, num_points, transforms=None, train=True, download=True)
| 21 | class ModelNet40Cls(data.Dataset): |
| 22 | |
| 23 | def __init__(self, num_points, transforms=None, train=True, download=True): |
| 24 | super().__init__() |
| 25 | |
| 26 | self.transforms = transforms |
| 27 | |
| 28 | self.folder = "modelnet40_ply_hdf5_2048" |
| 29 | self.data_dir = os.path.join(BASE_DIR, self.folder) |
| 30 | self.url = "https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip" |
| 31 | |
| 32 | if download and not os.path.exists(self.data_dir): |
| 33 | zipfile = os.path.join(BASE_DIR, os.path.basename(self.url)) |
| 34 | subprocess.check_call( |
| 35 | shlex.split("curl {} -o {}".format(self.url, zipfile)) |
| 36 | ) |
| 37 | |
| 38 | subprocess.check_call( |
| 39 | shlex.split("unzip {} -d {}".format(zipfile, BASE_DIR)) |
| 40 | ) |
| 41 | |
| 42 | subprocess.check_call(shlex.split("rm {}".format(zipfile))) |
| 43 | |
| 44 | self.train, self.num_points = train, num_points |
| 45 | if self.train: |
| 46 | self.files = _get_data_files( \ |
| 47 | os.path.join(self.data_dir, 'train_files.txt')) |
| 48 | else: |
| 49 | self.files = _get_data_files( \ |
| 50 | os.path.join(self.data_dir, 'test_files.txt')) |
| 51 | |
| 52 | point_list, label_list = [], [] |
| 53 | for f in self.files: |
| 54 | points, labels = _load_data_file(os.path.join(BASE_DIR, f)) |
| 55 | point_list.append(points) |
| 56 | label_list.append(labels) |
| 57 | |
| 58 | self.points = np.concatenate(point_list, 0) |
| 59 | self.labels = np.concatenate(label_list, 0) |
| 60 | |
| 61 | self.randomize() |
| 62 | |
| 63 | def __getitem__(self, idx): |
| 64 | pt_idxs = np.arange(0, self.actual_number_of_points) |
nothing calls this directly
no test coverage detected