(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None)
| 233 | |
| 234 | @torch.no_grad() |
| 235 | def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None, |
| 236 | unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None): |
| 237 | timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps |
| 238 | num_reference_steps = timesteps.shape[0] |
| 239 | |
| 240 | assert t_enc <= num_reference_steps |
| 241 | num_steps = t_enc |
| 242 | |
| 243 | if use_original_steps: |
| 244 | alphas_next = self.alphas_cumprod[:num_steps] |
| 245 | alphas = self.alphas_cumprod_prev[:num_steps] |
| 246 | else: |
| 247 | alphas_next = self.ddim_alphas[:num_steps] |
| 248 | alphas = torch.tensor(self.ddim_alphas_prev[:num_steps]) |
| 249 | |
| 250 | x_next = x0 |
| 251 | intermediates = [] |
| 252 | inter_steps = [] |
| 253 | for i in tqdm(range(num_steps), desc='Encoding Image'): |
| 254 | t = torch.full((x0.shape[0],), timesteps[i], device=self.model.device, dtype=torch.long) |
| 255 | if unconditional_guidance_scale == 1.: |
| 256 | noise_pred = self.model.apply_model(x_next, t, c) |
| 257 | else: |
| 258 | assert unconditional_conditioning is not None |
| 259 | e_t_uncond, noise_pred = torch.chunk( |
| 260 | self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)), |
| 261 | torch.cat((unconditional_conditioning, c))), 2) |
| 262 | noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond) |
| 263 | |
| 264 | xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next |
| 265 | weighted_noise_pred = alphas_next[i].sqrt() * ( |
| 266 | (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred |
| 267 | x_next = xt_weighted + weighted_noise_pred |
| 268 | if return_intermediates and i % ( |
| 269 | num_steps // return_intermediates) == 0 and i < num_steps - 1: |
| 270 | intermediates.append(x_next) |
| 271 | inter_steps.append(i) |
| 272 | elif return_intermediates and i >= num_steps - 2: |
| 273 | intermediates.append(x_next) |
| 274 | inter_steps.append(i) |
| 275 | if callback: callback(i) |
| 276 | |
| 277 | out = {'x_encoded': x_next, 'intermediate_steps': inter_steps} |
| 278 | if return_intermediates: |
| 279 | out.update({'intermediates': intermediates}) |
| 280 | return x_next, out |
| 281 | |
| 282 | @torch.no_grad() |
| 283 | def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): |
nothing calls this directly
no test coverage detected