| 37 | |
| 38 | |
| 39 | class Checkpointer: |
| 40 | def __init__(self, folder: str, dcp_api: bool): |
| 41 | self.folder = folder |
| 42 | self.dcp_api = dcp_api |
| 43 | self.last_training_time = get_latest_checkpoint_folder( |
| 44 | f"{folder}/{'dcp_api' if dcp_api else 'dtensor_api'}" |
| 45 | ) |
| 46 | |
| 47 | def is_empty(self): |
| 48 | return self.last_training_time is None |
| 49 | |
| 50 | def load_model(self, model: FSDPModule): |
| 51 | last_model_checkpoint = ( |
| 52 | f"{self.folder}/{'dcp_api' if self.dcp_api else 'dtensor_api'}" |
| 53 | f"/{self.last_training_time}/{MODEL_CHECKPOINT}" |
| 54 | ) |
| 55 | full_sd = torch.load( |
| 56 | last_model_checkpoint, mmap=True, weights_only=True, map_location="cpu" |
| 57 | ) |
| 58 | if self.dcp_api: |
| 59 | set_model_state_dict( |
| 60 | model=model, |
| 61 | model_state_dict=full_sd, |
| 62 | options=StateDictOptions( |
| 63 | full_state_dict=True, |
| 64 | broadcast_from_rank0=True, |
| 65 | ), |
| 66 | ) |
| 67 | return |
| 68 | meta_sharded_sd = model.state_dict() |
| 69 | sharded_sd = {} |
| 70 | for param_name, full_tensor in full_sd.items(): |
| 71 | sharded_meta_param = meta_sharded_sd.get(param_name) |
| 72 | sharded_tensor = distribute_tensor( |
| 73 | full_tensor, |
| 74 | sharded_meta_param.device_mesh, |
| 75 | sharded_meta_param.placements, |
| 76 | ) |
| 77 | sharded_sd[param_name] = nn.Parameter(sharded_tensor) |
| 78 | # choose `assign=True` since we cannot call `copy_` on meta tensor |
| 79 | model.load_state_dict(sharded_sd, strict=False, assign=True) |
| 80 | |
| 81 | def load_optim(self, model: FSDPModule, opt: torch.optim.Optimizer): |
| 82 | last_optim_checkpoint = ( |
| 83 | f"{self.folder}/{'dcp_api' if self.dcp_api else 'dtensor_api'}" |
| 84 | f"/{self.last_training_time}/{OPTIM_CHECKPOINT}" |
| 85 | ) |
| 86 | full_sd = torch.load( |
| 87 | last_optim_checkpoint, mmap=True, weights_only=True, map_location="cpu" |
| 88 | ) |
| 89 | if self.dcp_api: |
| 90 | set_optimizer_state_dict( |
| 91 | model=model, |
| 92 | optimizers=opt, |
| 93 | optim_state_dict=full_sd, |
| 94 | options=StateDictOptions( |
| 95 | full_state_dict=True, |
| 96 | broadcast_from_rank0=True, |