Dumps out module to ``folder`` with ``module_name`` so that it can be imported with ``from import `` Args: folder (Union[str, os.PathLike]): The folder to write the code out to module_name (str): Top-level name to use for the ``Module`
(self, folder: Union[str, os.PathLike], module_name: str = "FxModule")
| 174 | return python_code |
| 175 | |
| 176 | def to_folder(self, folder: Union[str, os.PathLike], module_name: str = "FxModule"): |
| 177 | """Dumps out module to ``folder`` with ``module_name`` so that it can be |
| 178 | imported with ``from <folder> import <module_name>`` |
| 179 | |
| 180 | Args: |
| 181 | |
| 182 | folder (Union[str, os.PathLike]): The folder to write the code out to |
| 183 | |
| 184 | module_name (str): Top-level name to use for the ``Module`` while |
| 185 | writing out the code |
| 186 | """ |
| 187 | folder = Path(folder) |
| 188 | Path(folder).mkdir(exist_ok=True) |
| 189 | torch.save(self.state_dict(), folder / "state_dict.pt") |
| 190 | tab = " " * 4 |
| 191 | |
| 192 | # we add import colossalai here |
| 193 | model_str = f""" |
| 194 | import torch |
| 195 | from torch.nn import * |
| 196 | import colossalai |
| 197 | |
| 198 | |
| 199 | class {module_name}(torch.nn.Module): |
| 200 | def __init__(self): |
| 201 | super().__init__() |
| 202 | """ |
| 203 | |
| 204 | def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]: |
| 205 | safe_reprs = [nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d] |
| 206 | if type(module) in safe_reprs: |
| 207 | return f"{module.__repr__()}" |
| 208 | else: |
| 209 | return None |
| 210 | |
| 211 | blobified_modules = [] |
| 212 | for module_name, module in self.named_children(): |
| 213 | module_str = _gen_model_repr(module_name, module) |
| 214 | if module_str is None: |
| 215 | module_file = folder / f"{module_name}.pt" |
| 216 | torch.save(module, module_file) |
| 217 | blobified_modules.append(module_name) |
| 218 | module_repr = module.__repr__().replace("\r", " ").replace("\n", " ") |
| 219 | module_str = f"torch.load(r'{module_file}') # {module_repr}" |
| 220 | model_str += f"{tab*2}self.{module_name} = {module_str}\n" |
| 221 | |
| 222 | for buffer_name, buffer in self._buffers.items(): |
| 223 | if buffer is None: |
| 224 | continue |
| 225 | model_str += f"{tab*2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}, dtype={buffer.dtype}))\n" |
| 226 | |
| 227 | for param_name, param in self._parameters.items(): |
| 228 | if param is None: |
| 229 | continue |
| 230 | model_str += f"{tab*2}self.{param_name} = torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype}))\n" |
| 231 | |
| 232 | model_str += f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n" |
| 233 | model_str += f"{_addindent(self.code, 4)}\n" |
nothing calls this directly
no test coverage detected