(train_state, milestone=False)
| 314 | ) |
| 315 | |
| 316 | def save_checkpoint(train_state, milestone=False): |
| 317 | step = int(jax.device_get(train_state.step)) |
| 318 | metadata = dict( |
| 319 | step=step, |
| 320 | variant=variant, |
| 321 | flags=flags_config_dict, |
| 322 | llama_config=llama_config.to_dict(), |
| 323 | ) |
| 324 | checkpointer.save_all( |
| 325 | train_state=train_state, |
| 326 | gather_fns=gather_fns, |
| 327 | metadata=metadata, |
| 328 | dataset=dataset.get_state_dict(), |
| 329 | milestone=milestone, |
| 330 | ) |
| 331 | |
| 332 | with mesh: |
| 333 | train_state, restored_params = None, None |