MCPcopy Index your code
hub / github.com/tensorlayer/SRGAN / train

Function train

train.py:73–143  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

71 return train_ds
72
73def train():
74 G = get_G((batch_size, 96, 96, 3))
75 D = get_D((batch_size, 384, 384, 3))
76 VGG = tl.models.vgg19(pretrained=True, end_with='pool4', mode='static')
77
78 lr_v = tf.Variable(lr_init)
79 g_optimizer_init = tf.optimizers.Adam(lr_v, beta_1=beta1)
80 g_optimizer = tf.optimizers.Adam(lr_v, beta_1=beta1)
81 d_optimizer = tf.optimizers.Adam(lr_v, beta_1=beta1)
82
83 G.train()
84 D.train()
85 VGG.train()
86
87 train_ds = get_train_data()
88
89 ## initialize learning (G)
90 n_step_epoch = round(n_epoch_init // batch_size)
91 for epoch in range(n_epoch_init):
92 for step, (lr_patchs, hr_patchs) in enumerate(train_ds):
93 if lr_patchs.shape[0] != batch_size: # if the remaining data in this epoch < batch_size
94 break
95 step_time = time.time()
96 with tf.GradientTape() as tape:
97 fake_hr_patchs = G(lr_patchs)
98 mse_loss = tl.cost.mean_squared_error(fake_hr_patchs, hr_patchs, is_mean=True)
99 grad = tape.gradient(mse_loss, G.trainable_weights)
100 g_optimizer_init.apply_gradients(zip(grad, G.trainable_weights))
101 print("Epoch: [{}/{}] step: [{}/{}] time: {:.3f}s, mse: {:.3f} ".format(
102 epoch, n_epoch_init, step, n_step_epoch, time.time() - step_time, mse_loss))
103 if (epoch != 0) and (epoch % 10 == 0):
104 tl.vis.save_images(fake_hr_patchs.numpy(), [2, 4], os.path.join(save_dir, 'train_g_init_{}.png'.format(epoch)))
105
106 ## adversarial learning (G, D)
107 n_step_epoch = round(n_epoch // batch_size)
108 for epoch in range(n_epoch):
109 for step, (lr_patchs, hr_patchs) in enumerate(train_ds):
110 if lr_patchs.shape[0] != batch_size: # if the remaining data in this epoch < batch_size
111 break
112 step_time = time.time()
113 with tf.GradientTape(persistent=True) as tape:
114 fake_patchs = G(lr_patchs)
115 logits_fake = D(fake_patchs)
116 logits_real = D(hr_patchs)
117 feature_fake = VGG((fake_patchs+1)/2.) # the pre-trained VGG uses the input range of [0, 1]
118 feature_real = VGG((hr_patchs+1)/2.)
119 d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real, tf.ones_like(logits_real))
120 d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake, tf.zeros_like(logits_fake))
121 d_loss = d_loss1 + d_loss2
122 g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(logits_fake, tf.ones_like(logits_fake))
123 mse_loss = tl.cost.mean_squared_error(fake_patchs, hr_patchs, is_mean=True)
124 vgg_loss = 2e-6 * tl.cost.mean_squared_error(feature_fake, feature_real, is_mean=True)
125 g_loss = mse_loss + vgg_loss + g_gan_loss
126 grad = tape.gradient(g_loss, G.trainable_weights)
127 g_optimizer.apply_gradients(zip(grad, G.trainable_weights))
128 grad = tape.gradient(d_loss, D.trainable_weights)
129 d_optimizer.apply_gradients(zip(grad, D.trainable_weights))
130 print("Epoch: [{}/{}] step: [{}/{}] time: {:.3f}s, g_loss(mse:{:.3f}, vgg:{:.3f}, adv:{:.3f}) d_loss: {:.3f}".format(

Callers 1

train.pyFile · 0.85

Calls 3

get_GFunction · 0.90
get_DFunction · 0.90
get_train_dataFunction · 0.85

Tested by

no test coverage detected