(self, x0, t, use_original_steps=False, noise=None)
| 281 | |
| 282 | @torch.no_grad() |
| 283 | def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): |
| 284 | # fast, but does not allow for exact reconstruction |
| 285 | # t serves as an index to gather the correct alphas |
| 286 | if use_original_steps: |
| 287 | sqrt_alphas_cumprod = self.sqrt_alphas_cumprod |
| 288 | sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod |
| 289 | else: |
| 290 | sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) |
| 291 | sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas |
| 292 | |
| 293 | if noise is None: |
| 294 | noise = torch.randn_like(x0) |
| 295 | return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + |
| 296 | extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) |
| 297 | |
| 298 | @torch.no_grad() |
| 299 | def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, |
nothing calls this directly
no test coverage detected