MCPcopy Index your code
hub / github.com/pytorch/pytorch / train

Method train

caffe2/python/checkpoint.py:689–759  ·  view source on GitHub ↗

Runs the training flow. Args: session: A Session object. Valid choises are: LocalSession, LocalHostScheduler, and DistributedSession. It is used to execute one TaskGroup a time.

(self, session)

Source from the content-addressed store, hash-verified

687 self.upload_task_group_builder = upload_task_group_builder
688
689 def train(self, session):
690 """Runs the training flow.
691
692 Args:
693 session: A Session object. Valid choises are: LocalSession,
694 LocalHostScheduler, and DistributedSession. It is used to
695 execute one TaskGroup a time.
696 """
697 # identify the epoch we must resume from
698 if self.checkpoint_manager:
699 self.checkpoint_manager.set_params(nodes=self.job.nodes_to_checkpoint())
700 self.resume_from_epoch = self.checkpoint_manager.\
701 get_resume_from_epoch_id(self.resume_from_epoch)
702 if self.resume_from_epoch is not None:
703 logger.info('Resuming from epoch {}'.format(self.resume_from_epoch))
704
705 # Initialize all the nodes.
706 from_scratch = self.resume_from_epoch is None
707 if from_scratch:
708 session.run(self.job.init_group)
709
710 if self.checkpoint_manager:
711 logger.info('Preparing checkpoints ...')
712 session.run(self.checkpoint_manager.init(
713 self.job.nodes_to_checkpoint(),
714 retrieve_from_epoch=self.resume_from_epoch))
715 # Save the first checkpoint before training starts, or resume from
716 # a previously saved checkpoint.
717 if from_scratch:
718 self.save_checkpoints(0, session)
719 else:
720 logger.info('Loading checkpoints for epoch {} ...'.format(
721 self.resume_from_epoch))
722 session.run(
723 self.checkpoint_manager.load(self.resume_from_epoch))
724 self.checkpoint_manager.report_checkpoint_stats('checkpoint_load')
725 logger.info('Checkpoint loaded')
726
727 logger.info("Finished initializing")
728
729 # Start training.
730 epoch = 1 if from_scratch else self.resume_from_epoch + 1
731 while True:
732 logger.info('Starting epoch %d' % epoch)
733 session.run(self.job.epoch_group)
734 logger.info('Finished epoch %d' % epoch)
735 stop_conditions = [o.fetch() for o in self.job.stop_conditions]
736
737 if self.checkpoint_manager:
738 self.save_checkpoints(epoch, session)
739
740 if any(stop_conditions):
741 logger.info('Stopping')
742 break
743 epoch += 1
744 logger.info('Finished training')
745 # Upload the checkpoints.
746 if (self.upload_task_group_builder):

Calls 13

save_checkpointsMethod · 0.95
nodes_to_checkpointMethod · 0.80
infoMethod · 0.80
anyFunction · 0.50
set_paramsMethod · 0.45
formatMethod · 0.45
runMethod · 0.45
initMethod · 0.45
loadMethod · 0.45
fetchMethod · 0.45