MCPcopy
hub / github.com/hojonathanho/diffusion / model

Function model

diffusion_tf/models/unet.py:87–145  ·  view source on GitHub ↗
(x, *, t, y, name, num_classes, reuse=tf.AUTO_REUSE, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks,
          attn_resolutions, dropout=0., resamp_with_conv=True)

Source from the content-addressed store, hash-verified

85
86
87def model(x, *, t, y, name, num_classes, reuse=tf.AUTO_REUSE, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks,
88 attn_resolutions, dropout=0., resamp_with_conv=True):
89 B, S, _, _ = x.shape
90 assert x.dtype == tf.float32 and x.shape[2] == S
91 assert t.dtype in [tf.int32, tf.int64]
92 num_resolutions = len(ch_mult)
93
94 assert num_classes == 1 and y is None, 'not supported'
95 del y
96
97 with tf.variable_scope(name, reuse=reuse):
98 # Timestep embedding
99 with tf.variable_scope('temb'):
100 temb = nn.get_timestep_embedding(t, ch)
101 temb = nn.dense(temb, name='dense0', num_units=ch * 4)
102 temb = nn.dense(nonlinearity(temb), name='dense1', num_units=ch * 4)
103 assert temb.shape == [B, ch * 4]
104
105 # Downsampling
106 hs = [nn.conv2d(x, name='conv_in', num_units=ch)]
107 for i_level in range(num_resolutions):
108 with tf.variable_scope('down_{}'.format(i_level)):
109 # Residual blocks for this resolution
110 for i_block in range(num_res_blocks):
111 h = resnet_block(
112 hs[-1], name='block_{}'.format(i_block), temb=temb, out_ch=ch * ch_mult[i_level], dropout=dropout)
113 if h.shape[1] in attn_resolutions:
114 h = attn_block(h, name='attn_{}'.format(i_block), temb=temb)
115 hs.append(h)
116 # Downsample
117 if i_level != num_resolutions - 1:
118 hs.append(downsample(hs[-1], name='downsample', with_conv=resamp_with_conv))
119
120 # Middle
121 with tf.variable_scope('mid'):
122 h = hs[-1]
123 h = resnet_block(h, temb=temb, name='block_1', dropout=dropout)
124 h = attn_block(h, name='attn_1'.format(i_block), temb=temb)
125 h = resnet_block(h, temb=temb, name='block_2', dropout=dropout)
126
127 # Upsampling
128 for i_level in reversed(range(num_resolutions)):
129 with tf.variable_scope('up_{}'.format(i_level)):
130 # Residual blocks for this resolution
131 for i_block in range(num_res_blocks + 1):
132 h = resnet_block(tf.concat([h, hs.pop()], axis=-1), name='block_{}'.format(i_block),
133 temb=temb, out_ch=ch * ch_mult[i_level], dropout=dropout)
134 if h.shape[1] in attn_resolutions:
135 h = attn_block(h, name='attn_{}'.format(i_block), temb=temb)
136 # Upsample
137 if i_level != 0:
138 h = upsample(h, name='upsample', with_conv=resamp_with_conv)
139 assert not hs
140
141 # End
142 h = nonlinearity(normalize(h, temb=temb, name='norm_out'))
143 h = nn.conv2d(h, name='conv_out', num_units=out_ch, init_scale=0.)
144 assert h.shape == x.shape[:3] + [out_ch]

Callers

nothing calls this directly

Calls 6

nonlinearityFunction · 0.85
resnet_blockFunction · 0.85
attn_blockFunction · 0.85
downsampleFunction · 0.85
upsampleFunction · 0.85
normalizeFunction · 0.85

Tested by

no test coverage detected