(self, images, wrong_images,
fake_images, embeddings, flag='lr')
| 165 | self.model.hr_get_generator(self.fake_images, hr_c) |
| 166 | |
| 167 | def compute_losses(self, images, wrong_images, |
| 168 | fake_images, embeddings, flag='lr'): |
| 169 | if flag == 'lr': |
| 170 | real_logit =\ |
| 171 | self.model.get_discriminator(images, embeddings) |
| 172 | wrong_logit =\ |
| 173 | self.model.get_discriminator(wrong_images, embeddings) |
| 174 | fake_logit =\ |
| 175 | self.model.get_discriminator(fake_images, embeddings) |
| 176 | else: |
| 177 | real_logit =\ |
| 178 | self.model.hr_get_discriminator(images, embeddings) |
| 179 | wrong_logit =\ |
| 180 | self.model.hr_get_discriminator(wrong_images, embeddings) |
| 181 | fake_logit =\ |
| 182 | self.model.hr_get_discriminator(fake_images, embeddings) |
| 183 | |
| 184 | real_d_loss =\ |
| 185 | tf.nn.sigmoid_cross_entropy_with_logits(real_logit, |
| 186 | tf.ones_like(real_logit)) |
| 187 | real_d_loss = tf.reduce_mean(real_d_loss) |
| 188 | wrong_d_loss =\ |
| 189 | tf.nn.sigmoid_cross_entropy_with_logits(wrong_logit, |
| 190 | tf.zeros_like(wrong_logit)) |
| 191 | wrong_d_loss = tf.reduce_mean(wrong_d_loss) |
| 192 | fake_d_loss =\ |
| 193 | tf.nn.sigmoid_cross_entropy_with_logits(fake_logit, |
| 194 | tf.zeros_like(fake_logit)) |
| 195 | fake_d_loss = tf.reduce_mean(fake_d_loss) |
| 196 | if cfg.TRAIN.B_WRONG: |
| 197 | discriminator_loss =\ |
| 198 | real_d_loss + (wrong_d_loss + fake_d_loss) / 2. |
| 199 | else: |
| 200 | discriminator_loss = real_d_loss + fake_d_loss |
| 201 | if flag == 'lr': |
| 202 | self.log_vars.append(("d_loss_real", real_d_loss)) |
| 203 | self.log_vars.append(("d_loss_fake", fake_d_loss)) |
| 204 | if cfg.TRAIN.B_WRONG: |
| 205 | self.log_vars.append(("d_loss_wrong", wrong_d_loss)) |
| 206 | else: |
| 207 | self.log_vars.append(("hr_d_loss_real", real_d_loss)) |
| 208 | self.log_vars.append(("hr_d_loss_fake", fake_d_loss)) |
| 209 | if cfg.TRAIN.B_WRONG: |
| 210 | self.log_vars.append(("hr_d_loss_wrong", wrong_d_loss)) |
| 211 | |
| 212 | generator_loss = \ |
| 213 | tf.nn.sigmoid_cross_entropy_with_logits(fake_logit, |
| 214 | tf.ones_like(fake_logit)) |
| 215 | generator_loss = tf.reduce_mean(generator_loss) |
| 216 | if flag == 'lr': |
| 217 | self.log_vars.append(("g_loss_fake", generator_loss)) |
| 218 | else: |
| 219 | self.log_vars.append(("hr_g_loss_fake", generator_loss)) |
| 220 | |
| 221 | return discriminator_loss, generator_loss |
| 222 | |
| 223 | def define_one_trainer(self, loss, learning_rate, key_word): |
| 224 | '''Helper function for init_opt''' |
no test coverage detected