| 134 | ) |
| 135 | |
| 136 | def _get_full_model_state_dict(self, model: FSDPModule): |
| 137 | if self.dcp_api: |
| 138 | return get_model_state_dict( |
| 139 | model=model, |
| 140 | options=StateDictOptions( |
| 141 | full_state_dict=True, |
| 142 | cpu_offload=True, |
| 143 | ), |
| 144 | ) |
| 145 | |
| 146 | sharded_sd = model.state_dict() |
| 147 | cpu_state_dict = {} |
| 148 | for param_name, sharded_param in sharded_sd.items(): |
| 149 | full_param = sharded_param.full_tensor() |
| 150 | if torch.distributed.get_rank() == 0: |
| 151 | cpu_state_dict[param_name] = full_param.cpu() |
| 152 | else: |
| 153 | del full_param |
| 154 | return cpu_state_dict |
| 155 | |
| 156 | def _get_full_optimizer_state_dict( |
| 157 | self, |