MCPcopy Index your code
hub / github.com/openai/guided-diffusion / ImageDataset

Class ImageDataset

guided_diffusion/image_datasets.py:82–123  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

80
81
82class ImageDataset(Dataset):
83 def __init__(
84 self,
85 resolution,
86 image_paths,
87 classes=None,
88 shard=0,
89 num_shards=1,
90 random_crop=False,
91 random_flip=True,
92 ):
93 super().__init__()
94 self.resolution = resolution
95 self.local_images = image_paths[shard:][::num_shards]
96 self.local_classes = None if classes is None else classes[shard:][::num_shards]
97 self.random_crop = random_crop
98 self.random_flip = random_flip
99
100 def __len__(self):
101 return len(self.local_images)
102
103 def __getitem__(self, idx):
104 path = self.local_images[idx]
105 with bf.BlobFile(path, "rb") as f:
106 pil_image = Image.open(f)
107 pil_image.load()
108 pil_image = pil_image.convert("RGB")
109
110 if self.random_crop:
111 arr = random_crop_arr(pil_image, self.resolution)
112 else:
113 arr = center_crop_arr(pil_image, self.resolution)
114
115 if self.random_flip and random.random() < 0.5:
116 arr = arr[:, ::-1]
117
118 arr = arr.astype(np.float32) / 127.5 - 1
119
120 out_dict = {}
121 if self.local_classes is not None:
122 out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
123 return np.transpose(arr, [2, 0, 1]), out_dict
124
125
126def center_crop_arr(pil_image, image_size):

Callers 1

load_dataFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected