MCPcopy
hub / github.com/ali-vilab/AnyDoor / progressive_denoising

Method progressive_denoising

ldm/models/diffusion/ddpm.py:997–1050  ·  view source on GitHub ↗
(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
                              img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
                              score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
                              log_every_t=None)

Source from the content-addressed store, hash-verified

995
996 @torch.no_grad()
997 def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
998 img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
999 score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
1000 log_every_t=None):
1001 if not log_every_t:
1002 log_every_t = self.log_every_t
1003 timesteps = self.num_timesteps
1004 if batch_size is not None:
1005 b = batch_size if batch_size is not None else shape[0]
1006 shape = [batch_size] + list(shape)
1007 else:
1008 b = batch_size = shape[0]
1009 if x_T is None:
1010 img = torch.randn(shape, device=self.device)
1011 else:
1012 img = x_T
1013 intermediates = []
1014 if cond is not None:
1015 if isinstance(cond, dict):
1016 cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
1017 list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
1018 else:
1019 cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1020
1021 if start_T is not None:
1022 timesteps = min(timesteps, start_T)
1023 iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
1024 total=timesteps) if verbose else reversed(
1025 range(0, timesteps))
1026 if type(temperature) == float:
1027 temperature = [temperature] * timesteps
1028
1029 for i in iterator:
1030 ts = torch.full((b,), i, device=self.device, dtype=torch.long)
1031 if self.shorten_cond_schedule:
1032 assert self.model.conditioning_key != 'hybrid'
1033 tc = self.cond_ids[ts].to(cond.device)
1034 cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1035
1036 img, x0_partial = self.p_sample(img, cond, ts,
1037 clip_denoised=self.clip_denoised,
1038 quantize_denoised=quantize_denoised, return_x0=True,
1039 temperature=temperature[i], noise_dropout=noise_dropout,
1040 score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
1041 if mask is not None:
1042 assert x0 is not None
1043 img_orig = self.q_sample(x0, ts)
1044 img = img_orig * mask + (1. - mask) * img
1045
1046 if i % log_every_t == 0 or i == timesteps - 1:
1047 intermediates.append(x0_partial)
1048 if callback: callback(i)
1049 if img_callback: img_callback(img, i)
1050 return img, intermediates
1051
1052 @torch.no_grad()
1053 def p_sample_loop(self, cond, shape, return_intermediates=False,

Callers 2

log_imagesMethod · 0.95
log_imagesMethod · 0.80

Calls 2

p_sampleMethod · 0.95
q_sampleMethod · 0.45

Tested by

no test coverage detected