MCPcopy
hub / github.com/openai/guided-diffusion / create_model

Function create_model

guided_diffusion/script_util.py:130–184  ·  view source on GitHub ↗
(
    image_size,
    num_channels,
    num_res_blocks,
    channel_mult="",
    learn_sigma=False,
    class_cond=False,
    use_checkpoint=False,
    attention_resolutions="16",
    num_heads=1,
    num_head_channels=-1,
    num_heads_upsample=-1,
    use_scale_shift_norm=False,
    dropout=0,
    resblock_updown=False,
    use_fp16=False,
    use_new_attention_order=False,
)

Source from the content-addressed store, hash-verified

128
129
130def create_model(
131 image_size,
132 num_channels,
133 num_res_blocks,
134 channel_mult="",
135 learn_sigma=False,
136 class_cond=False,
137 use_checkpoint=False,
138 attention_resolutions="16",
139 num_heads=1,
140 num_head_channels=-1,
141 num_heads_upsample=-1,
142 use_scale_shift_norm=False,
143 dropout=0,
144 resblock_updown=False,
145 use_fp16=False,
146 use_new_attention_order=False,
147):
148 if channel_mult == "":
149 if image_size == 512:
150 channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
151 elif image_size == 256:
152 channel_mult = (1, 1, 2, 2, 4, 4)
153 elif image_size == 128:
154 channel_mult = (1, 1, 2, 3, 4)
155 elif image_size == 64:
156 channel_mult = (1, 2, 3, 4)
157 else:
158 raise ValueError(f"unsupported image size: {image_size}")
159 else:
160 channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(","))
161
162 attention_ds = []
163 for res in attention_resolutions.split(","):
164 attention_ds.append(image_size // int(res))
165
166 return UNetModel(
167 image_size=image_size,
168 in_channels=3,
169 model_channels=num_channels,
170 out_channels=(3 if not learn_sigma else 6),
171 num_res_blocks=num_res_blocks,
172 attention_resolutions=tuple(attention_ds),
173 dropout=dropout,
174 channel_mult=channel_mult,
175 num_classes=(NUM_CLASSES if class_cond else None),
176 use_checkpoint=use_checkpoint,
177 use_fp16=use_fp16,
178 num_heads=num_heads,
179 num_head_channels=num_head_channels,
180 num_heads_upsample=num_heads_upsample,
181 use_scale_shift_norm=use_scale_shift_norm,
182 resblock_updown=resblock_updown,
183 use_new_attention_order=use_new_attention_order,
184 )
185
186
187def create_classifier_and_diffusion(

Callers 1

Calls 1

UNetModelClass · 0.85

Tested by

no test coverage detected