A dataset to prepare the instance and class images with the prompts for fine-tuning the model. It pre-processes the images and the tokenizes prompts.
| 273 | |
| 274 | |
| 275 | class CustomDiffusionDataset(Dataset): |
| 276 | """ |
| 277 | A dataset to prepare the instance and class images with the prompts for fine-tuning the model. |
| 278 | It pre-processes the images and the tokenizes prompts. |
| 279 | """ |
| 280 | |
| 281 | def __init__( |
| 282 | self, |
| 283 | concepts_list, |
| 284 | tokenizer, |
| 285 | size=512, |
| 286 | center_crop=False, |
| 287 | with_prior_preservation=False, |
| 288 | num_class_images=200, |
| 289 | hflip=False, |
| 290 | ): |
| 291 | self.size = size |
| 292 | self.center_crop = center_crop |
| 293 | self.tokenizer = tokenizer |
| 294 | self.interpolation = PIL.Image.BILINEAR |
| 295 | |
| 296 | self.instance_images_path = [] |
| 297 | self.class_images_path = [] |
| 298 | self.with_prior_preservation = with_prior_preservation |
| 299 | for concept in concepts_list: |
| 300 | inst_img_path = [(x, concept["instance_prompt"]) for x in Path(concept["instance_data_dir"]).iterdir() if x.is_file()] |
| 301 | self.instance_images_path.extend(inst_img_path) |
| 302 | |
| 303 | if with_prior_preservation: |
| 304 | class_data_root = Path(concept["class_data_dir"]) |
| 305 | if os.path.isdir(class_data_root): |
| 306 | class_images_path = list(class_data_root.iterdir()) |
| 307 | class_prompt = [concept["class_prompt"] for _ in range(len(class_images_path))] |
| 308 | else: |
| 309 | with open(class_data_root, "r") as f: |
| 310 | class_images_path = f.read().splitlines() |
| 311 | with open(concept["class_prompt"], "r") as f: |
| 312 | class_prompt = f.read().splitlines() |
| 313 | |
| 314 | class_img_path = [(x, y) for (x, y) in zip(class_images_path, class_prompt)] |
| 315 | self.class_images_path.extend(class_img_path[:num_class_images]) |
| 316 | |
| 317 | random.shuffle(self.instance_images_path) |
| 318 | self.num_instance_images = len(self.instance_images_path) |
| 319 | self.num_class_images = len(self.class_images_path) |
| 320 | self._length = max(self.num_class_images, self.num_instance_images) |
| 321 | self.flip = transforms.RandomHorizontalFlip(0.5 * hflip) |
| 322 | |
| 323 | self.image_transforms = transforms.Compose( |
| 324 | [ |
| 325 | self.flip, |
| 326 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), |
| 327 | transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), |
| 328 | transforms.ToTensor(), |
| 329 | transforms.Normalize([0.5], [0.5]), |
| 330 | ] |
| 331 | ) |
| 332 |