MCPcopy Index your code
hub / github.com/pytorch/examples / _get_full_model_state_dict

Method _get_full_model_state_dict

distributed/FSDP2/checkpoint.py:136–154  ·  view source on GitHub ↗
(self, model: FSDPModule)

Source from the content-addressed store, hash-verified

134 )
135
136 def _get_full_model_state_dict(self, model: FSDPModule):
137 if self.dcp_api:
138 return get_model_state_dict(
139 model=model,
140 options=StateDictOptions(
141 full_state_dict=True,
142 cpu_offload=True,
143 ),
144 )
145
146 sharded_sd = model.state_dict()
147 cpu_state_dict = {}
148 for param_name, sharded_param in sharded_sd.items():
149 full_param = sharded_param.full_tensor()
150 if torch.distributed.get_rank() == 0:
151 cpu_state_dict[param_name] = full_param.cpu()
152 else:
153 del full_param
154 return cpu_state_dict
155
156 def _get_full_optimizer_state_dict(
157 self,

Callers 1

saveMethod · 0.95

Calls

no outgoing calls

Tested by

no test coverage detected