| 65 | |
| 66 | |
| 67 | def step(self, model_output, timestep, sample, to_final=False): |
| 68 | alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])] |
| 69 | if isinstance(timestep, torch.Tensor): |
| 70 | timestep = timestep.cpu() |
| 71 | timestep_id = torch.argmin((self.timesteps - timestep).abs()) |
| 72 | if to_final or timestep_id + 1 >= len(self.timesteps): |
| 73 | alpha_prod_t_prev = 1.0 |
| 74 | else: |
| 75 | timestep_prev = int(self.timesteps[timestep_id + 1]) |
| 76 | alpha_prod_t_prev = self.alphas_cumprod[timestep_prev] |
| 77 | |
| 78 | return self.denoise(model_output, sample, alpha_prod_t, alpha_prod_t_prev) |
| 79 | |
| 80 | |
| 81 | def return_to_timestep(self, timestep, sample, sample_stablized): |