A progress bar based on tqdm. This callback is one of the :func:`DEFAULT_CALLBACKS()`.
| 44 | |
| 45 | |
| 46 | class ProgressBar(Callback): |
| 47 | """ A progress bar based on tqdm. |
| 48 | |
| 49 | This callback is one of the :func:`DEFAULT_CALLBACKS()`. |
| 50 | """ |
| 51 | |
| 52 | _chief_only = False |
| 53 | |
| 54 | def __init__(self, names=()): |
| 55 | """ |
| 56 | Args: |
| 57 | names(tuple[str]): the names of the tensors to monitor |
| 58 | on the progress bar. |
| 59 | """ |
| 60 | super(ProgressBar, self).__init__() |
| 61 | self._names = [get_op_tensor_name(n)[1] for n in names] |
| 62 | self._tags = [get_op_tensor_name(n)[0].split("/")[-1] for n in names] |
| 63 | self._bar = None |
| 64 | |
| 65 | def _before_train(self): |
| 66 | self._last_updated = self.local_step |
| 67 | |
| 68 | self._total = self.trainer.steps_per_epoch |
| 69 | self._tqdm_args = get_tqdm_kwargs(leave=True) |
| 70 | |
| 71 | self._fetches = self.get_tensors_maybe_in_tower(self._names) or None |
| 72 | if self._fetches: |
| 73 | for t in self._fetches: |
| 74 | assert t.shape.ndims == 0, "ProgressBar can only print scalars, not {}".format(t) |
| 75 | self._fetches = tf.train.SessionRunArgs(self._fetches) |
| 76 | self._tqdm_args['bar_format'] = self._tqdm_args['bar_format'] + "{postfix} " |
| 77 | |
| 78 | def _before_epoch(self): |
| 79 | self._bar = tqdm.trange(self._total, **self._tqdm_args) |
| 80 | |
| 81 | def _after_epoch(self): |
| 82 | self._bar.close() |
| 83 | |
| 84 | def _before_run(self, _): |
| 85 | # update progress bar when local step changed (one step is finished) |
| 86 | if self.local_step != self._last_updated: |
| 87 | self._last_updated = self.local_step |
| 88 | return self._fetches |
| 89 | else: |
| 90 | return None |
| 91 | |
| 92 | def _after_run(self, _, run_values): |
| 93 | res = run_values.results |
| 94 | if res: |
| 95 | self._bar.set_postfix(zip(self._tags, res)) |
| 96 | |
| 97 | def _trigger_step(self): |
| 98 | self._bar.update() |
| 99 | |
| 100 | def _after_train(self): |
| 101 | if self._bar: # training may get killed before the first step |
| 102 | self._bar.close() |
| 103 |
no outgoing calls
no test coverage detected