MCPcopy Index your code
hub / github.com/pytorch/examples / Checkpointer

Class Checkpointer

distributed/FSDP2/checkpoint.py:39–209  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

37
38
39class Checkpointer:
40 def __init__(self, folder: str, dcp_api: bool):
41 self.folder = folder
42 self.dcp_api = dcp_api
43 self.last_training_time = get_latest_checkpoint_folder(
44 f"{folder}/{'dcp_api' if dcp_api else 'dtensor_api'}"
45 )
46
47 def is_empty(self):
48 return self.last_training_time is None
49
50 def load_model(self, model: FSDPModule):
51 last_model_checkpoint = (
52 f"{self.folder}/{'dcp_api' if self.dcp_api else 'dtensor_api'}"
53 f"/{self.last_training_time}/{MODEL_CHECKPOINT}"
54 )
55 full_sd = torch.load(
56 last_model_checkpoint, mmap=True, weights_only=True, map_location="cpu"
57 )
58 if self.dcp_api:
59 set_model_state_dict(
60 model=model,
61 model_state_dict=full_sd,
62 options=StateDictOptions(
63 full_state_dict=True,
64 broadcast_from_rank0=True,
65 ),
66 )
67 return
68 meta_sharded_sd = model.state_dict()
69 sharded_sd = {}
70 for param_name, full_tensor in full_sd.items():
71 sharded_meta_param = meta_sharded_sd.get(param_name)
72 sharded_tensor = distribute_tensor(
73 full_tensor,
74 sharded_meta_param.device_mesh,
75 sharded_meta_param.placements,
76 )
77 sharded_sd[param_name] = nn.Parameter(sharded_tensor)
78 # choose `assign=True` since we cannot call `copy_` on meta tensor
79 model.load_state_dict(sharded_sd, strict=False, assign=True)
80
81 def load_optim(self, model: FSDPModule, opt: torch.optim.Optimizer):
82 last_optim_checkpoint = (
83 f"{self.folder}/{'dcp_api' if self.dcp_api else 'dtensor_api'}"
84 f"/{self.last_training_time}/{OPTIM_CHECKPOINT}"
85 )
86 full_sd = torch.load(
87 last_optim_checkpoint, mmap=True, weights_only=True, map_location="cpu"
88 )
89 if self.dcp_api:
90 set_optimizer_state_dict(
91 model=model,
92 optimizers=opt,
93 optim_state_dict=full_sd,
94 options=StateDictOptions(
95 full_state_dict=True,
96 broadcast_from_rank0=True,

Callers 1

mainFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected