(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None)
| 252 | |
| 253 | @torch.no_grad() |
| 254 | def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None, |
| 255 | unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None): |
| 256 | num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0] |
| 257 | |
| 258 | assert t_enc <= num_reference_steps |
| 259 | num_steps = t_enc |
| 260 | |
| 261 | if use_original_steps: |
| 262 | alphas_next = self.alphas_cumprod[:num_steps] |
| 263 | alphas = self.alphas_cumprod_prev[:num_steps] |
| 264 | else: |
| 265 | alphas_next = self.ddim_alphas[:num_steps] |
| 266 | alphas = torch.tensor(self.ddim_alphas_prev[:num_steps]) |
| 267 | |
| 268 | x_next = x0 |
| 269 | intermediates = [] |
| 270 | inter_steps = [] |
| 271 | for i in tqdm(range(num_steps), desc='Encoding Image'): |
| 272 | t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long) |
| 273 | if unconditional_guidance_scale == 1.: |
| 274 | noise_pred = self.model.apply_model(x_next, t, c) |
| 275 | else: |
| 276 | assert unconditional_conditioning is not None |
| 277 | e_t_uncond, noise_pred = torch.chunk( |
| 278 | self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)), |
| 279 | torch.cat((unconditional_conditioning, c))), 2) |
| 280 | noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond) |
| 281 | |
| 282 | xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next |
| 283 | weighted_noise_pred = alphas_next[i].sqrt() * ( |
| 284 | (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred |
| 285 | x_next = xt_weighted + weighted_noise_pred |
| 286 | if return_intermediates and i % ( |
| 287 | num_steps // return_intermediates) == 0 and i < num_steps - 1: |
| 288 | intermediates.append(x_next) |
| 289 | inter_steps.append(i) |
| 290 | elif return_intermediates and i >= num_steps - 2: |
| 291 | intermediates.append(x_next) |
| 292 | inter_steps.append(i) |
| 293 | if callback: callback(i) |
| 294 | |
| 295 | out = {'x_encoded': x_next, 'intermediate_steps': inter_steps} |
| 296 | if return_intermediates: |
| 297 | out.update({'intermediates': intermediates}) |
| 298 | return x_next, out |
| 299 | |
| 300 | @torch.no_grad() |
| 301 | def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): |
no test coverage detected