(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False)
| 1661 | |
| 1662 | @torch.no_grad() |
| 1663 | def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False): |
| 1664 | # note: restricted to non-trainable encoders currently |
| 1665 | assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpainting' |
| 1666 | z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, |
| 1667 | force_c_encode=True, return_original_cond=True, bs=bs) |
| 1668 | |
| 1669 | assert exists(self.concat_keys) |
| 1670 | c_cat = list() |
| 1671 | for ck in self.concat_keys: |
| 1672 | cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float() |
| 1673 | if bs is not None: |
| 1674 | cc = cc[:bs] |
| 1675 | cc = cc.to(self.device) |
| 1676 | bchw = z.shape |
| 1677 | if ck != self.masked_image_key: |
| 1678 | cc = torch.nn.functional.interpolate(cc, size=bchw[-2:]) |
| 1679 | else: |
| 1680 | cc = self.get_first_stage_encoding(self.encode_first_stage(cc)) |
| 1681 | c_cat.append(cc) |
| 1682 | c_cat = torch.cat(c_cat, dim=1) |
| 1683 | all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} |
| 1684 | if return_first_stage_outputs: |
| 1685 | return z, all_conds, x, xrec, xc |
| 1686 | return z, all_conds |
| 1687 | |
| 1688 | @torch.no_grad() |
| 1689 | def log_images(self, *args, **kwargs): |
nothing calls this directly
no test coverage detected