(
image_size,
class_cond,
learn_sigma,
num_channels,
num_res_blocks,
channel_mult,
num_heads,
num_head_channels,
num_heads_upsample,
attention_resolutions,
dropout,
diffusion_steps,
noise_schedule,
timestep_respacing,
use_kl,
predict_xstart,
rescale_timesteps,
rescale_learned_sigmas,
use_checkpoint,
use_scale_shift_norm,
resblock_updown,
use_fp16,
use_new_attention_order,
)
| 72 | |
| 73 | |
| 74 | def create_model_and_diffusion( |
| 75 | image_size, |
| 76 | class_cond, |
| 77 | learn_sigma, |
| 78 | num_channels, |
| 79 | num_res_blocks, |
| 80 | channel_mult, |
| 81 | num_heads, |
| 82 | num_head_channels, |
| 83 | num_heads_upsample, |
| 84 | attention_resolutions, |
| 85 | dropout, |
| 86 | diffusion_steps, |
| 87 | noise_schedule, |
| 88 | timestep_respacing, |
| 89 | use_kl, |
| 90 | predict_xstart, |
| 91 | rescale_timesteps, |
| 92 | rescale_learned_sigmas, |
| 93 | use_checkpoint, |
| 94 | use_scale_shift_norm, |
| 95 | resblock_updown, |
| 96 | use_fp16, |
| 97 | use_new_attention_order, |
| 98 | ): |
| 99 | model = create_model( |
| 100 | image_size, |
| 101 | num_channels, |
| 102 | num_res_blocks, |
| 103 | channel_mult=channel_mult, |
| 104 | learn_sigma=learn_sigma, |
| 105 | class_cond=class_cond, |
| 106 | use_checkpoint=use_checkpoint, |
| 107 | attention_resolutions=attention_resolutions, |
| 108 | num_heads=num_heads, |
| 109 | num_head_channels=num_head_channels, |
| 110 | num_heads_upsample=num_heads_upsample, |
| 111 | use_scale_shift_norm=use_scale_shift_norm, |
| 112 | dropout=dropout, |
| 113 | resblock_updown=resblock_updown, |
| 114 | use_fp16=use_fp16, |
| 115 | use_new_attention_order=use_new_attention_order, |
| 116 | ) |
| 117 | diffusion = create_gaussian_diffusion( |
| 118 | steps=diffusion_steps, |
| 119 | learn_sigma=learn_sigma, |
| 120 | noise_schedule=noise_schedule, |
| 121 | use_kl=use_kl, |
| 122 | predict_xstart=predict_xstart, |
| 123 | rescale_timesteps=rescale_timesteps, |
| 124 | rescale_learned_sigmas=rescale_learned_sigmas, |
| 125 | timestep_respacing=timestep_respacing, |
| 126 | ) |
| 127 | return model, diffusion |
| 128 | |
| 129 | |
| 130 | def create_model( |
no test coverage detected