MCPcopy Index your code
hub / github.com/tensorflow/models / _train_n_steps

Method _train_n_steps

orbit/controller.py:491–541  ·  view source on GitHub ↗

Runs training for `num_steps` steps. Also prints/logs updates about training progress, and summarizes training output (if output is returned from `self.trainer.train()`, and if `self.summary_dir` is set). Args: num_steps: An integer specifying how many steps of training to ru

(self, num_steps: int)

Source from the content-addressed store, hash-verified

489 return self._steps_per_loop
490
491 def _train_n_steps(self, num_steps: int):
492 """Runs training for `num_steps` steps.
493
494 Also prints/logs updates about training progress, and summarizes training
495 output (if output is returned from `self.trainer.train()`, and if
496 `self.summary_dir` is set).
497
498 Args:
499 num_steps: An integer specifying how many steps of training to run.
500
501 Raises:
502 RuntimeError: If `global_step` is not properly incremented by `num_steps`
503 after calling `self.trainer.train(num_steps)`.
504 """
505 if not self.step_timer:
506 self.step_timer = StepTimer(self.global_step)
507 current_step = self.global_step.numpy()
508
509 with self.summary_manager.summary_writer().as_default():
510 should_record = False # Allows static optimization in no-summary cases.
511 if self.summary_interval:
512 # Create a predicate to determine when summaries should be written.
513 should_record = lambda: (self.global_step % self.summary_interval == 0)
514 assert isinstance(self.trainer, runner.AbstractTrainer)
515 with tf.summary.record_if(should_record):
516 num_steps_tensor = tf.convert_to_tensor(num_steps, dtype=tf.int32)
517 train_output = self.trainer.train(num_steps_tensor)
518
519 # Verify that global_step was updated properly, then update current_step.
520 expected_step = current_step + num_steps
521 if self.global_step.numpy() != expected_step:
522 message = (
523 f"`trainer.train({num_steps})` did not update `global_step` by "
524 f"{num_steps}. Old value was {current_step}, expected updated value "
525 f"to be {expected_step}, but it was {self.global_step.numpy()}.")
526 logging.warning(message)
527
528 train_output = train_output or {}
529 for action in self.train_actions:
530 action(train_output)
531 train_output = tf.nest.map_structure(utils.get_value, train_output)
532
533 current_step = self.global_step.numpy()
534 steps_per_second = self.step_timer.steps_per_second()
535 _log(f"train | step: {current_step: 6d} | "
536 f"steps/sec: {steps_per_second: 6.1f} | "
537 f"output: {_format_output(train_output)}")
538
539 train_output["steps_per_second"] = steps_per_second
540 self.summary_manager.write_summaries(train_output)
541 self.summary_manager.flush()
542
543 def _maybe_save_checkpoint(self, check_interval: bool = True):
544 """Conditionally saves a checkpoint.

Callers 1

trainMethod · 0.95

Calls 8

StepTimerClass · 0.85
_logFunction · 0.85
_format_outputFunction · 0.85
steps_per_secondMethod · 0.80
summary_writerMethod · 0.45
trainMethod · 0.45
write_summariesMethod · 0.45
flushMethod · 0.45

Tested by

no test coverage detected