(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0)
| 92 | self.s_noise = s_noise |
| 93 | |
| 94 | def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0): |
| 95 | sigma_hat = sigma * (gamma + 1.0) |
| 96 | if gamma > 0: |
| 97 | eps = torch.randn_like(x) * self.s_noise |
| 98 | x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5 |
| 99 | |
| 100 | denoised = self.denoise(x, denoiser, sigma_hat, cond, uc) |
| 101 | d = to_d(x, sigma_hat, denoised) |
| 102 | dt = append_dims(next_sigma - sigma_hat, x.ndim) |
| 103 | |
| 104 | euler_step = self.euler_step(x, d, dt) |
| 105 | x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc) |
| 106 | return x |
| 107 | |
| 108 | def __call__(self, denoiser, x, cond, uc=None, num_steps=None): |
| 109 | x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) |
no test coverage detected