(self)
| 387 | } |
| 388 | |
| 389 | def step(self): |
| 390 | # Run the learning step. |
| 391 | with self._accelerator_strategy.scope(): |
| 392 | fetches = self._replicated_step() |
| 393 | |
| 394 | # Update our counts and record it. |
| 395 | new_timestamp = time.time() |
| 396 | time_passed = new_timestamp - self._walltime_timestamp |
| 397 | self._walltime_timestamp = new_timestamp |
| 398 | counts = self._counter.increment(steps=1, wall_time=time_passed) |
| 399 | fetches.update(counts) |
| 400 | |
| 401 | # Checkpoint and attempt to write the logs. |
| 402 | if self._checkpointer is not None: |
| 403 | self._checkpointer.save() |
| 404 | self._snapshotter.save() |
| 405 | self._logger.write(fetches) |
| 406 | |
| 407 | def get_variables(self, names: List[str]) -> List[List[np.ndarray]]: |
| 408 | return [tf2_utils.to_numpy(self._variables[name]) for name in names] |
nothing calls this directly
no test coverage detected