MCPcopy
hub / github.com/deepspeedai/DeepSpeed / _save_checkpoint

Method _save_checkpoint

deepspeed/runtime/engine.py:4438–4482  ·  view source on GitHub ↗
(self, save_dir, tag, client_state={}, exclude_frozen_parameters=False)

Source from the content-addressed store, hash-verified

4436 return success
4437
4438 def _save_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_parameters=False):
4439
4440 save_path = self._get_ckpt_name(save_dir, tag)
4441
4442 zero_optimizer_state = self.zero_optimization()
4443
4444 save_frozen_param = self.zero_optimization_partition_gradients() and not exclude_frozen_parameters
4445
4446 # A hack to save the checkpointing directory. Pipeline parallelism overrides
4447 # module_state_dict() and uses this path to save the model. module_state_dict()
4448 # then instead just returns None. The module_state_dict() implementation in
4449 # PipelineEngine expects the save path to be set in self._curr_ckpt_path.
4450 self._curr_ckpt_path = os.path.join(save_dir, tag)
4451 module = self.module_state_dict(exclude_frozen_parameters=exclude_frozen_parameters)
4452 self._curr_ckpt_path = None
4453
4454 state = dict(module=module,
4455 buffer_names=self._get_buffer_names(),
4456 optimizer=self.optimizer.state_dict() if self.optimizer and not zero_optimizer_state else None,
4457 param_shapes=self._get_zero_param_shapes() if self.optimizer and zero_optimizer_state else None,
4458 frozen_param_shapes=self._get_zero_frozen_param_attributes(self._get_param_shape_func)
4459 if save_frozen_param else None,
4460 shared_params=self._get_shared_params() if self.optimizer and zero_optimizer_state else None,
4461 frozen_param_fragments=self._get_zero_frozen_param_attributes(self._get_param_fragment_func)
4462 if save_frozen_param else None,
4463 lr_scheduler=self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
4464 data_sampler=self.training_dataloader.data_sampler.state_dict() if
4465 (self.training_dataloader is not None and self.curriculum_learning_enabled()) else None,
4466 random_ltd=self.random_ltd_scheduler.state_dict() if self.random_ltd_enabled() else None,
4467 sparse_tensor_module_names=self.sparse_tensor_module_names,
4468 skipped_steps=self.skipped_steps,
4469 global_steps=self.global_steps,
4470 global_samples=self.global_samples,
4471 dp_world_size=self.seq_dp_world_size,
4472 mp_world_size=self.mp_world_size,
4473 ds_config=self.config,
4474 ds_version=version)
4475 autotp_uc_info = getattr(self.module, UNIVERSAL_CHECKPOINT_INFO, None)
4476 if autotp_uc_info is not None:
4477 state[UNIVERSAL_CHECKPOINT_INFO] = autotp_uc_info
4478 state.update(client_state)
4479 log_dist(message=f'Saving model checkpoint: {save_path}', ranks=[0])
4480
4481 if self.save_non_zero_checkpoint:
4482 self.checkpoint_engine.save(state_dict=state, path=save_path)
4483
4484 def _get_buffer_names(self):
4485 buffer_names = []

Callers 1

save_checkpointMethod · 0.95

Calls 14

_get_ckpt_nameMethod · 0.95
zero_optimizationMethod · 0.95
module_state_dictMethod · 0.95
_get_buffer_namesMethod · 0.95
_get_shared_paramsMethod · 0.95
random_ltd_enabledMethod · 0.95
log_distFunction · 0.90
state_dictMethod · 0.45

Tested by

no test coverage detected