Get a static torch.nn.Module model from the given GeminiDDP module. You should notice that the original GeminiDDP model is not modified. Thus, you can use the original model in further training. But you should not use the returned torch model to train, this can cause unexpected errors.
(
zero_ddp_model, device=torch.device("cpu"), dtype=torch.float32, only_rank_0=True
)
| 62 | |
| 63 | |
| 64 | def get_static_torch_model( |
| 65 | zero_ddp_model, device=torch.device("cpu"), dtype=torch.float32, only_rank_0=True |
| 66 | ) -> torch.nn.Module: |
| 67 | """Get a static torch.nn.Module model from the given GeminiDDP module. |
| 68 | You should notice that the original GeminiDDP model is not modified. |
| 69 | Thus, you can use the original model in further training. |
| 70 | But you should not use the returned torch model to train, this can cause unexpected errors. |
| 71 | |
| 72 | Args: |
| 73 | zero_ddp_model (GeminiDDP): a zero ddp model |
| 74 | device (torch.device): the device of the final torch model |
| 75 | dtype (torch.dtype): the dtype of the final torch model |
| 76 | only_rank_0 (bool): if True, only rank0 has the converted torch model |
| 77 | |
| 78 | Returns: |
| 79 | torch.nn.Module: a static torch model used for saving checkpoints or numeric checks |
| 80 | """ |
| 81 | from colossalai.zero.gemini.gemini_ddp import GeminiDDP |
| 82 | |
| 83 | assert isinstance(zero_ddp_model, GeminiDDP) |
| 84 | |
| 85 | state_dict = zero_ddp_model.state_dict(only_rank_0=only_rank_0) |
| 86 | colo_model = zero_ddp_model.module |
| 87 | torch_model = _get_shallow_copy_model(colo_model) |
| 88 | |
| 89 | if not only_rank_0 or dist.get_rank() == 0: |
| 90 | for (name, colo_module), (_, torch_module) in zip( |
| 91 | _get_dfs_module_list(colo_model), _get_dfs_module_list(torch_model) |
| 92 | ): |
| 93 | # clean the parameter list of the new torch module |
| 94 | torch_module._parameters = OrderedDict() |
| 95 | for sufix_param_name, param in colo_module.named_parameters(recurse=False): |
| 96 | # get the full name of the parameter |
| 97 | full_param_name = name + ("." if name else "") + sufix_param_name |
| 98 | assert ( |
| 99 | full_param_name in state_dict |
| 100 | ), f"Can not find parameter `{full_param_name}` in the GeminiDDP module" |
| 101 | state_param = state_dict[full_param_name] |
| 102 | torch_param = torch.nn.Parameter(state_param.data.to(device=device, dtype=dtype)) |
| 103 | |
| 104 | setattr(torch_module, sufix_param_name, torch_param) |
| 105 | dist.barrier() |
| 106 | |
| 107 | return torch_model |
nothing calls this directly
no test coverage detected
searching dependent graphs…