MCPcopy
hub / github.com/openai/guided-diffusion / create_classifier

Function create_classifier

guided_diffusion/script_util.py:228–266  ·  view source on GitHub ↗
(
    image_size,
    classifier_use_fp16,
    classifier_width,
    classifier_depth,
    classifier_attention_resolutions,
    classifier_use_scale_shift_norm,
    classifier_resblock_updown,
    classifier_pool,
)

Source from the content-addressed store, hash-verified

226
227
228def create_classifier(
229 image_size,
230 classifier_use_fp16,
231 classifier_width,
232 classifier_depth,
233 classifier_attention_resolutions,
234 classifier_use_scale_shift_norm,
235 classifier_resblock_updown,
236 classifier_pool,
237):
238 if image_size == 512:
239 channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
240 elif image_size == 256:
241 channel_mult = (1, 1, 2, 2, 4, 4)
242 elif image_size == 128:
243 channel_mult = (1, 1, 2, 3, 4)
244 elif image_size == 64:
245 channel_mult = (1, 2, 3, 4)
246 else:
247 raise ValueError(f"unsupported image size: {image_size}")
248
249 attention_ds = []
250 for res in classifier_attention_resolutions.split(","):
251 attention_ds.append(image_size // int(res))
252
253 return EncoderUNetModel(
254 image_size=image_size,
255 in_channels=3,
256 model_channels=classifier_width,
257 out_channels=1000,
258 num_res_blocks=classifier_depth,
259 attention_resolutions=tuple(attention_ds),
260 channel_mult=channel_mult,
261 use_fp16=classifier_use_fp16,
262 num_head_channels=64,
263 use_scale_shift_norm=classifier_use_scale_shift_norm,
264 resblock_updown=classifier_resblock_updown,
265 pool=classifier_pool,
266 )
267
268
269def sr_model_and_diffusion_defaults():

Callers 2

mainFunction · 0.90

Calls 1

EncoderUNetModelClass · 0.85

Tested by

no test coverage detected