Run the main training loop. Args: steps_per_epoch, starting_epoch, max_epoch (int):
(self, steps_per_epoch, starting_epoch, max_epoch)
| 254 | |
| 255 | @call_only_once |
| 256 | def main_loop(self, steps_per_epoch, starting_epoch, max_epoch): |
| 257 | """ |
| 258 | Run the main training loop. |
| 259 | |
| 260 | Args: |
| 261 | steps_per_epoch, starting_epoch, max_epoch (int): |
| 262 | """ |
| 263 | with self.sess.as_default(): |
| 264 | self.loop.config(steps_per_epoch, starting_epoch, max_epoch) |
| 265 | self.loop.update_global_step() |
| 266 | try: |
| 267 | self._callbacks.before_train() |
| 268 | # refresh global step (might have changed by callbacks) TODO ugly |
| 269 | # what if gs is changed later? |
| 270 | self.loop.update_global_step() |
| 271 | for self.loop._epoch_num in range( |
| 272 | self.loop.starting_epoch, self.loop.max_epoch + 1): |
| 273 | logger.info("Start Epoch {} ...".format(self.loop.epoch_num)) |
| 274 | self._callbacks.before_epoch() |
| 275 | start_time = time.time() |
| 276 | for self.loop._local_step in range(self.loop.steps_per_epoch): |
| 277 | if self.hooked_sess.should_stop(): |
| 278 | return |
| 279 | self.run_step() # implemented by subclass |
| 280 | self._callbacks.trigger_step() |
| 281 | self._callbacks.after_epoch() |
| 282 | logger.info("Epoch {} (global_step {}) finished, time:{}.".format( |
| 283 | self.loop.epoch_num, self.loop.global_step, humanize_time_delta(time.time() - start_time))) |
| 284 | |
| 285 | # trigger epoch outside the timing region. |
| 286 | self._callbacks.trigger_epoch() |
| 287 | logger.info("Training has finished!") |
| 288 | except (StopTraining, tf.errors.OutOfRangeError) as e: |
| 289 | logger.info("Training was stopped by exception {}.".format(str(e))) |
| 290 | except KeyboardInterrupt: |
| 291 | logger.info("Detected Ctrl-C and exiting main loop.") |
| 292 | raise |
| 293 | except Exception: |
| 294 | logger.error("Training failed at global_step=", self.loop.global_step) |
| 295 | raise |
| 296 | finally: |
| 297 | self._callbacks.after_train() |
| 298 | self.hooked_sess.close() |
| 299 | |
| 300 | def train(self, |
| 301 | callbacks, monitors, |
no test coverage detected