| 321 | |
| 322 | |
| 323 | def load_rng(sd): |
| 324 | # rng states. |
| 325 | env_type = os.getenv('ENV_TYPE') |
| 326 | try: |
| 327 | random.setstate(sd['random_rng_state']) |
| 328 | np.random.set_state(sd['np_rng_state']) |
| 329 | torch.set_rng_state(sd['torch_rng_state']) |
| 330 | torch.cuda.set_rng_state(sd['cuda_rng_state']) |
| 331 | if env_type == 'deepspeed+mpu': |
| 332 | mpu.get_cuda_rng_tracker().set_states(sd['rng_tracker_states']) |
| 333 | log_dist('global rank 0 is loading rng states') |
| 334 | except KeyError: |
| 335 | log_dist('Unable to load random state from checkpoint, exiting. ' |
| 336 | 'Specify --no-load-rng or --finetune to prevent ' |
| 337 | 'attempting to load the random ' |
| 338 | 'state.') |
| 339 | log_dist(' successfully loaded rng checkpoints') |