For a dataset, create a generator over (images, kwargs) pairs. Each images is an NCHW float tensor, and the kwargs dict contains zero or more keys, each of which map to a batched Tensor of their own. The kwargs dict can be used for class labels, in which case the key is "y" and
(
*,
data_dir,
batch_size,
image_size,
class_cond=False,
deterministic=False,
random_crop=False,
random_flip=True,
)
| 9 | |
| 10 | |
| 11 | def load_data( |
| 12 | *, |
| 13 | data_dir, |
| 14 | batch_size, |
| 15 | image_size, |
| 16 | class_cond=False, |
| 17 | deterministic=False, |
| 18 | random_crop=False, |
| 19 | random_flip=True, |
| 20 | ): |
| 21 | """ |
| 22 | For a dataset, create a generator over (images, kwargs) pairs. |
| 23 | |
| 24 | Each images is an NCHW float tensor, and the kwargs dict contains zero or |
| 25 | more keys, each of which map to a batched Tensor of their own. |
| 26 | The kwargs dict can be used for class labels, in which case the key is "y" |
| 27 | and the values are integer tensors of class labels. |
| 28 | |
| 29 | :param data_dir: a dataset directory. |
| 30 | :param batch_size: the batch size of each returned pair. |
| 31 | :param image_size: the size to which images are resized. |
| 32 | :param class_cond: if True, include a "y" key in returned dicts for class |
| 33 | label. If classes are not available and this is true, an |
| 34 | exception will be raised. |
| 35 | :param deterministic: if True, yield results in a deterministic order. |
| 36 | :param random_crop: if True, randomly crop the images for augmentation. |
| 37 | :param random_flip: if True, randomly flip the images for augmentation. |
| 38 | """ |
| 39 | if not data_dir: |
| 40 | raise ValueError("unspecified data directory") |
| 41 | all_files = _list_image_files_recursively(data_dir) |
| 42 | classes = None |
| 43 | if class_cond: |
| 44 | # Assume classes are the first part of the filename, |
| 45 | # before an underscore. |
| 46 | class_names = [bf.basename(path).split("_")[0] for path in all_files] |
| 47 | sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))} |
| 48 | classes = [sorted_classes[x] for x in class_names] |
| 49 | dataset = ImageDataset( |
| 50 | image_size, |
| 51 | all_files, |
| 52 | classes=classes, |
| 53 | shard=MPI.COMM_WORLD.Get_rank(), |
| 54 | num_shards=MPI.COMM_WORLD.Get_size(), |
| 55 | random_crop=random_crop, |
| 56 | random_flip=random_flip, |
| 57 | ) |
| 58 | if deterministic: |
| 59 | loader = DataLoader( |
| 60 | dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True |
| 61 | ) |
| 62 | else: |
| 63 | loader = DataLoader( |
| 64 | dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True |
| 65 | ) |
| 66 | while True: |
| 67 | yield from loader |
| 68 |
no test coverage detected