Recompile this GraphModule from its ``graph`` attribute. This should be called after editing the contained ``graph``, otherwise the generated code of this ``GraphModule`` will be out of date.
(self)
| 55 | del globals_copy[func] |
| 56 | |
| 57 | def recompile(self) -> PythonCode: |
| 58 | """ |
| 59 | Recompile this GraphModule from its ``graph`` attribute. This should be |
| 60 | called after editing the contained ``graph``, otherwise the generated |
| 61 | code of this ``GraphModule`` will be out of date. |
| 62 | """ |
| 63 | if isinstance(self._graph._codegen, _PyTreeCodeGen): |
| 64 | self._in_spec = self._graph._codegen.pytree_info.in_spec |
| 65 | self._out_spec = self._graph._codegen.pytree_info.out_spec |
| 66 | python_code = self._graph.python_code(root_module="self") |
| 67 | self._code = python_code.src |
| 68 | |
| 69 | # To split ckpt functions code and forward code |
| 70 | _code_list = self._code.split("\n") |
| 71 | _fwd_def = [item for item in _code_list if "def forward" in item][0] |
| 72 | _fwd_idx = _code_list.index(_fwd_def) |
| 73 | ckpt_def = _code_list[:_fwd_idx] |
| 74 | self._code = "\n".join(_code_list[_fwd_idx:]) |
| 75 | |
| 76 | self.bind(ckpt_def, python_code.globals) |
| 77 | |
| 78 | cls = type(self) |
| 79 | cls.forward = _forward_from_src(self._code, python_code.globals) |
| 80 | |
| 81 | # Determine whether this class explicitly defines a __call__ implementation |
| 82 | # to wrap. If it does, save it in order to have wrapped_call invoke it. |
| 83 | # If it does not, wrapped_call can use a dynamic call to super() instead. |
| 84 | # In most cases, super().__call__ should be torch.nn.Module.__call__. |
| 85 | # We do not want to hold a reference to Module.__call__ here; doing so will |
| 86 | # bypass patching of torch.nn.Module.__call__ done while symbolic tracing. |
| 87 | cls_call = cls.__call__ if "__call__" in vars(cls) else None |
| 88 | |
| 89 | if "_wrapped_call" not in vars(cls): |
| 90 | cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined] |
| 91 | |
| 92 | def call_wrapped(self, *args, **kwargs): |
| 93 | return self._wrapped_call(self, *args, **kwargs) |
| 94 | |
| 95 | cls.__call__ = call_wrapped |
| 96 | |
| 97 | # reset self._code to original src, otherwise to_folder will be wrong |
| 98 | self._code = python_code.src |
| 99 | return python_code |
| 100 | |
| 101 | def to_folder(self, folder: Union[str, os.PathLike], module_name: str = "FxModule"): |
| 102 | """Dumps out module to ``folder`` with ``module_name`` so that it can be |