MCPcopy Index your code
hub / github.com/adobe-research/custom-diffusion / get_input_withmask

Method get_input_withmask

src/model.py:265–273  ·  view source on GitHub ↗
(self, batch, **args)

Source from the content-addressed store, hash-verified

263
264 @torch.no_grad()
265 def get_input_withmask(self, batch, **args):
266 out = super().get_input(batch, self.first_stage_key, **args)
267 mask = batch["mask"]
268 if len(mask.shape) == 3:
269 mask = mask[..., None]
270 mask = rearrange(mask, 'b h w c -> b c h w')
271 mask = mask.to(memory_format=torch.contiguous_format).float()
272 out += [mask]
273 return out
274
275 def training_step(self, batch, batch_idx):
276 if isinstance(batch, list):

Callers 1

shared_stepMethod · 0.95

Calls

no outgoing calls

Tested by

no test coverage detected