(
large_size,
small_size,
class_cond,
learn_sigma,
num_channels,
num_res_blocks,
num_heads,
num_head_channels,
num_heads_upsample,
attention_resolutions,
dropout,
diffusion_steps,
noise_schedule,
timestep_respacing,
use_kl,
predict_xstart,
rescale_timesteps,
rescale_learned_sigmas,
use_checkpoint,
use_scale_shift_norm,
resblock_updown,
use_fp16,
)
| 278 | |
| 279 | |
| 280 | def sr_create_model_and_diffusion( |
| 281 | large_size, |
| 282 | small_size, |
| 283 | class_cond, |
| 284 | learn_sigma, |
| 285 | num_channels, |
| 286 | num_res_blocks, |
| 287 | num_heads, |
| 288 | num_head_channels, |
| 289 | num_heads_upsample, |
| 290 | attention_resolutions, |
| 291 | dropout, |
| 292 | diffusion_steps, |
| 293 | noise_schedule, |
| 294 | timestep_respacing, |
| 295 | use_kl, |
| 296 | predict_xstart, |
| 297 | rescale_timesteps, |
| 298 | rescale_learned_sigmas, |
| 299 | use_checkpoint, |
| 300 | use_scale_shift_norm, |
| 301 | resblock_updown, |
| 302 | use_fp16, |
| 303 | ): |
| 304 | model = sr_create_model( |
| 305 | large_size, |
| 306 | small_size, |
| 307 | num_channels, |
| 308 | num_res_blocks, |
| 309 | learn_sigma=learn_sigma, |
| 310 | class_cond=class_cond, |
| 311 | use_checkpoint=use_checkpoint, |
| 312 | attention_resolutions=attention_resolutions, |
| 313 | num_heads=num_heads, |
| 314 | num_head_channels=num_head_channels, |
| 315 | num_heads_upsample=num_heads_upsample, |
| 316 | use_scale_shift_norm=use_scale_shift_norm, |
| 317 | dropout=dropout, |
| 318 | resblock_updown=resblock_updown, |
| 319 | use_fp16=use_fp16, |
| 320 | ) |
| 321 | diffusion = create_gaussian_diffusion( |
| 322 | steps=diffusion_steps, |
| 323 | learn_sigma=learn_sigma, |
| 324 | noise_schedule=noise_schedule, |
| 325 | use_kl=use_kl, |
| 326 | predict_xstart=predict_xstart, |
| 327 | rescale_timesteps=rescale_timesteps, |
| 328 | rescale_learned_sigmas=rescale_learned_sigmas, |
| 329 | timestep_respacing=timestep_respacing, |
| 330 | ) |
| 331 | return model, diffusion |
| 332 | |
| 333 | |
| 334 | def sr_create_model( |
no test coverage detected