| 8 | |
| 9 | |
| 10 | def create_diffusion( |
| 11 | timestep_respacing, |
| 12 | noise_schedule="linear", |
| 13 | use_kl=False, |
| 14 | sigma_small=False, |
| 15 | predict_xstart=False, |
| 16 | learn_sigma=True, |
| 17 | # learn_sigma=False, |
| 18 | rescale_learned_sigmas=False, |
| 19 | diffusion_steps=1000 |
| 20 | ): |
| 21 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) |
| 22 | if use_kl: |
| 23 | loss_type = gd.LossType.RESCALED_KL |
| 24 | elif rescale_learned_sigmas: |
| 25 | loss_type = gd.LossType.RESCALED_MSE |
| 26 | else: |
| 27 | loss_type = gd.LossType.MSE |
| 28 | if timestep_respacing is None or timestep_respacing == "": |
| 29 | timestep_respacing = [diffusion_steps] |
| 30 | return SpacedDiffusion( |
| 31 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), |
| 32 | betas=betas, |
| 33 | model_mean_type=( |
| 34 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X |
| 35 | ), |
| 36 | model_var_type=( |
| 37 | ( |
| 38 | gd.ModelVarType.FIXED_LARGE |
| 39 | if not sigma_small |
| 40 | else gd.ModelVarType.FIXED_SMALL |
| 41 | ) |
| 42 | if not learn_sigma |
| 43 | else gd.ModelVarType.LEARNED_RANGE |
| 44 | ), |
| 45 | loss_type=loss_type |
| 46 | # rescale_timesteps=rescale_timesteps, |
| 47 | ) |