(self)
| 300 | return counter |
| 301 | |
| 302 | def train(self): |
| 303 | config = tf.ConfigProto(allow_soft_placement=True) |
| 304 | with tf.Session(config=config) as sess: |
| 305 | with tf.device("/gpu:%d" % cfg.GPU_ID): |
| 306 | counter = self.build_model(sess) |
| 307 | saver = tf.train.Saver(tf.all_variables(), |
| 308 | keep_checkpoint_every_n_hours=2) |
| 309 | |
| 310 | # summary_op = tf.merge_all_summaries() |
| 311 | summary_writer = tf.train.SummaryWriter(self.log_dir, |
| 312 | sess.graph) |
| 313 | |
| 314 | keys = ["d_loss", "g_loss"] |
| 315 | log_vars = [] |
| 316 | log_keys = [] |
| 317 | for k, v in self.log_vars: |
| 318 | if k in keys: |
| 319 | log_vars.append(v) |
| 320 | log_keys.append(k) |
| 321 | # print(k, v) |
| 322 | generator_lr = cfg.TRAIN.GENERATOR_LR |
| 323 | discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR |
| 324 | num_embedding = cfg.TRAIN.NUM_EMBEDDING |
| 325 | lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH |
| 326 | number_example = self.dataset.train._num_examples |
| 327 | updates_per_epoch = int(number_example / self.batch_size) |
| 328 | epoch_start = int(counter / updates_per_epoch) |
| 329 | for epoch in range(epoch_start, self.max_epoch): |
| 330 | widgets = ["epoch #%d|" % epoch, |
| 331 | Percentage(), Bar(), ETA()] |
| 332 | pbar = ProgressBar(maxval=updates_per_epoch, |
| 333 | widgets=widgets) |
| 334 | pbar.start() |
| 335 | |
| 336 | if epoch % lr_decay_step == 0 and epoch != 0: |
| 337 | generator_lr *= 0.5 |
| 338 | discriminator_lr *= 0.5 |
| 339 | |
| 340 | all_log_vals = [] |
| 341 | for i in range(updates_per_epoch): |
| 342 | pbar.update(i) |
| 343 | # training d |
| 344 | images, wrong_images, embeddings, _, _ =\ |
| 345 | self.dataset.train.next_batch(self.batch_size, |
| 346 | num_embedding) |
| 347 | feed_dict = {self.images: images, |
| 348 | self.wrong_images: wrong_images, |
| 349 | self.embeddings: embeddings, |
| 350 | self.generator_lr: generator_lr, |
| 351 | self.discriminator_lr: discriminator_lr |
| 352 | } |
| 353 | # train d |
| 354 | feed_out = [self.discriminator_trainer, |
| 355 | self.d_sum, |
| 356 | self.hist_sum, |
| 357 | log_vars] |
| 358 | _, d_sum, hist_sum, log_vals = sess.run(feed_out, |
| 359 | feed_dict) |
no test coverage detected