MCPcopy
hub / github.com/hanzhanggit/StackGAN / train

Method train

stageI/trainer.py:302–394  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

300 return counter
301
302 def train(self):
303 config = tf.ConfigProto(allow_soft_placement=True)
304 with tf.Session(config=config) as sess:
305 with tf.device("/gpu:%d" % cfg.GPU_ID):
306 counter = self.build_model(sess)
307 saver = tf.train.Saver(tf.all_variables(),
308 keep_checkpoint_every_n_hours=2)
309
310 # summary_op = tf.merge_all_summaries()
311 summary_writer = tf.train.SummaryWriter(self.log_dir,
312 sess.graph)
313
314 keys = ["d_loss", "g_loss"]
315 log_vars = []
316 log_keys = []
317 for k, v in self.log_vars:
318 if k in keys:
319 log_vars.append(v)
320 log_keys.append(k)
321 # print(k, v)
322 generator_lr = cfg.TRAIN.GENERATOR_LR
323 discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR
324 num_embedding = cfg.TRAIN.NUM_EMBEDDING
325 lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH
326 number_example = self.dataset.train._num_examples
327 updates_per_epoch = int(number_example / self.batch_size)
328 epoch_start = int(counter / updates_per_epoch)
329 for epoch in range(epoch_start, self.max_epoch):
330 widgets = ["epoch #%d|" % epoch,
331 Percentage(), Bar(), ETA()]
332 pbar = ProgressBar(maxval=updates_per_epoch,
333 widgets=widgets)
334 pbar.start()
335
336 if epoch % lr_decay_step == 0 and epoch != 0:
337 generator_lr *= 0.5
338 discriminator_lr *= 0.5
339
340 all_log_vals = []
341 for i in range(updates_per_epoch):
342 pbar.update(i)
343 # training d
344 images, wrong_images, embeddings, _, _ =\
345 self.dataset.train.next_batch(self.batch_size,
346 num_embedding)
347 feed_dict = {self.images: images,
348 self.wrong_images: wrong_images,
349 self.embeddings: embeddings,
350 self.generator_lr: generator_lr,
351 self.discriminator_lr: discriminator_lr
352 }
353 # train d
354 feed_out = [self.discriminator_trainer,
355 self.d_sum,
356 self.hist_sum,
357 log_vars]
358 _, d_sum, hist_sum, log_vals = sess.run(feed_out,
359 feed_dict)

Callers 1

run_exp.pyFile · 0.45

Calls 3

build_modelMethod · 0.95
epoch_sum_imagesMethod · 0.95
next_batchMethod · 0.80

Tested by

no test coverage detected