()
| 71 | return train_ds |
| 72 | |
| 73 | def train(): |
| 74 | G = get_G((batch_size, 96, 96, 3)) |
| 75 | D = get_D((batch_size, 384, 384, 3)) |
| 76 | VGG = tl.models.vgg19(pretrained=True, end_with='pool4', mode='static') |
| 77 | |
| 78 | lr_v = tf.Variable(lr_init) |
| 79 | g_optimizer_init = tf.optimizers.Adam(lr_v, beta_1=beta1) |
| 80 | g_optimizer = tf.optimizers.Adam(lr_v, beta_1=beta1) |
| 81 | d_optimizer = tf.optimizers.Adam(lr_v, beta_1=beta1) |
| 82 | |
| 83 | G.train() |
| 84 | D.train() |
| 85 | VGG.train() |
| 86 | |
| 87 | train_ds = get_train_data() |
| 88 | |
| 89 | ## initialize learning (G) |
| 90 | n_step_epoch = round(n_epoch_init // batch_size) |
| 91 | for epoch in range(n_epoch_init): |
| 92 | for step, (lr_patchs, hr_patchs) in enumerate(train_ds): |
| 93 | if lr_patchs.shape[0] != batch_size: # if the remaining data in this epoch < batch_size |
| 94 | break |
| 95 | step_time = time.time() |
| 96 | with tf.GradientTape() as tape: |
| 97 | fake_hr_patchs = G(lr_patchs) |
| 98 | mse_loss = tl.cost.mean_squared_error(fake_hr_patchs, hr_patchs, is_mean=True) |
| 99 | grad = tape.gradient(mse_loss, G.trainable_weights) |
| 100 | g_optimizer_init.apply_gradients(zip(grad, G.trainable_weights)) |
| 101 | print("Epoch: [{}/{}] step: [{}/{}] time: {:.3f}s, mse: {:.3f} ".format( |
| 102 | epoch, n_epoch_init, step, n_step_epoch, time.time() - step_time, mse_loss)) |
| 103 | if (epoch != 0) and (epoch % 10 == 0): |
| 104 | tl.vis.save_images(fake_hr_patchs.numpy(), [2, 4], os.path.join(save_dir, 'train_g_init_{}.png'.format(epoch))) |
| 105 | |
| 106 | ## adversarial learning (G, D) |
| 107 | n_step_epoch = round(n_epoch // batch_size) |
| 108 | for epoch in range(n_epoch): |
| 109 | for step, (lr_patchs, hr_patchs) in enumerate(train_ds): |
| 110 | if lr_patchs.shape[0] != batch_size: # if the remaining data in this epoch < batch_size |
| 111 | break |
| 112 | step_time = time.time() |
| 113 | with tf.GradientTape(persistent=True) as tape: |
| 114 | fake_patchs = G(lr_patchs) |
| 115 | logits_fake = D(fake_patchs) |
| 116 | logits_real = D(hr_patchs) |
| 117 | feature_fake = VGG((fake_patchs+1)/2.) # the pre-trained VGG uses the input range of [0, 1] |
| 118 | feature_real = VGG((hr_patchs+1)/2.) |
| 119 | d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real, tf.ones_like(logits_real)) |
| 120 | d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake, tf.zeros_like(logits_fake)) |
| 121 | d_loss = d_loss1 + d_loss2 |
| 122 | g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(logits_fake, tf.ones_like(logits_fake)) |
| 123 | mse_loss = tl.cost.mean_squared_error(fake_patchs, hr_patchs, is_mean=True) |
| 124 | vgg_loss = 2e-6 * tl.cost.mean_squared_error(feature_fake, feature_real, is_mean=True) |
| 125 | g_loss = mse_loss + vgg_loss + g_gan_loss |
| 126 | grad = tape.gradient(g_loss, G.trainable_weights) |
| 127 | g_optimizer.apply_gradients(zip(grad, G.trainable_weights)) |
| 128 | grad = tape.gradient(d_loss, D.trainable_weights) |
| 129 | d_optimizer.apply_gradients(zip(grad, D.trainable_weights)) |
| 130 | print("Epoch: [{}/{}] step: [{}/{}] time: {:.3f}s, g_loss(mse:{:.3f}, vgg:{:.3f}, adv:{:.3f}) d_loss: {:.3f}".format( |
no test coverage detected