| 219 | |
| 220 | |
| 221 | class HeunEDMSampler(EDMSampler): |
| 222 | def possible_correction_step( |
| 223 | self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc |
| 224 | ): |
| 225 | if torch.sum(next_sigma) < 1e-14: |
| 226 | # Save a network evaluation if all noise levels are 0 |
| 227 | return euler_step |
| 228 | else: |
| 229 | denoised = self.denoise(euler_step, denoiser, next_sigma, cond, uc) |
| 230 | d_new = to_d(euler_step, next_sigma, denoised) |
| 231 | d_prime = (d + d_new) / 2.0 |
| 232 | |
| 233 | # apply correction if noise level is not 0 |
| 234 | x = torch.where( |
| 235 | append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step |
| 236 | ) |
| 237 | return x |
| 238 | |
| 239 | |
| 240 | class EulerAncestralSampler(AncestralSampler): |
no outgoing calls
no test coverage detected