A simple dataset to prepare the prompts to generate class images on multiple GPUs.
| 256 | |
| 257 | |
| 258 | class PromptDataset(Dataset): |
| 259 | "A simple dataset to prepare the prompts to generate class images on multiple GPUs." |
| 260 | |
| 261 | def __init__(self, prompt, num_samples): |
| 262 | self.prompt = prompt |
| 263 | self.num_samples = num_samples |
| 264 | |
| 265 | def __len__(self): |
| 266 | return self.num_samples |
| 267 | |
| 268 | def __getitem__(self, index): |
| 269 | example = {} |
| 270 | example["prompt"] = self.prompt |
| 271 | example["index"] = index |
| 272 | return example |
| 273 | |
| 274 | |
| 275 | class CustomDiffusionDataset(Dataset): |