(self, batch, **args)
| 263 | |
| 264 | @torch.no_grad() |
| 265 | def get_input_withmask(self, batch, **args): |
| 266 | out = super().get_input(batch, self.first_stage_key, **args) |
| 267 | mask = batch["mask"] |
| 268 | if len(mask.shape) == 3: |
| 269 | mask = mask[..., None] |
| 270 | mask = rearrange(mask, 'b h w c -> b c h w') |
| 271 | mask = mask.to(memory_format=torch.contiguous_format).float() |
| 272 | out += [mask] |
| 273 | return out |
| 274 | |
| 275 | def training_step(self, batch, batch_idx): |
| 276 | if isinstance(batch, list): |