(self, local_path, hdfs_path=None)
| 627 | |
| 628 | @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| 629 | def save_checkpoint(self, local_path, hdfs_path=None): |
| 630 | import torch |
| 631 | if self._is_offload_param: |
| 632 | load_fsdp_param_and_grad(module=self.critic_module, |
| 633 | device_id=torch.cuda.current_device(), |
| 634 | load_grad=self._is_offload_grad) |
| 635 | |
| 636 | # TODO: support DCP and save sharded checkpoints |
| 637 | import torch.distributed |
| 638 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType, FullStateDictConfig |
| 639 | cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) |
| 640 | with FSDP.state_dict_type(self.critic_module, StateDictType.FULL_STATE_DICT, cfg): |
| 641 | state_dict = self.critic_module.state_dict() |
| 642 | if self.rank == 0: |
| 643 | print(f'Saving critic checkpoint to {local_path}') |
| 644 | os.makedirs(local_path, exist_ok=True) |
| 645 | self.critic_module._fsdp_wrapped_module.save_pretrained(local_path, state_dict=state_dict) |
| 646 | self.tokenizer.save_pretrained(local_path) |
| 647 | if hdfs_path is not None: |
| 648 | print(f'Uploading critic checkpoint to {hdfs_path}') |
| 649 | hdfs_io.makedirs(hdfs_path, exist_ok=True) |
| 650 | hdfs_io.copy(src=local_path, dst=hdfs_path) |
| 651 | |
| 652 | torch.distributed.barrier() |
| 653 | if self._is_offload_param: |
| 654 | offload_fsdp_param_and_grad(module=self.critic_module, offload_grad=self._is_offload_grad) |
| 655 | |
| 656 | |
| 657 | class RewardModelWorker(Worker): |
nothing calls this directly
no test coverage detected