MCPcopy
hub / github.com/hojonathanho/diffusion / resnet_block

Function resnet_block

diffusion_tf/models/unet.py:37–63  ·  view source on GitHub ↗
(x, *, temb, name, out_ch=None, conv_shortcut=False, dropout)

Source from the content-addressed store, hash-verified

35
36
37def resnet_block(x, *, temb, name, out_ch=None, conv_shortcut=False, dropout):
38 B, H, W, C = x.shape
39 if out_ch is None:
40 out_ch = C
41
42 with tf.variable_scope(name):
43 h = x
44
45 h = nonlinearity(normalize(h, temb=temb, name='norm1'))
46 h = nn.conv2d(h, name='conv1', num_units=out_ch)
47
48 # add in timestep embedding
49 h += nn.dense(nonlinearity(temb), name='temb_proj', num_units=out_ch)[:, None, None, :]
50
51 h = nonlinearity(normalize(h, temb=temb, name='norm2'))
52 h = tf.nn.dropout(h, rate=dropout)
53 h = nn.conv2d(h, name='conv2', num_units=out_ch, init_scale=0.)
54
55 if C != out_ch:
56 if conv_shortcut:
57 x = nn.conv2d(x, name='conv_shortcut', num_units=out_ch)
58 else:
59 x = nn.nin(x, name='nin_shortcut', num_units=out_ch)
60
61 assert x.shape == h.shape
62 print('{}: x={} temb={}'.format(tf.get_default_graph().get_name_scope(), x.shape, temb.shape))
63 return x + h
64
65
66def attn_block(x, *, name, temb):

Callers 1

modelFunction · 0.85

Calls 2

nonlinearityFunction · 0.85
normalizeFunction · 0.85

Tested by

no test coverage detected