| 248 | } |
| 249 | |
| 250 | def step(self): |
| 251 | # Run the learning step. |
| 252 | fetches = self._step() |
| 253 | |
| 254 | # Compute elapsed time. |
| 255 | timestamp = time.time() |
| 256 | elapsed_time = timestamp - self._timestamp if self._timestamp else 0 |
| 257 | self._timestamp = timestamp |
| 258 | |
| 259 | # Update our counts and record it. |
| 260 | counts = self._counter.increment(steps=1, walltime=elapsed_time) |
| 261 | fetches.update(counts) |
| 262 | |
| 263 | # Checkpoint and attempt to write the logs. |
| 264 | if self._checkpointer is not None: |
| 265 | self._checkpointer.save() |
| 266 | if self._snapshotter is not None: |
| 267 | self._snapshotter.save() |
| 268 | self._logger.write(fetches) |
| 269 | |
| 270 | def get_variables(self, names: List[str]) -> List[List[np.ndarray]]: |
| 271 | return [tf2_utils.to_numpy(self._variables[name]) for name in names] |