MCPcopy Index your code
hub / github.com/deepspeedai/DeepSpeed / _load_zero_checkpoint

Method _load_zero_checkpoint

deepspeed/runtime/engine.py:3968–4006  ·  view source on GitHub ↗
(self, load_dir, tag, load_optimizer_states=True)

Source from the content-addressed store, hash-verified

3966 return load_path, client_state
3967
3968 def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True):
3969
3970 load_serial = None
3971 # When use loading checkpoint serial, checkpoint loading start from local rank 0,
3972 # all other local rank would be paused, waiting for its rank-1 peer ready and its notification.
3973 if self._config.zero_config.pipeline_loading_checkpoint:
3974 assert self.zero_optimization_stage(
3975 ) == ZeroStageEnum.weights, "Only stage3 support for pipeline checkpoint loading"
3976 load_serial = torch.zeros(1).to(self.device)
3977 if dist.get_local_rank() != 0:
3978 dist.recv(tensor=load_serial, src=dist.get_rank() - 1)
3979 if self.load_universal_checkpoint():
3980 zero_sd_list = None
3981 checkpoint_folder = f'{os.path.join(load_dir, tag)}'
3982 else:
3983 if load_optimizer_states and self.seq_dp_world_size != self.loaded_checkpoint_dp_world_size:
3984 raise ZeRORuntimeException("The checkpoint being loaded used a DP " \
3985 f"world size of {self.loaded_checkpoint_dp_world_size} but the " \
3986 f"current world size is {self.seq_dp_world_size}. Automatic adjustment " \
3987 "of ZeRO's optimizer state partitioning with a new world size is not " \
3988 "currently supported.")
3989 checkpoint_folder = None
3990 zero_sd_list = self._get_all_zero_checkpoints(load_dir, tag)
3991 if zero_sd_list is None:
3992 return False
3993
3994 param_shapes = self._get_zero_param_shapes()
3995 self.optimizer.load_state_dict(state_dict_list=zero_sd_list,
3996 load_optimizer_states=load_optimizer_states,
3997 load_from_fp32_weights=self.zero_load_from_fp32_weights(),
3998 checkpoint_folder=checkpoint_folder,
3999 load_serial=load_serial,
4000 param_shapes=param_shapes)
4001
4002 if self.load_universal_checkpoint():
4003 logger.info(f'loaded universal zero checkpoints from {checkpoint_folder} for rank {self.global_rank}')
4004 else:
4005 logger.info(f"loading {len(zero_sd_list)} zero partition checkpoints for rank {self.global_rank}")
4006 return True
4007
4008 def _get_mp_rank_zero_checkpoint_names(self, load_dir, tag, mp_rank, dp_world_size, bf16_mode):
4009 zero_ckpt_names = []

Callers 1

load_checkpointMethod · 0.95

Calls 10

toMethod · 0.45
recvMethod · 0.45
get_rankMethod · 0.45
load_state_dictMethod · 0.45

Tested by

no test coverage detected