MCPcopy Index your code
hub / github.com/hanzhanggit/StackGAN / init_opt

Method init_opt

stageII/trainer.py:107–155  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

105 return c, cfg.TRAIN.COEFF.KL * kl_loss
106
107 def init_opt(self):
108 self.build_placeholder()
109
110 with pt.defaults_scope(phase=pt.Phase.train):
111 # ####get output from G network####################################
112 with tf.variable_scope("g_net"):
113 c, kl_loss = self.sample_encoded_context(self.embeddings)
114 z = tf.random_normal([self.batch_size, cfg.Z_DIM])
115 self.log_vars.append(("hist_c", c))
116 self.log_vars.append(("hist_z", z))
117 fake_images = self.model.get_generator(tf.concat(1, [c, z]))
118
119 # ####get discriminator_loss and generator_loss ###################
120 discriminator_loss, generator_loss =\
121 self.compute_losses(self.images,
122 self.wrong_images,
123 fake_images,
124 self.embeddings,
125 flag='lr')
126 generator_loss += kl_loss
127 self.log_vars.append(("g_loss_kl_loss", kl_loss))
128 self.log_vars.append(("g_loss", generator_loss))
129 self.log_vars.append(("d_loss", discriminator_loss))
130
131 # #### For hr_g and hr_d #########################################
132 with tf.variable_scope("hr_g_net"):
133 hr_c, hr_kl_loss = self.sample_encoded_context(self.embeddings)
134 self.log_vars.append(("hist_hr_c", hr_c))
135 hr_fake_images = self.model.hr_get_generator(fake_images, hr_c)
136 # get losses
137 hr_discriminator_loss, hr_generator_loss =\
138 self.compute_losses(self.hr_images,
139 self.hr_wrong_images,
140 hr_fake_images,
141 self.embeddings,
142 flag='hr')
143 hr_generator_loss += hr_kl_loss
144 self.log_vars.append(("hr_g_loss", hr_generator_loss))
145 self.log_vars.append(("hr_d_loss", hr_discriminator_loss))
146
147 # #######define self.g_sum, self.d_sum,....########################
148 self.prepare_trainer(discriminator_loss, generator_loss,
149 hr_discriminator_loss, hr_generator_loss)
150 self.define_summaries()
151
152 with pt.defaults_scope(phase=pt.Phase.test):
153 self.sampler()
154 self.visualization(cfg.TRAIN.NUM_COPY)
155 print("success")
156
157 def sampler(self):
158 with tf.variable_scope("g_net", reuse=True):

Callers 2

build_modelMethod · 0.95
evaluateMethod · 0.95

Calls 9

build_placeholderMethod · 0.95
compute_lossesMethod · 0.95
prepare_trainerMethod · 0.95
define_summariesMethod · 0.95
samplerMethod · 0.95
visualizationMethod · 0.95
hr_get_generatorMethod · 0.80
get_generatorMethod · 0.45

Tested by

no test coverage detected