MCPcopy Index your code
hub / github.com/openai/improved-diffusion / create_model

Function create_model

improved_diffusion/script_util.py:86–125  ·  view source on GitHub ↗
(
    image_size,
    num_channels,
    num_res_blocks,
    learn_sigma,
    class_cond,
    use_checkpoint,
    attention_resolutions,
    num_heads,
    num_heads_upsample,
    use_scale_shift_norm,
    dropout,
)

Source from the content-addressed store, hash-verified

84
85
86def create_model(
87 image_size,
88 num_channels,
89 num_res_blocks,
90 learn_sigma,
91 class_cond,
92 use_checkpoint,
93 attention_resolutions,
94 num_heads,
95 num_heads_upsample,
96 use_scale_shift_norm,
97 dropout,
98):
99 if image_size == 256:
100 channel_mult = (1, 1, 2, 2, 4, 4)
101 elif image_size == 64:
102 channel_mult = (1, 2, 3, 4)
103 elif image_size == 32:
104 channel_mult = (1, 2, 2, 2)
105 else:
106 raise ValueError(f"unsupported image size: {image_size}")
107
108 attention_ds = []
109 for res in attention_resolutions.split(","):
110 attention_ds.append(image_size // int(res))
111
112 return UNetModel(
113 in_channels=3,
114 model_channels=num_channels,
115 out_channels=(3 if not learn_sigma else 6),
116 num_res_blocks=num_res_blocks,
117 attention_resolutions=tuple(attention_ds),
118 dropout=dropout,
119 channel_mult=channel_mult,
120 num_classes=(NUM_CLASSES if class_cond else None),
121 use_checkpoint=use_checkpoint,
122 num_heads=num_heads,
123 num_heads_upsample=num_heads_upsample,
124 use_scale_shift_norm=use_scale_shift_norm,
125 )
126
127
128def sr_model_and_diffusion_defaults():

Callers 1

Calls 1

UNetModelClass · 0.85

Tested by

no test coverage detected