Toggle DeepCompile runtime state and manage forward hooks accordingly.
(self, active: bool)
| 4915 | self._is_compiled_autograd_enabled = False |
| 4916 | |
| 4917 | def _set_deepcompile_active(self, active: bool) -> None: |
| 4918 | """Toggle DeepCompile runtime state and manage forward hooks accordingly.""" |
| 4919 | if self._deepcompile_active == active: |
| 4920 | return |
| 4921 | |
| 4922 | if active: |
| 4923 | if self.module_forward_pre_hook is not None: |
| 4924 | self.module_forward_pre_hook.remove() |
| 4925 | self.module_forward_pre_hook = None |
| 4926 | if self.module_forward_post_hook is not None: |
| 4927 | self.module_forward_post_hook.remove() |
| 4928 | self.module_forward_post_hook = None |
| 4929 | else: |
| 4930 | if self.module_forward_pre_hook is None: |
| 4931 | self.module_forward_pre_hook = self._create_module_forward_pre_hook() |
| 4932 | if self.module_forward_post_hook is None: |
| 4933 | self.module_forward_post_hook = self._create_module_forward_post_hook() |
| 4934 | |
| 4935 | self._deepcompile_active = active |
| 4936 | |
| 4937 | def get_compile_time(self): |
| 4938 | from deepspeed.compile.backend import opt_pass_times |
no test coverage detected