(ckpt_path: str, train_dl, train_state: TrainState)
| 234 | |
| 235 | |
| 236 | def load_context(ckpt_path: str, train_dl, train_state: TrainState): |
| 237 | context_stuffs = llm_load(os.path.join(ckpt_path, "context.pt")) |
| 238 | train_state.load_state_dict(context_stuffs, train_dl) |
| 239 | if gpc.is_rank_for_log(): |
| 240 | logger.info(f"reload train_state:{train_state}") |
| 241 | torch.cuda.empty_cache() |
| 242 | |
| 243 | |
| 244 | def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, learning_rate, train_state: TrainState): |
no test coverage detected