| 7 | from tensorlayer.models import Model |
| 8 | |
| 9 | def get_G(input_shape): |
| 10 | w_init = tf.random_normal_initializer(stddev=0.02) |
| 11 | g_init = tf.random_normal_initializer(1., 0.02) |
| 12 | |
| 13 | nin = Input(input_shape) |
| 14 | n = Conv2d(64, (3, 3), (1, 1), act=tf.nn.relu, padding='SAME', W_init=w_init)(nin) |
| 15 | temp = n |
| 16 | |
| 17 | # B residual blocks |
| 18 | for i in range(16): |
| 19 | nn = Conv2d(64, (3, 3), (1, 1), padding='SAME', W_init=w_init, b_init=None)(n) |
| 20 | nn = BatchNorm(act=tf.nn.relu, gamma_init=g_init)(nn) |
| 21 | nn = Conv2d(64, (3, 3), (1, 1), padding='SAME', W_init=w_init, b_init=None)(nn) |
| 22 | nn = BatchNorm(gamma_init=g_init)(nn) |
| 23 | nn = Elementwise(tf.add)([n, nn]) |
| 24 | n = nn |
| 25 | |
| 26 | n = Conv2d(64, (3, 3), (1, 1), padding='SAME', W_init=w_init, b_init=None)(n) |
| 27 | n = BatchNorm(gamma_init=g_init)(n) |
| 28 | n = Elementwise(tf.add)([n, temp]) |
| 29 | # B residual blacks end |
| 30 | |
| 31 | n = Conv2d(256, (3, 3), (1, 1), padding='SAME', W_init=w_init)(n) |
| 32 | n = SubpixelConv2d(scale=2, n_out_channels=None, act=tf.nn.relu)(n) |
| 33 | |
| 34 | n = Conv2d(256, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init)(n) |
| 35 | n = SubpixelConv2d(scale=2, n_out_channels=None, act=tf.nn.relu)(n) |
| 36 | |
| 37 | nn = Conv2d(3, (1, 1), (1, 1), act=tf.nn.tanh, padding='SAME', W_init=w_init)(n) |
| 38 | G = Model(inputs=nin, outputs=nn, name="generator") |
| 39 | return G |
| 40 | |
| 41 | def get_D(input_shape): |
| 42 | w_init = tf.random_normal_initializer(stddev=0.02) |