(x, *, temb, name, out_ch=None, conv_shortcut=False, dropout)
| 35 | |
| 36 | |
| 37 | def resnet_block(x, *, temb, name, out_ch=None, conv_shortcut=False, dropout): |
| 38 | B, H, W, C = x.shape |
| 39 | if out_ch is None: |
| 40 | out_ch = C |
| 41 | |
| 42 | with tf.variable_scope(name): |
| 43 | h = x |
| 44 | |
| 45 | h = nonlinearity(normalize(h, temb=temb, name='norm1')) |
| 46 | h = nn.conv2d(h, name='conv1', num_units=out_ch) |
| 47 | |
| 48 | # add in timestep embedding |
| 49 | h += nn.dense(nonlinearity(temb), name='temb_proj', num_units=out_ch)[:, None, None, :] |
| 50 | |
| 51 | h = nonlinearity(normalize(h, temb=temb, name='norm2')) |
| 52 | h = tf.nn.dropout(h, rate=dropout) |
| 53 | h = nn.conv2d(h, name='conv2', num_units=out_ch, init_scale=0.) |
| 54 | |
| 55 | if C != out_ch: |
| 56 | if conv_shortcut: |
| 57 | x = nn.conv2d(x, name='conv_shortcut', num_units=out_ch) |
| 58 | else: |
| 59 | x = nn.nin(x, name='nin_shortcut', num_units=out_ch) |
| 60 | |
| 61 | assert x.shape == h.shape |
| 62 | print('{}: x={} temb={}'.format(tf.get_default_graph().get_name_scope(), x.shape, temb.shape)) |
| 63 | return x + h |
| 64 | |
| 65 | |
| 66 | def attn_block(x, *, name, temb): |
no test coverage detected