MCPcopy
hub / github.com/ali-vilab/AnyDoor / get_input

Method get_input

ldm/models/diffusion/ddpm.py:1663–1686  ·  view source on GitHub ↗
(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False)

Source from the content-addressed store, hash-verified

1661
1662 @torch.no_grad()
1663 def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
1664 # note: restricted to non-trainable encoders currently
1665 assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpainting'
1666 z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
1667 force_c_encode=True, return_original_cond=True, bs=bs)
1668
1669 assert exists(self.concat_keys)
1670 c_cat = list()
1671 for ck in self.concat_keys:
1672 cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()
1673 if bs is not None:
1674 cc = cc[:bs]
1675 cc = cc.to(self.device)
1676 bchw = z.shape
1677 if ck != self.masked_image_key:
1678 cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
1679 else:
1680 cc = self.get_first_stage_encoding(self.encode_first_stage(cc))
1681 c_cat.append(cc)
1682 c_cat = torch.cat(c_cat, dim=1)
1683 all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
1684 if return_first_stage_outputs:
1685 return z, all_conds, x, xrec, xc
1686 return z, all_conds
1687
1688 @torch.no_grad()
1689 def log_images(self, *args, **kwargs):

Callers

nothing calls this directly

Calls 4

existsFunction · 0.90
encode_first_stageMethod · 0.80
get_inputMethod · 0.45

Tested by

no test coverage detected