(self, save_dir, tag, client_state={}, exclude_frozen_parameters=False)
| 4436 | return success |
| 4437 | |
| 4438 | def _save_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_parameters=False): |
| 4439 | |
| 4440 | save_path = self._get_ckpt_name(save_dir, tag) |
| 4441 | |
| 4442 | zero_optimizer_state = self.zero_optimization() |
| 4443 | |
| 4444 | save_frozen_param = self.zero_optimization_partition_gradients() and not exclude_frozen_parameters |
| 4445 | |
| 4446 | # A hack to save the checkpointing directory. Pipeline parallelism overrides |
| 4447 | # module_state_dict() and uses this path to save the model. module_state_dict() |
| 4448 | # then instead just returns None. The module_state_dict() implementation in |
| 4449 | # PipelineEngine expects the save path to be set in self._curr_ckpt_path. |
| 4450 | self._curr_ckpt_path = os.path.join(save_dir, tag) |
| 4451 | module = self.module_state_dict(exclude_frozen_parameters=exclude_frozen_parameters) |
| 4452 | self._curr_ckpt_path = None |
| 4453 | |
| 4454 | state = dict(module=module, |
| 4455 | buffer_names=self._get_buffer_names(), |
| 4456 | optimizer=self.optimizer.state_dict() if self.optimizer and not zero_optimizer_state else None, |
| 4457 | param_shapes=self._get_zero_param_shapes() if self.optimizer and zero_optimizer_state else None, |
| 4458 | frozen_param_shapes=self._get_zero_frozen_param_attributes(self._get_param_shape_func) |
| 4459 | if save_frozen_param else None, |
| 4460 | shared_params=self._get_shared_params() if self.optimizer and zero_optimizer_state else None, |
| 4461 | frozen_param_fragments=self._get_zero_frozen_param_attributes(self._get_param_fragment_func) |
| 4462 | if save_frozen_param else None, |
| 4463 | lr_scheduler=self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None, |
| 4464 | data_sampler=self.training_dataloader.data_sampler.state_dict() if |
| 4465 | (self.training_dataloader is not None and self.curriculum_learning_enabled()) else None, |
| 4466 | random_ltd=self.random_ltd_scheduler.state_dict() if self.random_ltd_enabled() else None, |
| 4467 | sparse_tensor_module_names=self.sparse_tensor_module_names, |
| 4468 | skipped_steps=self.skipped_steps, |
| 4469 | global_steps=self.global_steps, |
| 4470 | global_samples=self.global_samples, |
| 4471 | dp_world_size=self.seq_dp_world_size, |
| 4472 | mp_world_size=self.mp_world_size, |
| 4473 | ds_config=self.config, |
| 4474 | ds_version=version) |
| 4475 | autotp_uc_info = getattr(self.module, UNIVERSAL_CHECKPOINT_INFO, None) |
| 4476 | if autotp_uc_info is not None: |
| 4477 | state[UNIVERSAL_CHECKPOINT_INFO] = autotp_uc_info |
| 4478 | state.update(client_state) |
| 4479 | log_dist(message=f'Saving model checkpoint: {save_path}', ranks=[0]) |
| 4480 | |
| 4481 | if self.save_non_zero_checkpoint: |
| 4482 | self.checkpoint_engine.save(state_dict=state, path=save_path) |
| 4483 | |
| 4484 | def _get_buffer_names(self): |
| 4485 | buffer_names = [] |
no test coverage detected