MCPcopy
hub / github.com/ali-vilab/AnyDoor / p_sample_loop

Method p_sample_loop

ldm/models/diffusion/ddpm.py:1053–1101  ·  view source on GitHub ↗
(self, cond, shape, return_intermediates=False,
                      x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
                      mask=None, x0=None, img_callback=None, start_T=None,
                      log_every_t=None)

Source from the content-addressed store, hash-verified

1051
1052 @torch.no_grad()
1053 def p_sample_loop(self, cond, shape, return_intermediates=False,
1054 x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
1055 mask=None, x0=None, img_callback=None, start_T=None,
1056 log_every_t=None):
1057
1058 if not log_every_t:
1059 log_every_t = self.log_every_t
1060 device = self.betas.device
1061 b = shape[0]
1062 if x_T is None:
1063 img = torch.randn(shape, device=device)
1064 else:
1065 img = x_T
1066
1067 intermediates = [img]
1068 if timesteps is None:
1069 timesteps = self.num_timesteps
1070
1071 if start_T is not None:
1072 timesteps = min(timesteps, start_T)
1073 iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
1074 range(0, timesteps))
1075
1076 if mask is not None:
1077 assert x0 is not None
1078 assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
1079
1080 for i in iterator:
1081 ts = torch.full((b,), i, device=device, dtype=torch.long)
1082 if self.shorten_cond_schedule:
1083 assert self.model.conditioning_key != 'hybrid'
1084 tc = self.cond_ids[ts].to(cond.device)
1085 cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1086
1087 img = self.p_sample(img, cond, ts,
1088 clip_denoised=self.clip_denoised,
1089 quantize_denoised=quantize_denoised)
1090 if mask is not None:
1091 img_orig = self.q_sample(x0, ts)
1092 img = img_orig * mask + (1. - mask) * img
1093
1094 if i % log_every_t == 0 or i == timesteps - 1:
1095 intermediates.append(img)
1096 if callback: callback(i)
1097 if img_callback: img_callback(img, i)
1098
1099 if return_intermediates:
1100 return img, intermediates
1101 return img
1102
1103 @torch.no_grad()
1104 def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,

Callers 1

sampleMethod · 0.95

Calls 2

p_sampleMethod · 0.95
q_sampleMethod · 0.45

Tested by

no test coverage detected