(ckpt_path: str, sampler)
| 223 | |
| 224 | |
| 225 | def load_sampler(ckpt_path: str, sampler): |
| 226 | sampler_states = llm_load(os.path.join(ckpt_path, "sampler.pt")) |
| 227 | sampler.load_state_dict(sampler_states) |
| 228 | if gpc.is_rank_for_log(): |
| 229 | pstate = copy.deepcopy(sampler_states) |
| 230 | pstate.pop("indices") |
| 231 | pstate.pop("rng_state") |
| 232 | logger.info(f"reload sampler_states:{pstate}") |
| 233 | torch.cuda.empty_cache() |
| 234 | |
| 235 | |
| 236 | def load_context(ckpt_path: str, train_dl, train_state: TrainState): |
no test coverage detected