MCPcopy
hub / github.com/hpcaitech/ColossalAI / get_static_torch_model

Function get_static_torch_model

colossalai/zero/gemini/utils.py:64–107  ·  view source on GitHub ↗

Get a static torch.nn.Module model from the given GeminiDDP module. You should notice that the original GeminiDDP model is not modified. Thus, you can use the original model in further training. But you should not use the returned torch model to train, this can cause unexpected errors.

(
    zero_ddp_model, device=torch.device("cpu"), dtype=torch.float32, only_rank_0=True
)

Source from the content-addressed store, hash-verified

62
63
64def get_static_torch_model(
65 zero_ddp_model, device=torch.device("cpu"), dtype=torch.float32, only_rank_0=True
66) -> torch.nn.Module:
67 """Get a static torch.nn.Module model from the given GeminiDDP module.
68 You should notice that the original GeminiDDP model is not modified.
69 Thus, you can use the original model in further training.
70 But you should not use the returned torch model to train, this can cause unexpected errors.
71
72 Args:
73 zero_ddp_model (GeminiDDP): a zero ddp model
74 device (torch.device): the device of the final torch model
75 dtype (torch.dtype): the dtype of the final torch model
76 only_rank_0 (bool): if True, only rank0 has the converted torch model
77
78 Returns:
79 torch.nn.Module: a static torch model used for saving checkpoints or numeric checks
80 """
81 from colossalai.zero.gemini.gemini_ddp import GeminiDDP
82
83 assert isinstance(zero_ddp_model, GeminiDDP)
84
85 state_dict = zero_ddp_model.state_dict(only_rank_0=only_rank_0)
86 colo_model = zero_ddp_model.module
87 torch_model = _get_shallow_copy_model(colo_model)
88
89 if not only_rank_0 or dist.get_rank() == 0:
90 for (name, colo_module), (_, torch_module) in zip(
91 _get_dfs_module_list(colo_model), _get_dfs_module_list(torch_model)
92 ):
93 # clean the parameter list of the new torch module
94 torch_module._parameters = OrderedDict()
95 for sufix_param_name, param in colo_module.named_parameters(recurse=False):
96 # get the full name of the parameter
97 full_param_name = name + ("." if name else "") + sufix_param_name
98 assert (
99 full_param_name in state_dict
100 ), f"Can not find parameter `{full_param_name}` in the GeminiDDP module"
101 state_param = state_dict[full_param_name]
102 torch_param = torch.nn.Parameter(state_param.data.to(device=device, dtype=dtype))
103
104 setattr(torch_module, sufix_param_name, torch_param)
105 dist.barrier()
106
107 return torch_model

Callers

nothing calls this directly

Calls 7

_get_shallow_copy_modelFunction · 0.85
_get_dfs_module_listFunction · 0.85
get_rankMethod · 0.80
deviceMethod · 0.45
state_dictMethod · 0.45
named_parametersMethod · 0.45
toMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…