MCPcopy
hub / github.com/InternLM/InternLM / load_sampler

Function load_sampler

internlm/utils/model_checkpoint.py:225–233  ·  view source on GitHub ↗
(ckpt_path: str, sampler)

Source from the content-addressed store, hash-verified

223
224
225def load_sampler(ckpt_path: str, sampler):
226 sampler_states = llm_load(os.path.join(ckpt_path, "sampler.pt"))
227 sampler.load_state_dict(sampler_states)
228 if gpc.is_rank_for_log():
229 pstate = copy.deepcopy(sampler_states)
230 pstate.pop("indices")
231 pstate.pop("rng_state")
232 logger.info(f"reload sampler_states:{pstate}")
233 torch.cuda.empty_cache()
234
235
236def load_context(ckpt_path: str, train_dl, train_state: TrainState):

Callers 1

try_resume_trainingMethod · 0.85

Calls 3

llm_loadFunction · 0.90
is_rank_for_logMethod · 0.80
load_state_dictMethod · 0.45

Tested by

no test coverage detected