Simple `Action` that saves the outputs passed to `__call__`.
| 30 | |
| 31 | |
| 32 | class _OutputRecorderAction: |
| 33 | """Simple `Action` that saves the outputs passed to `__call__`.""" |
| 34 | |
| 35 | def __init__(self): |
| 36 | self.train_output = {} |
| 37 | |
| 38 | def __call__( |
| 39 | self, |
| 40 | output: Optional[Mapping[str, tf.Tensor]] = None) -> Mapping[str, Any]: |
| 41 | self.train_output = {k: v.numpy() for k, v in output.items() |
| 42 | } if output else {} |
| 43 | |
| 44 | |
| 45 | def run_benchmark( |