MCPcopy Index your code
hub / github.com/pytorch/pytorch / JobRunner

Class JobRunner

caffe2/python/checkpoint.py:655–816  ·  view source on GitHub ↗

Implement the runtime logic for jobs with checkpointing at the level of epoch. Can be used to run either single-host or distributed jobs. Job runner is a callable to be called once from the master, passing a session as an argument. This call will block until the Job execution is com

Source from the content-addressed store, hash-verified

653
654
655class JobRunner:
656 """
657 Implement the runtime logic for jobs with checkpointing at the level of
658 epoch. Can be used to run either single-host or distributed jobs. Job
659 runner is a callable to be called once from the master, passing a session
660 as an argument. This call will block until the Job execution is complete.
661
662 If a checkpoint_manager is passed, checkpoints will be taken after
663 initialization and after each epoch execution. If, in addition,
664 `resume_from_epoch` is an epoch number, the corresponding checkpoint will
665 be loaded and job execution will continue from the given epoch. In
666 this case, the job's init_group will not be run.
667
668 Refer to checkpoint_test.py for an example.
669 """
670 def __init__(self, job, checkpoint_manager=None, resume_from_epoch=None,
671 upload_task_group_builder=None):
672 """Initializes the JobRunner.
673
674 Args:
675 job: A Job object. The job to be executed.
676 checkpoint_manager: Can be a CheckpointManager for single machine
677 or a MultiNodeCheckpointManager for multi-machine. The manager
678 that initializes/saves/loads checkpoints.
679 resume_from_epoch: An integer. The epoch to resume from.
680 upload_task_group_builder: A subclass of the
681 UploadTaskGroupBuilder. Creates a task group to upload
682 checkpoints.
683 """
684 self.resume_from_epoch = resume_from_epoch
685 self.checkpoint_manager = checkpoint_manager
686 self.job = job
687 self.upload_task_group_builder = upload_task_group_builder
688
689 def train(self, session):
690 """Runs the training flow.
691
692 Args:
693 session: A Session object. Valid choises are: LocalSession,
694 LocalHostScheduler, and DistributedSession. It is used to
695 execute one TaskGroup a time.
696 """
697 # identify the epoch we must resume from
698 if self.checkpoint_manager:
699 self.checkpoint_manager.set_params(nodes=self.job.nodes_to_checkpoint())
700 self.resume_from_epoch = self.checkpoint_manager.\
701 get_resume_from_epoch_id(self.resume_from_epoch)
702 if self.resume_from_epoch is not None:
703 logger.info('Resuming from epoch {}'.format(self.resume_from_epoch))
704
705 # Initialize all the nodes.
706 from_scratch = self.resume_from_epoch is None
707 if from_scratch:
708 session.run(self.job.init_group)
709
710 if self.checkpoint_manager:
711 logger.info('Preparing checkpoints ...')
712 session.run(self.checkpoint_manager.init(

Calls

no outgoing calls

Used in the wild real call sites across dependent graphs

searching dependent graphs…