Trigger a callback every k global steps or every k epochs by its :meth:`trigger()` method. Most existing callbacks which do something every epoch are implemented with :meth:`trigger()` method. By default the :meth:`trigger()` method will be called every epoch. This wrapper can make
| 8 | |
| 9 | |
| 10 | class PeriodicTrigger(ProxyCallback): |
| 11 | """ |
| 12 | Trigger a callback every k global steps or every k epochs by its :meth:`trigger()` method. |
| 13 | |
| 14 | Most existing callbacks which do something every epoch are implemented |
| 15 | with :meth:`trigger()` method. By default the :meth:`trigger()` method will be called every epoch. |
| 16 | This wrapper can make the callback run at a different frequency. |
| 17 | |
| 18 | All other methods (``before/after_run``, ``trigger_step``, etc) of the given callback |
| 19 | are unaffected. They will still be called as-is. |
| 20 | """ |
| 21 | |
| 22 | def __init__(self, triggerable, every_k_steps=None, every_k_epochs=None, before_train=False): |
| 23 | """ |
| 24 | Args: |
| 25 | triggerable (Callback): a Callback instance with a trigger method to be called. |
| 26 | every_k_steps (int): trigger when ``global_step % k == 0``. Set to |
| 27 | None to ignore. |
| 28 | every_k_epochs (int): trigger when ``epoch_num % k == 0``. Set to |
| 29 | None to ignore. |
| 30 | before_train (bool): trigger in the :meth:`before_train` method. |
| 31 | |
| 32 | every_k_steps and every_k_epochs can be both set, but cannot be both None unless before_train is True. |
| 33 | """ |
| 34 | assert isinstance(triggerable, Callback), type(triggerable) |
| 35 | super(PeriodicTrigger, self).__init__(triggerable) |
| 36 | if before_train is False: |
| 37 | assert (every_k_epochs is not None) or (every_k_steps is not None), \ |
| 38 | "Arguments to PeriodicTrigger have disabled the triggerable!" |
| 39 | self._step_k = every_k_steps |
| 40 | self._epoch_k = every_k_epochs |
| 41 | self._do_before_train = before_train |
| 42 | |
| 43 | def _before_train(self): |
| 44 | self.cb.before_train() |
| 45 | if self._do_before_train: |
| 46 | self.cb.trigger() |
| 47 | |
| 48 | def _trigger_step(self): |
| 49 | self.cb.trigger_step() |
| 50 | if self._step_k is None: |
| 51 | return |
| 52 | if self.global_step % self._step_k == 0: |
| 53 | self.cb.trigger() |
| 54 | |
| 55 | def _trigger_epoch(self): |
| 56 | if self._epoch_k is None: |
| 57 | return |
| 58 | if self.epoch_num % self._epoch_k == 0: |
| 59 | self.cb.trigger() |
| 60 | |
| 61 | def __str__(self): |
| 62 | return "PeriodicTrigger-" + str(self.cb) |
| 63 | |
| 64 | |
| 65 | class EnableCallbackIf(ProxyCallback): |
no outgoing calls
no test coverage detected