(self)
| 105 | return c, cfg.TRAIN.COEFF.KL * kl_loss |
| 106 | |
| 107 | def init_opt(self): |
| 108 | self.build_placeholder() |
| 109 | |
| 110 | with pt.defaults_scope(phase=pt.Phase.train): |
| 111 | # ####get output from G network#################################### |
| 112 | with tf.variable_scope("g_net"): |
| 113 | c, kl_loss = self.sample_encoded_context(self.embeddings) |
| 114 | z = tf.random_normal([self.batch_size, cfg.Z_DIM]) |
| 115 | self.log_vars.append(("hist_c", c)) |
| 116 | self.log_vars.append(("hist_z", z)) |
| 117 | fake_images = self.model.get_generator(tf.concat(1, [c, z])) |
| 118 | |
| 119 | # ####get discriminator_loss and generator_loss ################### |
| 120 | discriminator_loss, generator_loss =\ |
| 121 | self.compute_losses(self.images, |
| 122 | self.wrong_images, |
| 123 | fake_images, |
| 124 | self.embeddings, |
| 125 | flag='lr') |
| 126 | generator_loss += kl_loss |
| 127 | self.log_vars.append(("g_loss_kl_loss", kl_loss)) |
| 128 | self.log_vars.append(("g_loss", generator_loss)) |
| 129 | self.log_vars.append(("d_loss", discriminator_loss)) |
| 130 | |
| 131 | # #### For hr_g and hr_d ######################################### |
| 132 | with tf.variable_scope("hr_g_net"): |
| 133 | hr_c, hr_kl_loss = self.sample_encoded_context(self.embeddings) |
| 134 | self.log_vars.append(("hist_hr_c", hr_c)) |
| 135 | hr_fake_images = self.model.hr_get_generator(fake_images, hr_c) |
| 136 | # get losses |
| 137 | hr_discriminator_loss, hr_generator_loss =\ |
| 138 | self.compute_losses(self.hr_images, |
| 139 | self.hr_wrong_images, |
| 140 | hr_fake_images, |
| 141 | self.embeddings, |
| 142 | flag='hr') |
| 143 | hr_generator_loss += hr_kl_loss |
| 144 | self.log_vars.append(("hr_g_loss", hr_generator_loss)) |
| 145 | self.log_vars.append(("hr_d_loss", hr_discriminator_loss)) |
| 146 | |
| 147 | # #######define self.g_sum, self.d_sum,....######################## |
| 148 | self.prepare_trainer(discriminator_loss, generator_loss, |
| 149 | hr_discriminator_loss, hr_generator_loss) |
| 150 | self.define_summaries() |
| 151 | |
| 152 | with pt.defaults_scope(phase=pt.Phase.test): |
| 153 | self.sampler() |
| 154 | self.visualization(cfg.TRAIN.NUM_COPY) |
| 155 | print("success") |
| 156 | |
| 157 | def sampler(self): |
| 158 | with tf.variable_scope("g_net", reuse=True): |
no test coverage detected