MCPcopy
hub / github.com/hanzhanggit/StackGAN / init_opt

Method init_opt

stageI/trainer.py:91–123  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

89 return c, cfg.TRAIN.COEFF.KL * kl_loss
90
91 def init_opt(self):
92 self.build_placeholder()
93
94 with pt.defaults_scope(phase=pt.Phase.train):
95 with tf.variable_scope("g_net"):
96 # ####get output from G network################################
97 c, kl_loss = self.sample_encoded_context(self.embeddings)
98 z = tf.random_normal([self.batch_size, cfg.Z_DIM])
99 self.log_vars.append(("hist_c", c))
100 self.log_vars.append(("hist_z", z))
101 fake_images = self.model.get_generator(tf.concat(1, [c, z]))
102
103 # ####get discriminator_loss and generator_loss ###################
104 discriminator_loss, generator_loss =\
105 self.compute_losses(self.images,
106 self.wrong_images,
107 fake_images,
108 self.embeddings)
109 generator_loss += kl_loss
110 self.log_vars.append(("g_loss_kl_loss", kl_loss))
111 self.log_vars.append(("g_loss", generator_loss))
112 self.log_vars.append(("d_loss", discriminator_loss))
113
114 # #######Total loss for build optimizers###########################
115 self.prepare_trainer(generator_loss, discriminator_loss)
116 # #######define self.g_sum, self.d_sum,....########################
117 self.define_summaries()
118
119 with pt.defaults_scope(phase=pt.Phase.test):
120 with tf.variable_scope("g_net", reuse=True):
121 self.sampler()
122 self.visualization(cfg.TRAIN.NUM_COPY)
123 print("success")
124
125 def sampler(self):
126 c, _ = self.sample_encoded_context(self.embeddings)

Callers 2

build_modelMethod · 0.95
evaluateMethod · 0.95

Calls 8

build_placeholderMethod · 0.95
compute_lossesMethod · 0.95
prepare_trainerMethod · 0.95
define_summariesMethod · 0.95
samplerMethod · 0.95
visualizationMethod · 0.95
get_generatorMethod · 0.45

Tested by

no test coverage detected