| 434 | ) |
| 435 | |
| 436 | def _save_checkpoint(self, state: State, logger: Logger): |
| 437 | self.last_checkpoint_batch = state.timestamp.batch |
| 438 | |
| 439 | # save the checkpoint to the filename |
| 440 | filename_with_placeholders = self.filename.format(state, keep_placeholders=True) |
| 441 | save_filename = checkpoint.get_save_filename(state, filename_with_placeholders) |
| 442 | # Store before saving so state_dict in checkpoint has reference to latest checkpoint (itself) |
| 443 | self.all_saved_checkpoints_to_timestamp[save_filename] = state.timestamp |
| 444 | |
| 445 | saved_path = checkpoint.save_checkpoint( |
| 446 | state=state, |
| 447 | filename=filename_with_placeholders, |
| 448 | weights_only=self.weights_only, |
| 449 | ignore_keys=self.ignore_keys, |
| 450 | ) |
| 451 | log.debug(f'Checkpoint locally saved to {saved_path}') |
| 452 | |
| 453 | self.symlink_count += 1 |
| 454 | # Remote checkpoint file names on this rank |
| 455 | local_remote_file_names = [] |
| 456 | all_remote_filenames = [] |
| 457 | |
| 458 | if not saved_path: # not all ranks save |
| 459 | if self.remote_file_name is not None and self.remote_uploader is not None: |
| 460 | all_remote_filenames = dist.all_gather_object(local_remote_file_names) |
| 461 | return |
| 462 | |
| 463 | metadata_local_file_path = None |
| 464 | if dist.get_global_rank() == 0 and state.fsdp_sharded_state_dict_enabled: |
| 465 | metadata_local_file_path = format_name_with_dist_and_time( |
| 466 | os.path.join(Path(saved_path).parent, _TORCH_DISTRIBUTED_CHECKPOINTS_METADATA_FILENAME), |
| 467 | state.run_name, |
| 468 | state.timestamp, |
| 469 | ) |
| 470 | |
| 471 | self.rank_saves_symlinks = dist.get_global_rank() == 0 or not state.fsdp_sharded_state_dict_enabled |
| 472 | if self.latest_filename is not None and self.num_checkpoints_to_keep != 0: |
| 473 | symlink = self.latest_filename.format(state) |
| 474 | os.makedirs(os.path.dirname(symlink), exist_ok=True) |
| 475 | try: |
| 476 | os.remove(symlink) |
| 477 | except FileNotFoundError: |
| 478 | pass |
| 479 | # Sharded checkpoints for torch >2.0 use directories not files for load_paths |
| 480 | if state.fsdp_sharded_state_dict_enabled: |
| 481 | src_path = str(pathlib.Path(saved_path).parent) |
| 482 | else: |
| 483 | src_path = saved_path |
| 484 | if self.rank_saves_symlinks: |
| 485 | os.symlink(os.path.relpath(src_path, os.path.dirname(symlink)), symlink) |
| 486 | |
| 487 | # if remote file name provided, upload the checkpoint |
| 488 | if self.remote_file_name is not None: |
| 489 | if state.fsdp_sharded_state_dict_enabled: |
| 490 | remote_file_name = self.remote_file_name.format( |
| 491 | state, |
| 492 | keep_placeholders=True, |
| 493 | ).lstrip('/') |