Runs the training flow. Args: session: A Session object. Valid choises are: LocalSession, LocalHostScheduler, and DistributedSession. It is used to execute one TaskGroup a time.
(self, session)
| 687 | self.upload_task_group_builder = upload_task_group_builder |
| 688 | |
| 689 | def train(self, session): |
| 690 | """Runs the training flow. |
| 691 | |
| 692 | Args: |
| 693 | session: A Session object. Valid choises are: LocalSession, |
| 694 | LocalHostScheduler, and DistributedSession. It is used to |
| 695 | execute one TaskGroup a time. |
| 696 | """ |
| 697 | # identify the epoch we must resume from |
| 698 | if self.checkpoint_manager: |
| 699 | self.checkpoint_manager.set_params(nodes=self.job.nodes_to_checkpoint()) |
| 700 | self.resume_from_epoch = self.checkpoint_manager.\ |
| 701 | get_resume_from_epoch_id(self.resume_from_epoch) |
| 702 | if self.resume_from_epoch is not None: |
| 703 | logger.info('Resuming from epoch {}'.format(self.resume_from_epoch)) |
| 704 | |
| 705 | # Initialize all the nodes. |
| 706 | from_scratch = self.resume_from_epoch is None |
| 707 | if from_scratch: |
| 708 | session.run(self.job.init_group) |
| 709 | |
| 710 | if self.checkpoint_manager: |
| 711 | logger.info('Preparing checkpoints ...') |
| 712 | session.run(self.checkpoint_manager.init( |
| 713 | self.job.nodes_to_checkpoint(), |
| 714 | retrieve_from_epoch=self.resume_from_epoch)) |
| 715 | # Save the first checkpoint before training starts, or resume from |
| 716 | # a previously saved checkpoint. |
| 717 | if from_scratch: |
| 718 | self.save_checkpoints(0, session) |
| 719 | else: |
| 720 | logger.info('Loading checkpoints for epoch {} ...'.format( |
| 721 | self.resume_from_epoch)) |
| 722 | session.run( |
| 723 | self.checkpoint_manager.load(self.resume_from_epoch)) |
| 724 | self.checkpoint_manager.report_checkpoint_stats('checkpoint_load') |
| 725 | logger.info('Checkpoint loaded') |
| 726 | |
| 727 | logger.info("Finished initializing") |
| 728 | |
| 729 | # Start training. |
| 730 | epoch = 1 if from_scratch else self.resume_from_epoch + 1 |
| 731 | while True: |
| 732 | logger.info('Starting epoch %d' % epoch) |
| 733 | session.run(self.job.epoch_group) |
| 734 | logger.info('Finished epoch %d' % epoch) |
| 735 | stop_conditions = [o.fetch() for o in self.job.stop_conditions] |
| 736 | |
| 737 | if self.checkpoint_manager: |
| 738 | self.save_checkpoints(epoch, session) |
| 739 | |
| 740 | if any(stop_conditions): |
| 741 | logger.info('Stopping') |
| 742 | break |
| 743 | epoch += 1 |
| 744 | logger.info('Finished training') |
| 745 | # Upload the checkpoints. |
| 746 | if (self.upload_task_group_builder): |