For a dataset, create a generator over (images, kwargs) pairs. Each images is an NCHW float tensor, and the kwargs dict contains zero or more keys, each of which map to a batched Tensor of their own. The kwargs dict can be used for class labels, in which case the key is "y" and
(
*, data_dir, batch_size, image_size, class_cond=False, deterministic=False
)
| 6 | |
| 7 | |
| 8 | def load_data( |
| 9 | *, data_dir, batch_size, image_size, class_cond=False, deterministic=False |
| 10 | ): |
| 11 | """ |
| 12 | For a dataset, create a generator over (images, kwargs) pairs. |
| 13 | |
| 14 | Each images is an NCHW float tensor, and the kwargs dict contains zero or |
| 15 | more keys, each of which map to a batched Tensor of their own. |
| 16 | The kwargs dict can be used for class labels, in which case the key is "y" |
| 17 | and the values are integer tensors of class labels. |
| 18 | |
| 19 | :param data_dir: a dataset directory. |
| 20 | :param batch_size: the batch size of each returned pair. |
| 21 | :param image_size: the size to which images are resized. |
| 22 | :param class_cond: if True, include a "y" key in returned dicts for class |
| 23 | label. If classes are not available and this is true, an |
| 24 | exception will be raised. |
| 25 | :param deterministic: if True, yield results in a deterministic order. |
| 26 | """ |
| 27 | if not data_dir: |
| 28 | raise ValueError("unspecified data directory") |
| 29 | all_files = _list_image_files_recursively(data_dir) |
| 30 | classes = None |
| 31 | if class_cond: |
| 32 | # Assume classes are the first part of the filename, |
| 33 | # before an underscore. |
| 34 | class_names = [bf.basename(path).split("_")[0] for path in all_files] |
| 35 | sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))} |
| 36 | classes = [sorted_classes[x] for x in class_names] |
| 37 | dataset = ImageDataset( |
| 38 | image_size, |
| 39 | all_files, |
| 40 | classes=classes, |
| 41 | shard=MPI.COMM_WORLD.Get_rank(), |
| 42 | num_shards=MPI.COMM_WORLD.Get_size(), |
| 43 | ) |
| 44 | if deterministic: |
| 45 | loader = DataLoader( |
| 46 | dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True |
| 47 | ) |
| 48 | else: |
| 49 | loader = DataLoader( |
| 50 | dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True |
| 51 | ) |
| 52 | while True: |
| 53 | yield from loader |
| 54 | |
| 55 | |
| 56 | def _list_image_files_recursively(data_dir): |
no test coverage detected