MCPcopy
hub / github.com/hanzhanggit/StackGAN / train

Method train

stageII/trainer.py:459–535  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

457 return log_vals
458
459 def train(self):
460 config = tf.ConfigProto(allow_soft_placement=True)
461 with tf.Session(config=config) as sess:
462 with tf.device("/gpu:%d" % cfg.GPU_ID):
463 counter = self.build_model(sess)
464 saver = tf.train.Saver(tf.all_variables(),
465 keep_checkpoint_every_n_hours=5)
466
467 # summary_op = tf.merge_all_summaries()
468 summary_writer = tf.train.SummaryWriter(self.log_dir,
469 sess.graph)
470
471 if cfg.TRAIN.FINETUNE_LR:
472 keys = ["hr_d_loss", "hr_g_loss", "d_loss", "g_loss"]
473 else:
474 keys = ["d_loss", "g_loss"]
475 log_vars = []
476 log_keys = []
477 for k, v in self.log_vars:
478 if k in keys:
479 log_vars.append(v)
480 log_keys.append(k)
481 generator_lr = cfg.TRAIN.GENERATOR_LR
482 discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR
483 lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH
484 number_example = self.dataset.train._num_examples
485 updates_per_epoch = int(number_example / self.batch_size)
486 # int((counter + lr_decay_step/2) / lr_decay_step)
487 decay_start = cfg.TRAIN.PRETRAINED_EPOCH
488 epoch_start = int(counter / updates_per_epoch)
489 for epoch in range(epoch_start, self.max_epoch):
490 widgets = ["epoch #%d|" % epoch,
491 Percentage(), Bar(), ETA()]
492 pbar = ProgressBar(maxval=updates_per_epoch,
493 widgets=widgets)
494 pbar.start()
495
496 if epoch % lr_decay_step == 0 and epoch > decay_start:
497 generator_lr *= 0.5
498 discriminator_lr *= 0.5
499
500 all_log_vals = []
501 for i in range(updates_per_epoch):
502 pbar.update(i)
503 log_vals = self.train_one_step(generator_lr,
504 discriminator_lr,
505 counter, summary_writer,
506 log_vars, sess)
507 all_log_vals.append(log_vals)
508 # save checkpoint
509 counter += 1
510 if counter % self.snapshot_interval == 0:
511 snapshot_path = "%s/%s_%s.ckpt" %\
512 (self.checkpoint_dir,
513 self.exp_name,
514 str(counter))
515 fn = saver.save(sess, snapshot_path)
516 print("Model saved in file: %s" % fn)

Callers 1

run_exp.pyFile · 0.45

Calls 3

build_modelMethod · 0.95
train_one_stepMethod · 0.95
epoch_sum_imagesMethod · 0.95

Tested by

no test coverage detected