| 1051 | |
| 1052 | @torch.no_grad() |
| 1053 | def p_sample_loop(self, cond, shape, return_intermediates=False, |
| 1054 | x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False, |
| 1055 | mask=None, x0=None, img_callback=None, start_T=None, |
| 1056 | log_every_t=None): |
| 1057 | |
| 1058 | if not log_every_t: |
| 1059 | log_every_t = self.log_every_t |
| 1060 | device = self.betas.device |
| 1061 | b = shape[0] |
| 1062 | if x_T is None: |
| 1063 | img = torch.randn(shape, device=device) |
| 1064 | else: |
| 1065 | img = x_T |
| 1066 | |
| 1067 | intermediates = [img] |
| 1068 | if timesteps is None: |
| 1069 | timesteps = self.num_timesteps |
| 1070 | |
| 1071 | if start_T is not None: |
| 1072 | timesteps = min(timesteps, start_T) |
| 1073 | iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed( |
| 1074 | range(0, timesteps)) |
| 1075 | |
| 1076 | if mask is not None: |
| 1077 | assert x0 is not None |
| 1078 | assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match |
| 1079 | |
| 1080 | for i in iterator: |
| 1081 | ts = torch.full((b,), i, device=device, dtype=torch.long) |
| 1082 | if self.shorten_cond_schedule: |
| 1083 | assert self.model.conditioning_key != 'hybrid' |
| 1084 | tc = self.cond_ids[ts].to(cond.device) |
| 1085 | cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) |
| 1086 | |
| 1087 | img = self.p_sample(img, cond, ts, |
| 1088 | clip_denoised=self.clip_denoised, |
| 1089 | quantize_denoised=quantize_denoised) |
| 1090 | if mask is not None: |
| 1091 | img_orig = self.q_sample(x0, ts) |
| 1092 | img = img_orig * mask + (1. - mask) * img |
| 1093 | |
| 1094 | if i % log_every_t == 0 or i == timesteps - 1: |
| 1095 | intermediates.append(img) |
| 1096 | if callback: callback(i) |
| 1097 | if img_callback: img_callback(img, i) |
| 1098 | |
| 1099 | if return_intermediates: |
| 1100 | return img, intermediates |
| 1101 | return img |
| 1102 | |
| 1103 | @torch.no_grad() |
| 1104 | def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None, |