| 377 | return img_summary, hr_img_summary |
| 378 | |
| 379 | def build_model(self, sess): |
| 380 | self.init_opt() |
| 381 | |
| 382 | sess.run(tf.initialize_all_variables()) |
| 383 | if len(self.model_path) > 0: |
| 384 | print("Reading model parameters from %s" % self.model_path) |
| 385 | all_vars = tf.trainable_variables() |
| 386 | # all_vars = tf.all_variables() |
| 387 | restore_vars = [] |
| 388 | for var in all_vars: |
| 389 | if var.name.startswith('g_') or var.name.startswith('d_'): |
| 390 | restore_vars.append(var) |
| 391 | # print(var.name) |
| 392 | saver = tf.train.Saver(restore_vars) |
| 393 | saver.restore(sess, self.model_path) |
| 394 | |
| 395 | istart = self.model_path.rfind('_') + 1 |
| 396 | iend = self.model_path.rfind('.') |
| 397 | counter = self.model_path[istart:iend] |
| 398 | counter = int(counter) |
| 399 | else: |
| 400 | print("Created model with fresh parameters.") |
| 401 | counter = 0 |
| 402 | return counter |
| 403 | |
| 404 | def train_one_step(self, generator_lr, |
| 405 | discriminator_lr, |