(self, load_dir, tag, load_optimizer_states=True)
| 3966 | return load_path, client_state |
| 3967 | |
| 3968 | def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True): |
| 3969 | |
| 3970 | load_serial = None |
| 3971 | # When use loading checkpoint serial, checkpoint loading start from local rank 0, |
| 3972 | # all other local rank would be paused, waiting for its rank-1 peer ready and its notification. |
| 3973 | if self._config.zero_config.pipeline_loading_checkpoint: |
| 3974 | assert self.zero_optimization_stage( |
| 3975 | ) == ZeroStageEnum.weights, "Only stage3 support for pipeline checkpoint loading" |
| 3976 | load_serial = torch.zeros(1).to(self.device) |
| 3977 | if dist.get_local_rank() != 0: |
| 3978 | dist.recv(tensor=load_serial, src=dist.get_rank() - 1) |
| 3979 | if self.load_universal_checkpoint(): |
| 3980 | zero_sd_list = None |
| 3981 | checkpoint_folder = f'{os.path.join(load_dir, tag)}' |
| 3982 | else: |
| 3983 | if load_optimizer_states and self.seq_dp_world_size != self.loaded_checkpoint_dp_world_size: |
| 3984 | raise ZeRORuntimeException("The checkpoint being loaded used a DP " \ |
| 3985 | f"world size of {self.loaded_checkpoint_dp_world_size} but the " \ |
| 3986 | f"current world size is {self.seq_dp_world_size}. Automatic adjustment " \ |
| 3987 | "of ZeRO's optimizer state partitioning with a new world size is not " \ |
| 3988 | "currently supported.") |
| 3989 | checkpoint_folder = None |
| 3990 | zero_sd_list = self._get_all_zero_checkpoints(load_dir, tag) |
| 3991 | if zero_sd_list is None: |
| 3992 | return False |
| 3993 | |
| 3994 | param_shapes = self._get_zero_param_shapes() |
| 3995 | self.optimizer.load_state_dict(state_dict_list=zero_sd_list, |
| 3996 | load_optimizer_states=load_optimizer_states, |
| 3997 | load_from_fp32_weights=self.zero_load_from_fp32_weights(), |
| 3998 | checkpoint_folder=checkpoint_folder, |
| 3999 | load_serial=load_serial, |
| 4000 | param_shapes=param_shapes) |
| 4001 | |
| 4002 | if self.load_universal_checkpoint(): |
| 4003 | logger.info(f'loaded universal zero checkpoints from {checkpoint_folder} for rank {self.global_rank}') |
| 4004 | else: |
| 4005 | logger.info(f"loading {len(zero_sd_list)} zero partition checkpoints for rank {self.global_rank}") |
| 4006 | return True |
| 4007 | |
| 4008 | def _get_mp_rank_zero_checkpoint_names(self, load_dir, tag, mp_rank, dp_world_size, bf16_mode): |
| 4009 | zero_ckpt_names = [] |
no test coverage detected