(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True)
| 44 | |
| 45 | |
| 46 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): |
| 47 | if ddim_discr_method == 'uniform': |
| 48 | c = num_ddpm_timesteps // num_ddim_timesteps |
| 49 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) |
| 50 | elif ddim_discr_method == 'quad': |
| 51 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) |
| 52 | else: |
| 53 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') |
| 54 | |
| 55 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps |
| 56 | # add one to get the final alpha values right (the ones from first scale to data during sampling) |
| 57 | steps_out = ddim_timesteps + 1 |
| 58 | if verbose: |
| 59 | print(f'Selected timesteps for ddim sampler: {steps_out}') |
| 60 | return steps_out |
| 61 | |
| 62 | |
| 63 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): |
no test coverage detected