Swap out the graph depending on the selected training mode. NOTE this should be safe when calling model.eval() because that just calls this with mode == False.
(self, mode=True)
| 318 | self._tracer_cls = self.graph._tracer_cls |
| 319 | |
| 320 | def train(self, mode=True): |
| 321 | """ |
| 322 | Swap out the graph depending on the selected training mode. |
| 323 | NOTE this should be safe when calling model.eval() because that just |
| 324 | calls this with mode == False. |
| 325 | """ |
| 326 | # NOTE: Only set self.graph if the current graph is not the desired |
| 327 | # one. This saves us from recompiling the graph where not necessary. |
| 328 | if mode and not self.training: |
| 329 | self.graph = self.train_graph |
| 330 | elif not mode and self.training: |
| 331 | self.graph = self.eval_graph |
| 332 | return super().train(mode=mode) |
| 333 | |
| 334 | def _deepcopy_init(self): |
| 335 | # See __deepcopy__ below |
no outgoing calls