A GAN trainer which runs two optimization ops with a certain ratio.
| 150 | |
| 151 | |
| 152 | class SeparateGANTrainer(TowerTrainer): |
| 153 | """ A GAN trainer which runs two optimization ops with a certain ratio.""" |
| 154 | def __init__(self, input, model, d_period=1, g_period=1): |
| 155 | """ |
| 156 | Args: |
| 157 | d_period(int): period of each d_opt run |
| 158 | g_period(int): period of each g_opt run |
| 159 | """ |
| 160 | super(SeparateGANTrainer, self).__init__() |
| 161 | self._d_period = int(d_period) |
| 162 | self._g_period = int(g_period) |
| 163 | assert min(d_period, g_period) == 1 |
| 164 | |
| 165 | # Setup input |
| 166 | cbs = input.setup(model.get_input_signature()) |
| 167 | self.register_callback(cbs) |
| 168 | |
| 169 | # Build the graph |
| 170 | self.tower_func = TowerFunc(model.build_graph, model.inputs()) |
| 171 | with TowerContext('', is_training=True), \ |
| 172 | argscope(BatchNorm, ema_update='internal'): |
| 173 | # should not hook the EMA updates to both train_op, it will hurt training speed. |
| 174 | self.tower_func(*input.get_input_tensors()) |
| 175 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) |
| 176 | if len(update_ops): |
| 177 | logger.warn("Found {} ops in UPDATE_OPS collection!".format(len(update_ops))) |
| 178 | logger.warn("Using SeparateGANTrainer with UPDATE_OPS may hurt your training speed a lot!") |
| 179 | |
| 180 | opt = model.get_optimizer() |
| 181 | with tf.name_scope('optimize'): |
| 182 | self.d_min = opt.minimize( |
| 183 | model.d_loss, var_list=model.d_vars, name='d_min') |
| 184 | self.g_min = opt.minimize( |
| 185 | model.g_loss, var_list=model.g_vars, name='g_min') |
| 186 | |
| 187 | def run_step(self): |
| 188 | # Define the training iteration |
| 189 | if self.global_step % (self._d_period) == 0: |
| 190 | self.hooked_sess.run(self.d_min) |
| 191 | if self.global_step % (self._g_period) == 0: |
| 192 | self.hooked_sess.run(self.g_min) |
| 193 | |
| 194 | |
| 195 | class RandomZData(DataFlow): |
no outgoing calls
no test coverage detected