MCPcopy Index your code
hub / github.com/LargeWorldModel/LWM / save_checkpoint

Function save_checkpoint

lwm/train.py:316–330  ·  view source on GitHub ↗
(train_state, milestone=False)

Source from the content-addressed store, hash-verified

314 )
315
316 def save_checkpoint(train_state, milestone=False):
317 step = int(jax.device_get(train_state.step))
318 metadata = dict(
319 step=step,
320 variant=variant,
321 flags=flags_config_dict,
322 llama_config=llama_config.to_dict(),
323 )
324 checkpointer.save_all(
325 train_state=train_state,
326 gather_fns=gather_fns,
327 metadata=metadata,
328 dataset=dataset.get_state_dict(),
329 milestone=milestone,
330 )
331
332 with mesh:
333 train_state, restored_params = None, None

Callers 1

mainFunction · 0.85

Calls 1

get_state_dictMethod · 0.45

Tested by

no test coverage detected