(
*,
steps=1000,
learn_sigma=False,
sigma_small=False,
noise_schedule="linear",
use_kl=False,
predict_xstart=False,
rescale_timesteps=False,
rescale_learned_sigmas=False,
timestep_respacing="",
)
| 384 | |
| 385 | |
| 386 | def create_gaussian_diffusion( |
| 387 | *, |
| 388 | steps=1000, |
| 389 | learn_sigma=False, |
| 390 | sigma_small=False, |
| 391 | noise_schedule="linear", |
| 392 | use_kl=False, |
| 393 | predict_xstart=False, |
| 394 | rescale_timesteps=False, |
| 395 | rescale_learned_sigmas=False, |
| 396 | timestep_respacing="", |
| 397 | ): |
| 398 | betas = gd.get_named_beta_schedule(noise_schedule, steps) |
| 399 | if use_kl: |
| 400 | loss_type = gd.LossType.RESCALED_KL |
| 401 | elif rescale_learned_sigmas: |
| 402 | loss_type = gd.LossType.RESCALED_MSE |
| 403 | else: |
| 404 | loss_type = gd.LossType.MSE |
| 405 | if not timestep_respacing: |
| 406 | timestep_respacing = [steps] |
| 407 | return SpacedDiffusion( |
| 408 | use_timesteps=space_timesteps(steps, timestep_respacing), |
| 409 | betas=betas, |
| 410 | model_mean_type=( |
| 411 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X |
| 412 | ), |
| 413 | model_var_type=( |
| 414 | ( |
| 415 | gd.ModelVarType.FIXED_LARGE |
| 416 | if not sigma_small |
| 417 | else gd.ModelVarType.FIXED_SMALL |
| 418 | ) |
| 419 | if not learn_sigma |
| 420 | else gd.ModelVarType.LEARNED_RANGE |
| 421 | ), |
| 422 | loss_type=loss_type, |
| 423 | rescale_timesteps=rescale_timesteps, |
| 424 | ) |
| 425 | |
| 426 | |
| 427 | def add_dict_to_argparser(parser, default_dict): |
no test coverage detected