(x, *, t, y, name, num_classes, reuse=tf.AUTO_REUSE, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks,
attn_resolutions, dropout=0., resamp_with_conv=True)
| 85 | |
| 86 | |
| 87 | def model(x, *, t, y, name, num_classes, reuse=tf.AUTO_REUSE, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, |
| 88 | attn_resolutions, dropout=0., resamp_with_conv=True): |
| 89 | B, S, _, _ = x.shape |
| 90 | assert x.dtype == tf.float32 and x.shape[2] == S |
| 91 | assert t.dtype in [tf.int32, tf.int64] |
| 92 | num_resolutions = len(ch_mult) |
| 93 | |
| 94 | assert num_classes == 1 and y is None, 'not supported' |
| 95 | del y |
| 96 | |
| 97 | with tf.variable_scope(name, reuse=reuse): |
| 98 | # Timestep embedding |
| 99 | with tf.variable_scope('temb'): |
| 100 | temb = nn.get_timestep_embedding(t, ch) |
| 101 | temb = nn.dense(temb, name='dense0', num_units=ch * 4) |
| 102 | temb = nn.dense(nonlinearity(temb), name='dense1', num_units=ch * 4) |
| 103 | assert temb.shape == [B, ch * 4] |
| 104 | |
| 105 | # Downsampling |
| 106 | hs = [nn.conv2d(x, name='conv_in', num_units=ch)] |
| 107 | for i_level in range(num_resolutions): |
| 108 | with tf.variable_scope('down_{}'.format(i_level)): |
| 109 | # Residual blocks for this resolution |
| 110 | for i_block in range(num_res_blocks): |
| 111 | h = resnet_block( |
| 112 | hs[-1], name='block_{}'.format(i_block), temb=temb, out_ch=ch * ch_mult[i_level], dropout=dropout) |
| 113 | if h.shape[1] in attn_resolutions: |
| 114 | h = attn_block(h, name='attn_{}'.format(i_block), temb=temb) |
| 115 | hs.append(h) |
| 116 | # Downsample |
| 117 | if i_level != num_resolutions - 1: |
| 118 | hs.append(downsample(hs[-1], name='downsample', with_conv=resamp_with_conv)) |
| 119 | |
| 120 | # Middle |
| 121 | with tf.variable_scope('mid'): |
| 122 | h = hs[-1] |
| 123 | h = resnet_block(h, temb=temb, name='block_1', dropout=dropout) |
| 124 | h = attn_block(h, name='attn_1'.format(i_block), temb=temb) |
| 125 | h = resnet_block(h, temb=temb, name='block_2', dropout=dropout) |
| 126 | |
| 127 | # Upsampling |
| 128 | for i_level in reversed(range(num_resolutions)): |
| 129 | with tf.variable_scope('up_{}'.format(i_level)): |
| 130 | # Residual blocks for this resolution |
| 131 | for i_block in range(num_res_blocks + 1): |
| 132 | h = resnet_block(tf.concat([h, hs.pop()], axis=-1), name='block_{}'.format(i_block), |
| 133 | temb=temb, out_ch=ch * ch_mult[i_level], dropout=dropout) |
| 134 | if h.shape[1] in attn_resolutions: |
| 135 | h = attn_block(h, name='attn_{}'.format(i_block), temb=temb) |
| 136 | # Upsample |
| 137 | if i_level != 0: |
| 138 | h = upsample(h, name='upsample', with_conv=resamp_with_conv) |
| 139 | assert not hs |
| 140 | |
| 141 | # End |
| 142 | h = nonlinearity(normalize(h, temb=temb, name='norm_out')) |
| 143 | h = nn.conv2d(h, name='conv_out', num_units=out_ch, init_scale=0.) |
| 144 | assert h.shape == x.shape[:3] + [out_ch] |
nothing calls this directly
no test coverage detected