Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 for the given timesteps using the provided noise scheduler. Args: noise_scheduler (`NoiseScheduler`):
(noise_scheduler, timesteps)
| 74 | |
| 75 | |
| 76 | def compute_snr(noise_scheduler, timesteps): |
| 77 | """ |
| 78 | Computes SNR as per |
| 79 | https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 |
| 80 | for the given timesteps using the provided noise scheduler. |
| 81 | |
| 82 | Args: |
| 83 | noise_scheduler (`NoiseScheduler`): |
| 84 | An object containing the noise schedule parameters, specifically `alphas_cumprod`, which is used to compute |
| 85 | the SNR values. |
| 86 | timesteps (`torch.Tensor`): |
| 87 | A tensor of timesteps for which the SNR is computed. |
| 88 | |
| 89 | Returns: |
| 90 | `torch.Tensor`: A tensor containing the computed SNR values for each timestep. |
| 91 | """ |
| 92 | alphas_cumprod = noise_scheduler.alphas_cumprod |
| 93 | sqrt_alphas_cumprod = alphas_cumprod**0.5 |
| 94 | sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 |
| 95 | |
| 96 | # Expand the tensors. |
| 97 | # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 |
| 98 | sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() |
| 99 | while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): |
| 100 | sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] |
| 101 | alpha = sqrt_alphas_cumprod.expand(timesteps.shape) |
| 102 | |
| 103 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() |
| 104 | while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): |
| 105 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] |
| 106 | sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) |
| 107 | |
| 108 | # Compute SNR. |
| 109 | snr = (alpha / sigma) ** 2 |
| 110 | return snr |
| 111 | |
| 112 | |
| 113 | def compute_confidence_aware_loss( |
no test coverage detected
searching dependent graphs…