MCPcopy
hub / github.com/hpcaitech/ColossalAI / to_folder

Method to_folder

colossalai/_analyzer/fx/graph_module.py:176–245  ·  view source on GitHub ↗

Dumps out module to ``folder`` with ``module_name`` so that it can be imported with ``from import `` Args: folder (Union[str, os.PathLike]): The folder to write the code out to module_name (str): Top-level name to use for the ``Module`

(self, folder: Union[str, os.PathLike], module_name: str = "FxModule")

Source from the content-addressed store, hash-verified

174 return python_code
175
176 def to_folder(self, folder: Union[str, os.PathLike], module_name: str = "FxModule"):
177 """Dumps out module to ``folder`` with ``module_name`` so that it can be
178 imported with ``from <folder> import <module_name>``
179
180 Args:
181
182 folder (Union[str, os.PathLike]): The folder to write the code out to
183
184 module_name (str): Top-level name to use for the ``Module`` while
185 writing out the code
186 """
187 folder = Path(folder)
188 Path(folder).mkdir(exist_ok=True)
189 torch.save(self.state_dict(), folder / "state_dict.pt")
190 tab = " " * 4
191
192 # we add import colossalai here
193 model_str = f"""
194import torch
195from torch.nn import *
196import colossalai
197
198
199class {module_name}(torch.nn.Module):
200 def __init__(self):
201 super().__init__()
202"""
203
204 def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]:
205 safe_reprs = [nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d]
206 if type(module) in safe_reprs:
207 return f"{module.__repr__()}"
208 else:
209 return None
210
211 blobified_modules = []
212 for module_name, module in self.named_children():
213 module_str = _gen_model_repr(module_name, module)
214 if module_str is None:
215 module_file = folder / f"{module_name}.pt"
216 torch.save(module, module_file)
217 blobified_modules.append(module_name)
218 module_repr = module.__repr__().replace("\r", " ").replace("\n", " ")
219 module_str = f"torch.load(r'{module_file}') # {module_repr}"
220 model_str += f"{tab*2}self.{module_name} = {module_str}\n"
221
222 for buffer_name, buffer in self._buffers.items():
223 if buffer is None:
224 continue
225 model_str += f"{tab*2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}, dtype={buffer.dtype}))\n"
226
227 for param_name, param in self._parameters.items():
228 if param is None:
229 continue
230 model_str += f"{tab*2}self.{param_name} = torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype}))\n"
231
232 model_str += f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n"
233 model_str += f"{_addindent(self.code, 4)}\n"

Callers

nothing calls this directly

Calls 6

saveMethod · 0.45
state_dictMethod · 0.45
named_childrenMethod · 0.45
appendMethod · 0.45
replaceMethod · 0.45
__repr__Method · 0.45

Tested by

no test coverage detected