MCPcopy Index your code
hub / github.com/hanzhanggit/StackGAN / compute_losses

Method compute_losses

stageII/trainer.py:167–221  ·  view source on GitHub ↗
(self, images, wrong_images,
                       fake_images, embeddings, flag='lr')

Source from the content-addressed store, hash-verified

165 self.model.hr_get_generator(self.fake_images, hr_c)
166
167 def compute_losses(self, images, wrong_images,
168 fake_images, embeddings, flag='lr'):
169 if flag == 'lr':
170 real_logit =\
171 self.model.get_discriminator(images, embeddings)
172 wrong_logit =\
173 self.model.get_discriminator(wrong_images, embeddings)
174 fake_logit =\
175 self.model.get_discriminator(fake_images, embeddings)
176 else:
177 real_logit =\
178 self.model.hr_get_discriminator(images, embeddings)
179 wrong_logit =\
180 self.model.hr_get_discriminator(wrong_images, embeddings)
181 fake_logit =\
182 self.model.hr_get_discriminator(fake_images, embeddings)
183
184 real_d_loss =\
185 tf.nn.sigmoid_cross_entropy_with_logits(real_logit,
186 tf.ones_like(real_logit))
187 real_d_loss = tf.reduce_mean(real_d_loss)
188 wrong_d_loss =\
189 tf.nn.sigmoid_cross_entropy_with_logits(wrong_logit,
190 tf.zeros_like(wrong_logit))
191 wrong_d_loss = tf.reduce_mean(wrong_d_loss)
192 fake_d_loss =\
193 tf.nn.sigmoid_cross_entropy_with_logits(fake_logit,
194 tf.zeros_like(fake_logit))
195 fake_d_loss = tf.reduce_mean(fake_d_loss)
196 if cfg.TRAIN.B_WRONG:
197 discriminator_loss =\
198 real_d_loss + (wrong_d_loss + fake_d_loss) / 2.
199 else:
200 discriminator_loss = real_d_loss + fake_d_loss
201 if flag == 'lr':
202 self.log_vars.append(("d_loss_real", real_d_loss))
203 self.log_vars.append(("d_loss_fake", fake_d_loss))
204 if cfg.TRAIN.B_WRONG:
205 self.log_vars.append(("d_loss_wrong", wrong_d_loss))
206 else:
207 self.log_vars.append(("hr_d_loss_real", real_d_loss))
208 self.log_vars.append(("hr_d_loss_fake", fake_d_loss))
209 if cfg.TRAIN.B_WRONG:
210 self.log_vars.append(("hr_d_loss_wrong", wrong_d_loss))
211
212 generator_loss = \
213 tf.nn.sigmoid_cross_entropy_with_logits(fake_logit,
214 tf.ones_like(fake_logit))
215 generator_loss = tf.reduce_mean(generator_loss)
216 if flag == 'lr':
217 self.log_vars.append(("g_loss_fake", generator_loss))
218 else:
219 self.log_vars.append(("hr_g_loss_fake", generator_loss))
220
221 return discriminator_loss, generator_loss
222
223 def define_one_trainer(self, loss, learning_rate, key_word):
224 '''Helper function for init_opt'''

Callers 1

init_optMethod · 0.95

Calls 2

hr_get_discriminatorMethod · 0.80
get_discriminatorMethod · 0.45

Tested by

no test coverage detected