| 231 | |
| 232 | |
| 233 | def collate_fn(examples, with_prior_preservation): |
| 234 | input_ids = [example["instance_prompt_ids"] for example in examples] |
| 235 | pixel_values = [example["instance_images"] for example in examples] |
| 236 | mask = [example["mask"] for example in examples] |
| 237 | # Concat class and instance examples for prior preservation. |
| 238 | # We do this to avoid doing two forward passes. |
| 239 | if with_prior_preservation: |
| 240 | input_ids += [example["class_prompt_ids"] for example in examples] |
| 241 | pixel_values += [example["class_images"] for example in examples] |
| 242 | mask += [example["class_mask"] for example in examples] |
| 243 | |
| 244 | input_ids = torch.cat(input_ids, dim=0) |
| 245 | pixel_values = torch.stack(pixel_values) |
| 246 | mask = torch.stack(mask) |
| 247 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() |
| 248 | mask = mask.to(memory_format=torch.contiguous_format).float() |
| 249 | |
| 250 | batch = { |
| 251 | "input_ids": input_ids, |
| 252 | "pixel_values": pixel_values, |
| 253 | "mask": mask.unsqueeze(1) |
| 254 | } |
| 255 | return batch |
| 256 | |
| 257 | |
| 258 | class PromptDataset(Dataset): |