(path: str, dev: torch.device)
| 248 | """ |
| 249 | |
| 250 | def _load_checkpoint(path: str, dev: torch.device): |
| 251 | state_dict = torch.load(path, map_location=dev) |
| 252 | if hasattr(model, "module"): |
| 253 | # Needed if wrapped by DeepSpeed. |
| 254 | model.module.load_state_dict(state_dict) |
| 255 | else: |
| 256 | # If a vanilla HF model. |
| 257 | model.load_state_dict(state_dict) |
| 258 | logger.info(f"{rank=}: Loaded checkpoint from {path}") |
| 259 | |
| 260 | if not throw_on_error: |
| 261 | try: |
no test coverage detected