(self, num_steps: int)
| 75 | super().__init__(betas, parameterization, rescale_cfg) |
| 76 | |
| 77 | def make_schedule(self, num_steps: int) -> None: |
| 78 | used_timesteps = space_timesteps(self.num_timesteps, str(num_steps)) |
| 79 | betas = [] |
| 80 | last_alpha_cumprod = 1.0 |
| 81 | for i, alpha_cumprod in enumerate(self.training_alphas_cumprod): |
| 82 | if i in used_timesteps: |
| 83 | betas.append(1 - alpha_cumprod / last_alpha_cumprod) |
| 84 | last_alpha_cumprod = alpha_cumprod |
| 85 | self.timesteps = np.array( |
| 86 | sorted(list(used_timesteps)), dtype=np.int32 |
| 87 | ) # e.g. [0, 10, 20, ...] |
| 88 | |
| 89 | betas = np.array(betas, dtype=np.float64) |
| 90 | alphas = 1.0 - betas |
| 91 | alphas_cumprod = np.cumprod(alphas, axis=0) |
| 92 | alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) |
| 93 | |
| 94 | sqrt_recip_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod) |
| 95 | sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod - 1) |
| 96 | posterior_variance = ( |
| 97 | betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) |
| 98 | ) |
| 99 | posterior_log_variance_clipped = np.log( |
| 100 | np.append(posterior_variance[1], posterior_variance[1:]) |
| 101 | ) |
| 102 | posterior_mean_coef1 = ( |
| 103 | betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod) |
| 104 | ) |
| 105 | posterior_mean_coef2 = ( |
| 106 | (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod) |
| 107 | ) |
| 108 | |
| 109 | self.register("sqrt_alphas_cumprod", np.sqrt(alphas_cumprod)) |
| 110 | self.register("sqrt_one_minus_alphas_cumprod", np.sqrt(1 - alphas_cumprod)) |
| 111 | self.register("sqrt_recip_alphas_cumprod", sqrt_recip_alphas_cumprod) |
| 112 | self.register("sqrt_recipm1_alphas_cumprod", sqrt_recipm1_alphas_cumprod) |
| 113 | self.register("posterior_variance", posterior_variance) |
| 114 | self.register("posterior_log_variance_clipped", posterior_log_variance_clipped) |
| 115 | self.register("posterior_mean_coef1", posterior_mean_coef1) |
| 116 | self.register("posterior_mean_coef2", posterior_mean_coef2) |
| 117 | |
| 118 | def q_posterior_mean_variance( |
| 119 | self, x_start: torch.Tensor, x_t: torch.Tensor, t: torch.Tensor |
no test coverage detected