(self, images, wrong_images, fake_images, embeddings)
| 131 | self.fake_images = self.model.get_generator(tf.concat(1, [c, z])) |
| 132 | |
| 133 | def compute_losses(self, images, wrong_images, fake_images, embeddings): |
| 134 | real_logit = self.model.get_discriminator(images, embeddings) |
| 135 | wrong_logit = self.model.get_discriminator(wrong_images, embeddings) |
| 136 | fake_logit = self.model.get_discriminator(fake_images, embeddings) |
| 137 | |
| 138 | real_d_loss =\ |
| 139 | tf.nn.sigmoid_cross_entropy_with_logits(real_logit, |
| 140 | tf.ones_like(real_logit)) |
| 141 | real_d_loss = tf.reduce_mean(real_d_loss) |
| 142 | wrong_d_loss =\ |
| 143 | tf.nn.sigmoid_cross_entropy_with_logits(wrong_logit, |
| 144 | tf.zeros_like(wrong_logit)) |
| 145 | wrong_d_loss = tf.reduce_mean(wrong_d_loss) |
| 146 | fake_d_loss =\ |
| 147 | tf.nn.sigmoid_cross_entropy_with_logits(fake_logit, |
| 148 | tf.zeros_like(fake_logit)) |
| 149 | fake_d_loss = tf.reduce_mean(fake_d_loss) |
| 150 | if cfg.TRAIN.B_WRONG: |
| 151 | discriminator_loss =\ |
| 152 | real_d_loss + (wrong_d_loss + fake_d_loss) / 2. |
| 153 | self.log_vars.append(("d_loss_wrong", wrong_d_loss)) |
| 154 | else: |
| 155 | discriminator_loss = real_d_loss + fake_d_loss |
| 156 | self.log_vars.append(("d_loss_real", real_d_loss)) |
| 157 | self.log_vars.append(("d_loss_fake", fake_d_loss)) |
| 158 | |
| 159 | generator_loss = \ |
| 160 | tf.nn.sigmoid_cross_entropy_with_logits(fake_logit, |
| 161 | tf.ones_like(fake_logit)) |
| 162 | generator_loss = tf.reduce_mean(generator_loss) |
| 163 | |
| 164 | return discriminator_loss, generator_loss |
| 165 | |
| 166 | def prepare_trainer(self, generator_loss, discriminator_loss): |
| 167 | '''Helper function for init_opt''' |
no test coverage detected