MCPcopy
hub / github.com/FlagAI-Open/FlagAI / load_rng

Function load_rng

flagai/utils.py:323–339  ·  view source on GitHub ↗
(sd)

Source from the content-addressed store, hash-verified

321
322
323def load_rng(sd):
324 # rng states.
325 env_type = os.getenv('ENV_TYPE')
326 try:
327 random.setstate(sd['random_rng_state'])
328 np.random.set_state(sd['np_rng_state'])
329 torch.set_rng_state(sd['torch_rng_state'])
330 torch.cuda.set_rng_state(sd['cuda_rng_state'])
331 if env_type == 'deepspeed+mpu':
332 mpu.get_cuda_rng_tracker().set_states(sd['rng_tracker_states'])
333 log_dist('global rank 0 is loading rng states')
334 except KeyError:
335 log_dist('Unable to load random state from checkpoint, exiting. '
336 'Specify --no-load-rng or --finetune to prevent '
337 'attempting to load the random '
338 'state.')
339 log_dist(' successfully loaded rng checkpoints')

Callers 4

do_trainMethod · 0.90
do_trainMethod · 0.90
trainMethod · 0.90
trainMethod · 0.90

Calls 2

log_distFunction · 0.90
set_statesMethod · 0.80

Tested by

no test coverage detected