| 1705 | |
| 1706 | @torch.no_grad() |
| 1707 | def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False): |
| 1708 | # note: restricted to non-trainable encoders currently |
| 1709 | assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for depth2img' |
| 1710 | z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, |
| 1711 | force_c_encode=True, return_original_cond=True, bs=bs) |
| 1712 | |
| 1713 | assert exists(self.concat_keys) |
| 1714 | assert len(self.concat_keys) == 1 |
| 1715 | c_cat = list() |
| 1716 | for ck in self.concat_keys: |
| 1717 | cc = batch[ck] |
| 1718 | if bs is not None: |
| 1719 | cc = cc[:bs] |
| 1720 | cc = cc.to(self.device) |
| 1721 | cc = self.depth_model(cc) |
| 1722 | cc = torch.nn.functional.interpolate( |
| 1723 | cc, |
| 1724 | size=z.shape[2:], |
| 1725 | mode="bicubic", |
| 1726 | align_corners=False, |
| 1727 | ) |
| 1728 | |
| 1729 | depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3], |
| 1730 | keepdim=True) |
| 1731 | cc = 2. * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1. |
| 1732 | c_cat.append(cc) |
| 1733 | c_cat = torch.cat(c_cat, dim=1) |
| 1734 | all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} |
| 1735 | if return_first_stage_outputs: |
| 1736 | return z, all_conds, x, xrec, xc |
| 1737 | return z, all_conds |
| 1738 | |
| 1739 | @torch.no_grad() |
| 1740 | def log_images(self, *args, **kwargs): |