(self)
| 457 | return log_vals |
| 458 | |
| 459 | def train(self): |
| 460 | config = tf.ConfigProto(allow_soft_placement=True) |
| 461 | with tf.Session(config=config) as sess: |
| 462 | with tf.device("/gpu:%d" % cfg.GPU_ID): |
| 463 | counter = self.build_model(sess) |
| 464 | saver = tf.train.Saver(tf.all_variables(), |
| 465 | keep_checkpoint_every_n_hours=5) |
| 466 | |
| 467 | # summary_op = tf.merge_all_summaries() |
| 468 | summary_writer = tf.train.SummaryWriter(self.log_dir, |
| 469 | sess.graph) |
| 470 | |
| 471 | if cfg.TRAIN.FINETUNE_LR: |
| 472 | keys = ["hr_d_loss", "hr_g_loss", "d_loss", "g_loss"] |
| 473 | else: |
| 474 | keys = ["d_loss", "g_loss"] |
| 475 | log_vars = [] |
| 476 | log_keys = [] |
| 477 | for k, v in self.log_vars: |
| 478 | if k in keys: |
| 479 | log_vars.append(v) |
| 480 | log_keys.append(k) |
| 481 | generator_lr = cfg.TRAIN.GENERATOR_LR |
| 482 | discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR |
| 483 | lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH |
| 484 | number_example = self.dataset.train._num_examples |
| 485 | updates_per_epoch = int(number_example / self.batch_size) |
| 486 | # int((counter + lr_decay_step/2) / lr_decay_step) |
| 487 | decay_start = cfg.TRAIN.PRETRAINED_EPOCH |
| 488 | epoch_start = int(counter / updates_per_epoch) |
| 489 | for epoch in range(epoch_start, self.max_epoch): |
| 490 | widgets = ["epoch #%d|" % epoch, |
| 491 | Percentage(), Bar(), ETA()] |
| 492 | pbar = ProgressBar(maxval=updates_per_epoch, |
| 493 | widgets=widgets) |
| 494 | pbar.start() |
| 495 | |
| 496 | if epoch % lr_decay_step == 0 and epoch > decay_start: |
| 497 | generator_lr *= 0.5 |
| 498 | discriminator_lr *= 0.5 |
| 499 | |
| 500 | all_log_vals = [] |
| 501 | for i in range(updates_per_epoch): |
| 502 | pbar.update(i) |
| 503 | log_vals = self.train_one_step(generator_lr, |
| 504 | discriminator_lr, |
| 505 | counter, summary_writer, |
| 506 | log_vars, sess) |
| 507 | all_log_vals.append(log_vals) |
| 508 | # save checkpoint |
| 509 | counter += 1 |
| 510 | if counter % self.snapshot_interval == 0: |
| 511 | snapshot_path = "%s/%s_%s.ckpt" %\ |
| 512 | (self.checkpoint_dir, |
| 513 | self.exp_name, |
| 514 | str(counter)) |
| 515 | fn = saver.save(sess, snapshot_path) |
| 516 | print("Model saved in file: %s" % fn) |
no test coverage detected