MCPcopy
hub / github.com/hpcaitech/ColossalAI / recompile

Method recompile

colossalai/fx/graph_module.py:57–99  ·  view source on GitHub ↗

Recompile this GraphModule from its ``graph`` attribute. This should be called after editing the contained ``graph``, otherwise the generated code of this ``GraphModule`` will be out of date.

(self)

Source from the content-addressed store, hash-verified

55 del globals_copy[func]
56
57 def recompile(self) -> PythonCode:
58 """
59 Recompile this GraphModule from its ``graph`` attribute. This should be
60 called after editing the contained ``graph``, otherwise the generated
61 code of this ``GraphModule`` will be out of date.
62 """
63 if isinstance(self._graph._codegen, _PyTreeCodeGen):
64 self._in_spec = self._graph._codegen.pytree_info.in_spec
65 self._out_spec = self._graph._codegen.pytree_info.out_spec
66 python_code = self._graph.python_code(root_module="self")
67 self._code = python_code.src
68
69 # To split ckpt functions code and forward code
70 _code_list = self._code.split("\n")
71 _fwd_def = [item for item in _code_list if "def forward" in item][0]
72 _fwd_idx = _code_list.index(_fwd_def)
73 ckpt_def = _code_list[:_fwd_idx]
74 self._code = "\n".join(_code_list[_fwd_idx:])
75
76 self.bind(ckpt_def, python_code.globals)
77
78 cls = type(self)
79 cls.forward = _forward_from_src(self._code, python_code.globals)
80
81 # Determine whether this class explicitly defines a __call__ implementation
82 # to wrap. If it does, save it in order to have wrapped_call invoke it.
83 # If it does not, wrapped_call can use a dynamic call to super() instead.
84 # In most cases, super().__call__ should be torch.nn.Module.__call__.
85 # We do not want to hold a reference to Module.__call__ here; doing so will
86 # bypass patching of torch.nn.Module.__call__ done while symbolic tracing.
87 cls_call = cls.__call__ if "__call__" in vars(cls) else None
88
89 if "_wrapped_call" not in vars(cls):
90 cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined]
91
92 def call_wrapped(self, *args, **kwargs):
93 return self._wrapped_call(self, *args, **kwargs)
94
95 cls.__call__ = call_wrapped
96
97 # reset self._code to original src, otherwise to_folder will be wrong
98 self._code = python_code.src
99 return python_code
100
101 def to_folder(self, folder: Union[str, os.PathLike], module_name: str = "FxModule"):
102 """Dumps out module to ``folder`` with ``module_name`` so that it can be

Callers 15

initialize_modelFunction · 0.95
test_linear_moduleFunction · 0.95
test_conv_moduleFunction · 0.95
_run_act_ckpt_codegenFunction · 0.95
_run_act_ckpt_codegenFunction · 0.95
_run_offload_codegenFunction · 0.95
assert_codegen_runFunction · 0.95
assert_codegen_runFunction · 0.95

Calls 4

bindMethod · 0.95
_WrappedCallClass · 0.90
splitMethod · 0.80
indexMethod · 0.80

Tested by 15

test_linear_moduleFunction · 0.76
test_conv_moduleFunction · 0.76
_run_act_ckpt_codegenFunction · 0.76
_run_act_ckpt_codegenFunction · 0.76
_run_offload_codegenFunction · 0.76
assert_codegen_runFunction · 0.76
assert_codegen_runFunction · 0.76
assert_codegen_runFunction · 0.76
assert_codegen_runFunction · 0.76