MCPcopy
hub / github.com/PRIME-RL/PRIME / save_checkpoint

Method save_checkpoint

training/verl/workers/fsdp_workers.py:629–654  ·  view source on GitHub ↗
(self, local_path, hdfs_path=None)

Source from the content-addressed store, hash-verified

627
628 @register(dispatch_mode=Dispatch.ONE_TO_ALL)
629 def save_checkpoint(self, local_path, hdfs_path=None):
630 import torch
631 if self._is_offload_param:
632 load_fsdp_param_and_grad(module=self.critic_module,
633 device_id=torch.cuda.current_device(),
634 load_grad=self._is_offload_grad)
635
636 # TODO: support DCP and save sharded checkpoints
637 import torch.distributed
638 from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType, FullStateDictConfig
639 cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
640 with FSDP.state_dict_type(self.critic_module, StateDictType.FULL_STATE_DICT, cfg):
641 state_dict = self.critic_module.state_dict()
642 if self.rank == 0:
643 print(f'Saving critic checkpoint to {local_path}')
644 os.makedirs(local_path, exist_ok=True)
645 self.critic_module._fsdp_wrapped_module.save_pretrained(local_path, state_dict=state_dict)
646 self.tokenizer.save_pretrained(local_path)
647 if hdfs_path is not None:
648 print(f'Uploading critic checkpoint to {hdfs_path}')
649 hdfs_io.makedirs(hdfs_path, exist_ok=True)
650 hdfs_io.copy(src=local_path, dst=hdfs_path)
651
652 torch.distributed.barrier()
653 if self._is_offload_param:
654 offload_fsdp_param_and_grad(module=self.critic_module, offload_grad=self._is_offload_grad)
655
656
657class RewardModelWorker(Worker):

Callers

nothing calls this directly

Calls 2

load_fsdp_param_and_gradFunction · 0.90

Tested by

no test coverage detected