Compile the module using the specified backend and kwargs. If a compiler_fn is set, it will be used instead of torch.compile().
(self,
backend=get_accelerator().get_compile_backend(),
compile_kwargs={},
schedule=None,
compiled_autograd_enabled=False)
| 4863 | return resolved_backend, schedule |
| 4864 | |
| 4865 | def compile(self, |
| 4866 | backend=get_accelerator().get_compile_backend(), |
| 4867 | compile_kwargs={}, |
| 4868 | schedule=None, |
| 4869 | compiled_autograd_enabled=False) -> None: |
| 4870 | """Compile the module using the specified backend and kwargs. |
| 4871 | If a compiler_fn is set, it will be used instead of torch.compile(). |
| 4872 | """ |
| 4873 | # Avoid graph breaks |
| 4874 | deepspeed.utils.nvtx.enable_nvtx = False |
| 4875 | |
| 4876 | if not is_compile_supported(): |
| 4877 | raise RuntimeError("compile is not supported in your version of PyTorch.") |
| 4878 | |
| 4879 | if self.is_compiled: |
| 4880 | return |
| 4881 | |
| 4882 | if 'backend' in compile_kwargs: |
| 4883 | logger.warning("The `backend` in `compile_kwargs` will be overridden. Use the `backend` argument instead.") |
| 4884 | |
| 4885 | logger.info(f"Compiling deepcompile={self.is_deepcompile_enabled()} backend={backend}") |
| 4886 | |
| 4887 | resolved_backend = None |
| 4888 | if self.is_deepcompile_enabled(): |
| 4889 | resolved_backend, schedule = self.get_deepspeed_compile_backend(backend, compile_kwargs, schedule) |
| 4890 | |
| 4891 | is_deepspeed_compile_backend = resolved_backend is not None |
| 4892 | |
| 4893 | # default to torch.compiler backend if deepspeed config validation fails |
| 4894 | backend = resolved_backend or backend |
| 4895 | |
| 4896 | # Hook state must align with whether DeepCompile is active. |
| 4897 | self._set_deepcompile_active(is_deepspeed_compile_backend) |
| 4898 | |
| 4899 | # create new dict to avoid modifying original dict |
| 4900 | try: |
| 4901 | self.module.compile(**{**compile_kwargs, 'backend': backend}) |
| 4902 | except Exception: |
| 4903 | if is_deepspeed_compile_backend: |
| 4904 | # Restore default hooks if compilation fails before completing. |
| 4905 | self._set_deepcompile_active(False) |
| 4906 | raise |
| 4907 | |
| 4908 | self._is_compiled = True |
| 4909 | self._compile_kwargs = compile_kwargs |
| 4910 | if compiled_autograd_enabled: |
| 4911 | if not self._deepcompile_active: |
| 4912 | self._is_compiled_autograd_enabled = compiled_autograd_enabled |
| 4913 | else: |
| 4914 | logger.warning("Compiled autograd is not compatible with DeepCompile, disabling compiled autograd.") |
| 4915 | self._is_compiled_autograd_enabled = False |
| 4916 | |
| 4917 | def _set_deepcompile_active(self, active: bool) -> None: |
| 4918 | """Toggle DeepCompile runtime state and manage forward hooks accordingly.""" |