MCPcopy
hub / github.com/ali-vilab/AnyDoor / get_input

Method get_input

ldm/models/diffusion/ddpm.py:1707–1737  ·  view source on GitHub ↗
(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False)

Source from the content-addressed store, hash-verified

1705
1706 @torch.no_grad()
1707 def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
1708 # note: restricted to non-trainable encoders currently
1709 assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for depth2img'
1710 z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
1711 force_c_encode=True, return_original_cond=True, bs=bs)
1712
1713 assert exists(self.concat_keys)
1714 assert len(self.concat_keys) == 1
1715 c_cat = list()
1716 for ck in self.concat_keys:
1717 cc = batch[ck]
1718 if bs is not None:
1719 cc = cc[:bs]
1720 cc = cc.to(self.device)
1721 cc = self.depth_model(cc)
1722 cc = torch.nn.functional.interpolate(
1723 cc,
1724 size=z.shape[2:],
1725 mode="bicubic",
1726 align_corners=False,
1727 )
1728
1729 depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3],
1730 keepdim=True)
1731 cc = 2. * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1.
1732 c_cat.append(cc)
1733 c_cat = torch.cat(c_cat, dim=1)
1734 all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
1735 if return_first_stage_outputs:
1736 return z, all_conds, x, xrec, xc
1737 return z, all_conds
1738
1739 @torch.no_grad()
1740 def log_images(self, *args, **kwargs):

Callers

nothing calls this directly

Calls 2

existsFunction · 0.90
get_inputMethod · 0.45

Tested by

no test coverage detected