(
image_size,
classifier_use_fp16,
classifier_width,
classifier_depth,
classifier_attention_resolutions,
classifier_use_scale_shift_norm,
classifier_resblock_updown,
classifier_pool,
)
| 226 | |
| 227 | |
| 228 | def create_classifier( |
| 229 | image_size, |
| 230 | classifier_use_fp16, |
| 231 | classifier_width, |
| 232 | classifier_depth, |
| 233 | classifier_attention_resolutions, |
| 234 | classifier_use_scale_shift_norm, |
| 235 | classifier_resblock_updown, |
| 236 | classifier_pool, |
| 237 | ): |
| 238 | if image_size == 512: |
| 239 | channel_mult = (0.5, 1, 1, 2, 2, 4, 4) |
| 240 | elif image_size == 256: |
| 241 | channel_mult = (1, 1, 2, 2, 4, 4) |
| 242 | elif image_size == 128: |
| 243 | channel_mult = (1, 1, 2, 3, 4) |
| 244 | elif image_size == 64: |
| 245 | channel_mult = (1, 2, 3, 4) |
| 246 | else: |
| 247 | raise ValueError(f"unsupported image size: {image_size}") |
| 248 | |
| 249 | attention_ds = [] |
| 250 | for res in classifier_attention_resolutions.split(","): |
| 251 | attention_ds.append(image_size // int(res)) |
| 252 | |
| 253 | return EncoderUNetModel( |
| 254 | image_size=image_size, |
| 255 | in_channels=3, |
| 256 | model_channels=classifier_width, |
| 257 | out_channels=1000, |
| 258 | num_res_blocks=classifier_depth, |
| 259 | attention_resolutions=tuple(attention_ds), |
| 260 | channel_mult=channel_mult, |
| 261 | use_fp16=classifier_use_fp16, |
| 262 | num_head_channels=64, |
| 263 | use_scale_shift_norm=classifier_use_scale_shift_norm, |
| 264 | resblock_updown=classifier_resblock_updown, |
| 265 | pool=classifier_pool, |
| 266 | ) |
| 267 | |
| 268 | |
| 269 | def sr_model_and_diffusion_defaults(): |
no test coverage detected