(self)
| 89 | return c, cfg.TRAIN.COEFF.KL * kl_loss |
| 90 | |
| 91 | def init_opt(self): |
| 92 | self.build_placeholder() |
| 93 | |
| 94 | with pt.defaults_scope(phase=pt.Phase.train): |
| 95 | with tf.variable_scope("g_net"): |
| 96 | # ####get output from G network################################ |
| 97 | c, kl_loss = self.sample_encoded_context(self.embeddings) |
| 98 | z = tf.random_normal([self.batch_size, cfg.Z_DIM]) |
| 99 | self.log_vars.append(("hist_c", c)) |
| 100 | self.log_vars.append(("hist_z", z)) |
| 101 | fake_images = self.model.get_generator(tf.concat(1, [c, z])) |
| 102 | |
| 103 | # ####get discriminator_loss and generator_loss ################### |
| 104 | discriminator_loss, generator_loss =\ |
| 105 | self.compute_losses(self.images, |
| 106 | self.wrong_images, |
| 107 | fake_images, |
| 108 | self.embeddings) |
| 109 | generator_loss += kl_loss |
| 110 | self.log_vars.append(("g_loss_kl_loss", kl_loss)) |
| 111 | self.log_vars.append(("g_loss", generator_loss)) |
| 112 | self.log_vars.append(("d_loss", discriminator_loss)) |
| 113 | |
| 114 | # #######Total loss for build optimizers########################### |
| 115 | self.prepare_trainer(generator_loss, discriminator_loss) |
| 116 | # #######define self.g_sum, self.d_sum,....######################## |
| 117 | self.define_summaries() |
| 118 | |
| 119 | with pt.defaults_scope(phase=pt.Phase.test): |
| 120 | with tf.variable_scope("g_net", reuse=True): |
| 121 | self.sampler() |
| 122 | self.visualization(cfg.TRAIN.NUM_COPY) |
| 123 | print("success") |
| 124 | |
| 125 | def sampler(self): |
| 126 | c, _ = self.sample_encoded_context(self.embeddings) |
no test coverage detected