MCPcopy
hub / github.com/mosaicml/composer / _save_checkpoint

Method _save_checkpoint

composer/callbacks/checkpoint_saver.py:436–573  ·  view source on GitHub ↗
(self, state: State, logger: Logger)

Source from the content-addressed store, hash-verified

434 )
435
436 def _save_checkpoint(self, state: State, logger: Logger):
437 self.last_checkpoint_batch = state.timestamp.batch
438
439 # save the checkpoint to the filename
440 filename_with_placeholders = self.filename.format(state, keep_placeholders=True)
441 save_filename = checkpoint.get_save_filename(state, filename_with_placeholders)
442 # Store before saving so state_dict in checkpoint has reference to latest checkpoint (itself)
443 self.all_saved_checkpoints_to_timestamp[save_filename] = state.timestamp
444
445 saved_path = checkpoint.save_checkpoint(
446 state=state,
447 filename=filename_with_placeholders,
448 weights_only=self.weights_only,
449 ignore_keys=self.ignore_keys,
450 )
451 log.debug(f'Checkpoint locally saved to {saved_path}')
452
453 self.symlink_count += 1
454 # Remote checkpoint file names on this rank
455 local_remote_file_names = []
456 all_remote_filenames = []
457
458 if not saved_path: # not all ranks save
459 if self.remote_file_name is not None and self.remote_uploader is not None:
460 all_remote_filenames = dist.all_gather_object(local_remote_file_names)
461 return
462
463 metadata_local_file_path = None
464 if dist.get_global_rank() == 0 and state.fsdp_sharded_state_dict_enabled:
465 metadata_local_file_path = format_name_with_dist_and_time(
466 os.path.join(Path(saved_path).parent, _TORCH_DISTRIBUTED_CHECKPOINTS_METADATA_FILENAME),
467 state.run_name,
468 state.timestamp,
469 )
470
471 self.rank_saves_symlinks = dist.get_global_rank() == 0 or not state.fsdp_sharded_state_dict_enabled
472 if self.latest_filename is not None and self.num_checkpoints_to_keep != 0:
473 symlink = self.latest_filename.format(state)
474 os.makedirs(os.path.dirname(symlink), exist_ok=True)
475 try:
476 os.remove(symlink)
477 except FileNotFoundError:
478 pass
479 # Sharded checkpoints for torch >2.0 use directories not files for load_paths
480 if state.fsdp_sharded_state_dict_enabled:
481 src_path = str(pathlib.Path(saved_path).parent)
482 else:
483 src_path = saved_path
484 if self.rank_saves_symlinks:
485 os.symlink(os.path.relpath(src_path, os.path.dirname(symlink)), symlink)
486
487 # if remote file name provided, upload the checkpoint
488 if self.remote_file_name is not None:
489 if state.fsdp_sharded_state_dict_enabled:
490 remote_file_name = self.remote_file_name.format(
491 state,
492 keep_placeholders=True,
493 ).lstrip('/')

Callers 5

batch_checkpointMethod · 0.95
epoch_checkpointMethod · 0.95
iteration_checkpointMethod · 0.95
test_autoloadFunction · 0.80

Calls 8

_upload_checkpointMethod · 0.95
_rotate_checkpointsMethod · 0.95
create_symlink_fileFunction · 0.90
formatMethod · 0.80
save_checkpointMethod · 0.80
upload_fileMethod · 0.45

Tested by 1

test_autoloadFunction · 0.64