MCPcopy Index your code
hub / github.com/tensorpack/tensorpack / ProgressBar

Class ProgressBar

tensorpack/callbacks/steps.py:46–102  ·  view source on GitHub ↗

A progress bar based on tqdm. This callback is one of the :func:`DEFAULT_CALLBACKS()`.

Source from the content-addressed store, hash-verified

44
45
46class ProgressBar(Callback):
47 """ A progress bar based on tqdm.
48
49 This callback is one of the :func:`DEFAULT_CALLBACKS()`.
50 """
51
52 _chief_only = False
53
54 def __init__(self, names=()):
55 """
56 Args:
57 names(tuple[str]): the names of the tensors to monitor
58 on the progress bar.
59 """
60 super(ProgressBar, self).__init__()
61 self._names = [get_op_tensor_name(n)[1] for n in names]
62 self._tags = [get_op_tensor_name(n)[0].split("/")[-1] for n in names]
63 self._bar = None
64
65 def _before_train(self):
66 self._last_updated = self.local_step
67
68 self._total = self.trainer.steps_per_epoch
69 self._tqdm_args = get_tqdm_kwargs(leave=True)
70
71 self._fetches = self.get_tensors_maybe_in_tower(self._names) or None
72 if self._fetches:
73 for t in self._fetches:
74 assert t.shape.ndims == 0, "ProgressBar can only print scalars, not {}".format(t)
75 self._fetches = tf.train.SessionRunArgs(self._fetches)
76 self._tqdm_args['bar_format'] = self._tqdm_args['bar_format'] + "{postfix} "
77
78 def _before_epoch(self):
79 self._bar = tqdm.trange(self._total, **self._tqdm_args)
80
81 def _after_epoch(self):
82 self._bar.close()
83
84 def _before_run(self, _):
85 # update progress bar when local step changed (one step is finished)
86 if self.local_step != self._last_updated:
87 self._last_updated = self.local_step
88 return self._fetches
89 else:
90 return None
91
92 def _after_run(self, _, run_values):
93 res = run_values.results
94 if res:
95 self._bar.set_postfix(zip(self._tags, res))
96
97 def _trigger_step(self):
98 self._bar.update()
99
100 def _after_train(self):
101 if self._bar: # training may get killed before the first step
102 self._bar.close()
103

Callers 2

DEFAULT_CALLBACKSFunction · 0.85
get_configFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected