MCPcopy Index your code
hub / github.com/adobe-research/custom-diffusion / CustomDiffusionDataset

Class CustomDiffusionDataset

src/diffusers_data_pipeline.py:275–405  ·  view source on GitHub ↗

A dataset to prepare the instance and class images with the prompts for fine-tuning the model. It pre-processes the images and the tokenizes prompts.

Source from the content-addressed store, hash-verified

273
274
275class CustomDiffusionDataset(Dataset):
276 """
277 A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
278 It pre-processes the images and the tokenizes prompts.
279 """
280
281 def __init__(
282 self,
283 concepts_list,
284 tokenizer,
285 size=512,
286 center_crop=False,
287 with_prior_preservation=False,
288 num_class_images=200,
289 hflip=False,
290 ):
291 self.size = size
292 self.center_crop = center_crop
293 self.tokenizer = tokenizer
294 self.interpolation = PIL.Image.BILINEAR
295
296 self.instance_images_path = []
297 self.class_images_path = []
298 self.with_prior_preservation = with_prior_preservation
299 for concept in concepts_list:
300 inst_img_path = [(x, concept["instance_prompt"]) for x in Path(concept["instance_data_dir"]).iterdir() if x.is_file()]
301 self.instance_images_path.extend(inst_img_path)
302
303 if with_prior_preservation:
304 class_data_root = Path(concept["class_data_dir"])
305 if os.path.isdir(class_data_root):
306 class_images_path = list(class_data_root.iterdir())
307 class_prompt = [concept["class_prompt"] for _ in range(len(class_images_path))]
308 else:
309 with open(class_data_root, "r") as f:
310 class_images_path = f.read().splitlines()
311 with open(concept["class_prompt"], "r") as f:
312 class_prompt = f.read().splitlines()
313
314 class_img_path = [(x, y) for (x, y) in zip(class_images_path, class_prompt)]
315 self.class_images_path.extend(class_img_path[:num_class_images])
316
317 random.shuffle(self.instance_images_path)
318 self.num_instance_images = len(self.instance_images_path)
319 self.num_class_images = len(self.class_images_path)
320 self._length = max(self.num_class_images, self.num_instance_images)
321 self.flip = transforms.RandomHorizontalFlip(0.5 * hflip)
322
323 self.image_transforms = transforms.Compose(
324 [
325 self.flip,
326 transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
327 transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
328 transforms.ToTensor(),
329 transforms.Normalize([0.5], [0.5]),
330 ]
331 )
332

Callers 2

mainFunction · 0.90
mainFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected