A diffusion process which can skip steps in a base diffusion process. :param use_timesteps: a collection (sequence or set) of timesteps from the original diffusion process to retain. :param kwargs: the kwargs to create the base diffusion process.
| 61 | |
| 62 | |
| 63 | class SpacedDiffusion(GaussianDiffusion): |
| 64 | """ |
| 65 | A diffusion process which can skip steps in a base diffusion process. |
| 66 | |
| 67 | :param use_timesteps: a collection (sequence or set) of timesteps from the |
| 68 | original diffusion process to retain. |
| 69 | :param kwargs: the kwargs to create the base diffusion process. |
| 70 | """ |
| 71 | |
| 72 | def __init__(self, use_timesteps, **kwargs): |
| 73 | self.use_timesteps = set(use_timesteps) |
| 74 | self.timestep_map = [] |
| 75 | self.original_num_steps = len(kwargs["betas"]) |
| 76 | |
| 77 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa |
| 78 | last_alpha_cumprod = 1.0 |
| 79 | new_betas = [] |
| 80 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): |
| 81 | if i in self.use_timesteps: |
| 82 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) |
| 83 | last_alpha_cumprod = alpha_cumprod |
| 84 | self.timestep_map.append(i) |
| 85 | kwargs["betas"] = np.array(new_betas) |
| 86 | super().__init__(**kwargs) |
| 87 | |
| 88 | def p_mean_variance( |
| 89 | self, model, *args, **kwargs |
| 90 | ): # pylint: disable=signature-differs |
| 91 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) |
| 92 | |
| 93 | def training_losses( |
| 94 | self, model, *args, **kwargs |
| 95 | ): # pylint: disable=signature-differs |
| 96 | return super().training_losses(self._wrap_model(model), *args, **kwargs) |
| 97 | |
| 98 | def condition_mean(self, cond_fn, *args, **kwargs): |
| 99 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) |
| 100 | |
| 101 | def condition_score(self, cond_fn, *args, **kwargs): |
| 102 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) |
| 103 | |
| 104 | def _wrap_model(self, model): |
| 105 | if isinstance(model, _WrappedModel): |
| 106 | return model |
| 107 | return _WrappedModel( |
| 108 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps |
| 109 | ) |
| 110 | |
| 111 | def _scale_timesteps(self, t): |
| 112 | # Scaling is done by the wrapped model. |
| 113 | return t |
| 114 | |
| 115 | |
| 116 | class _WrappedModel: |
no outgoing calls
no test coverage detected