MCPcopy
hub / github.com/openai/improved-diffusion / __init__

Method __init__

improved_diffusion/unet.py:301–437  ·  view source on GitHub ↗
(
        self,
        in_channels,
        model_channels,
        out_channels,
        num_res_blocks,
        attention_resolutions,
        dropout=0,
        channel_mult=(1, 2, 4, 8),
        conv_resample=True,
        dims=2,
        num_classes=None,
        use_checkpoint=False,
        num_heads=1,
        num_heads_upsample=-1,
        use_scale_shift_norm=False,
    )

Source from the content-addressed store, hash-verified

299 """
300
301 def __init__(
302 self,
303 in_channels,
304 model_channels,
305 out_channels,
306 num_res_blocks,
307 attention_resolutions,
308 dropout=0,
309 channel_mult=(1, 2, 4, 8),
310 conv_resample=True,
311 dims=2,
312 num_classes=None,
313 use_checkpoint=False,
314 num_heads=1,
315 num_heads_upsample=-1,
316 use_scale_shift_norm=False,
317 ):
318 super().__init__()
319
320 if num_heads_upsample == -1:
321 num_heads_upsample = num_heads
322
323 self.in_channels = in_channels
324 self.model_channels = model_channels
325 self.out_channels = out_channels
326 self.num_res_blocks = num_res_blocks
327 self.attention_resolutions = attention_resolutions
328 self.dropout = dropout
329 self.channel_mult = channel_mult
330 self.conv_resample = conv_resample
331 self.num_classes = num_classes
332 self.use_checkpoint = use_checkpoint
333 self.num_heads = num_heads
334 self.num_heads_upsample = num_heads_upsample
335
336 time_embed_dim = model_channels * 4
337 self.time_embed = nn.Sequential(
338 linear(model_channels, time_embed_dim),
339 SiLU(),
340 linear(time_embed_dim, time_embed_dim),
341 )
342
343 if self.num_classes is not None:
344 self.label_emb = nn.Embedding(num_classes, time_embed_dim)
345
346 self.input_blocks = nn.ModuleList(
347 [
348 TimestepEmbedSequential(
349 conv_nd(dims, in_channels, model_channels, 3, padding=1)
350 )
351 ]
352 )
353 input_block_chans = [model_channels]
354 ch = model_channels
355 ds = 1
356 for level, mult in enumerate(channel_mult):
357 for _ in range(num_res_blocks):
358 layers = [

Callers

nothing calls this directly

Calls 11

linearFunction · 0.85
SiLUClass · 0.85
conv_ndFunction · 0.85
ResBlockClass · 0.85
AttentionBlockClass · 0.85
DownsampleClass · 0.85
UpsampleClass · 0.85
normalizationFunction · 0.85
zero_moduleFunction · 0.85
__init__Method · 0.45

Tested by

no test coverage detected