MCPcopy
hub / github.com/allenai/open-instruct / _load_checkpoint

Function _load_checkpoint

open_instruct/model_utils.py:250–258  ·  view source on GitHub ↗
(path: str, dev: torch.device)

Source from the content-addressed store, hash-verified

248 """
249
250 def _load_checkpoint(path: str, dev: torch.device):
251 state_dict = torch.load(path, map_location=dev)
252 if hasattr(model, "module"):
253 # Needed if wrapped by DeepSpeed.
254 model.module.load_state_dict(state_dict)
255 else:
256 # If a vanilla HF model.
257 model.load_state_dict(state_dict)
258 logger.info(f"{rank=}: Loaded checkpoint from {path}")
259
260 if not throw_on_error:
261 try:

Callers 1

maybe_load_checkpointFunction · 0.85

Calls 2

loadMethod · 0.80
load_state_dictMethod · 0.45

Tested by

no test coverage detected