MCPcopy Index your code
hub / github.com/tensorpack/tensorpack / SeparateGANTrainer

Class SeparateGANTrainer

examples/GAN/GAN.py:152–192  ·  view source on GitHub ↗

A GAN trainer which runs two optimization ops with a certain ratio.

Source from the content-addressed store, hash-verified

150
151
152class SeparateGANTrainer(TowerTrainer):
153 """ A GAN trainer which runs two optimization ops with a certain ratio."""
154 def __init__(self, input, model, d_period=1, g_period=1):
155 """
156 Args:
157 d_period(int): period of each d_opt run
158 g_period(int): period of each g_opt run
159 """
160 super(SeparateGANTrainer, self).__init__()
161 self._d_period = int(d_period)
162 self._g_period = int(g_period)
163 assert min(d_period, g_period) == 1
164
165 # Setup input
166 cbs = input.setup(model.get_input_signature())
167 self.register_callback(cbs)
168
169 # Build the graph
170 self.tower_func = TowerFunc(model.build_graph, model.inputs())
171 with TowerContext('', is_training=True), \
172 argscope(BatchNorm, ema_update='internal'):
173 # should not hook the EMA updates to both train_op, it will hurt training speed.
174 self.tower_func(*input.get_input_tensors())
175 update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
176 if len(update_ops):
177 logger.warn("Found {} ops in UPDATE_OPS collection!".format(len(update_ops)))
178 logger.warn("Using SeparateGANTrainer with UPDATE_OPS may hurt your training speed a lot!")
179
180 opt = model.get_optimizer()
181 with tf.name_scope('optimize'):
182 self.d_min = opt.minimize(
183 model.d_loss, var_list=model.d_vars, name='d_min')
184 self.g_min = opt.minimize(
185 model.g_loss, var_list=model.g_vars, name='g_min')
186
187 def run_step(self):
188 # Define the training iteration
189 if self.global_step % (self._d_period) == 0:
190 self.hooked_sess.run(self.d_min)
191 if self.global_step % (self._g_period) == 0:
192 self.hooked_sess.run(self.g_min)
193
194
195class RandomZData(DataFlow):

Callers 3

DiscoGAN-CelebA.pyFile · 0.90
WGAN.pyFile · 0.90
Improved-WGAN.pyFile · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected