MCPcopy
hub / github.com/tensorflow/models / train

Method train

orbit/controller.py:257–289  ·  view source on GitHub ↗

Runs training until the specified global step count has been reached. This method makes calls to `self.trainer.train()` until the global step count is equal to `steps`. It will additionally save checkpoints (if a `CheckpointManager` was passed to `Controller.__init__`) and summarize

(self, steps: int, checkpoint_at_completion: bool = True)

Source from the content-addressed store, hash-verified

255 _orbit_api_gauge.get_cell().set(True)
256
257 def train(self, steps: int, checkpoint_at_completion: bool = True):
258 """Runs training until the specified global step count has been reached.
259
260 This method makes calls to `self.trainer.train()` until the global step
261 count is equal to `steps`. It will additionally save checkpoints (if a
262 `CheckpointManager` was passed to `Controller.__init__`) and summarize
263 training output (if `summary_dir` is set).
264
265 When async checkpointing is enabled, a sync is triggered at the end of this
266 method to make sure any ongoing async checkpoint saving is finished before
267 returning.
268
269 Args:
270 steps: The global step count to train up to.
271 checkpoint_at_completion: Whether to save a checkpoint when this method
272 returns (regardless of the checkpointing interval). Defaults to `True`.
273 """
274 self._require("trainer", for_method="train")
275
276 # TODO(momernick): Support steps=None or -1 (training to exhaustion).
277 current_step = self.global_step.numpy() # Cache, since this is expensive.
278 _log(f"train | step: {current_step: 6d} | training until step {steps}...")
279 while current_step < steps:
280 # Calculates steps to run for the next train loop.
281 num_steps = min(steps - current_step, self.steps_per_loop)
282 self._train_n_steps(num_steps)
283 self._maybe_save_checkpoint()
284 current_step = self.global_step.numpy()
285
286 if checkpoint_at_completion:
287 self._maybe_save_checkpoint(check_interval=False)
288
289 self._sync_on_async_checkpointing()
290
291 def evaluate(self, steps: int = -1) -> Optional[runner.Output]:
292 """Runs evaluation for the given number of steps.

Callers 15

run_experimentFunction · 0.95
run_experimentFunction · 0.95
runFunction · 0.95
run_benchmarkFunction · 0.95
mainFunction · 0.95
test_train_onlyMethod · 0.95
train_and_evaluateMethod · 0.95

Calls 5

_requireMethod · 0.95
_train_n_stepsMethod · 0.95
_logFunction · 0.85