(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps)
| 20 | |
| 21 | |
| 22 | def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): |
| 23 | if beta_schedule == 'quad': |
| 24 | betas = np.linspace(beta_start ** 0.5, beta_end ** 0.5, num_diffusion_timesteps, dtype=np.float64) ** 2 |
| 25 | elif beta_schedule == 'linear': |
| 26 | betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) |
| 27 | elif beta_schedule == 'warmup10': |
| 28 | betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1) |
| 29 | elif beta_schedule == 'warmup50': |
| 30 | betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5) |
| 31 | elif beta_schedule == 'const': |
| 32 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) |
| 33 | elif beta_schedule == 'jsd': # 1/T, 1/(T-1), 1/(T-2), ..., 1 |
| 34 | betas = 1. / np.linspace(num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64) |
| 35 | else: |
| 36 | raise NotImplementedError(beta_schedule) |
| 37 | assert betas.shape == (num_diffusion_timesteps,) |
| 38 | return betas |
| 39 | |
| 40 | |
| 41 | def noise_like(shape, noise_fn=tf.random_normal, repeat=False, dtype=tf.float32): |
no test coverage detected