MCPcopy Index your code
hub / github.com/adobe-research/custom-diffusion / collate_fn

Function collate_fn

src/diffusers_data_pipeline.py:233–255  ·  view source on GitHub ↗
(examples, with_prior_preservation)

Source from the content-addressed store, hash-verified

231
232
233def collate_fn(examples, with_prior_preservation):
234 input_ids = [example["instance_prompt_ids"] for example in examples]
235 pixel_values = [example["instance_images"] for example in examples]
236 mask = [example["mask"] for example in examples]
237 # Concat class and instance examples for prior preservation.
238 # We do this to avoid doing two forward passes.
239 if with_prior_preservation:
240 input_ids += [example["class_prompt_ids"] for example in examples]
241 pixel_values += [example["class_images"] for example in examples]
242 mask += [example["class_mask"] for example in examples]
243
244 input_ids = torch.cat(input_ids, dim=0)
245 pixel_values = torch.stack(pixel_values)
246 mask = torch.stack(mask)
247 pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
248 mask = mask.to(memory_format=torch.contiguous_format).float()
249
250 batch = {
251 "input_ids": input_ids,
252 "pixel_values": pixel_values,
253 "mask": mask.unsqueeze(1)
254 }
255 return batch
256
257
258class PromptDataset(Dataset):

Callers 2

mainFunction · 0.90
mainFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected