Implement the runtime logic for jobs with checkpointing at the level of epoch. Can be used to run either single-host or distributed jobs. Job runner is a callable to be called once from the master, passing a session as an argument. This call will block until the Job execution is com
| 653 | |
| 654 | |
| 655 | class JobRunner: |
| 656 | """ |
| 657 | Implement the runtime logic for jobs with checkpointing at the level of |
| 658 | epoch. Can be used to run either single-host or distributed jobs. Job |
| 659 | runner is a callable to be called once from the master, passing a session |
| 660 | as an argument. This call will block until the Job execution is complete. |
| 661 | |
| 662 | If a checkpoint_manager is passed, checkpoints will be taken after |
| 663 | initialization and after each epoch execution. If, in addition, |
| 664 | `resume_from_epoch` is an epoch number, the corresponding checkpoint will |
| 665 | be loaded and job execution will continue from the given epoch. In |
| 666 | this case, the job's init_group will not be run. |
| 667 | |
| 668 | Refer to checkpoint_test.py for an example. |
| 669 | """ |
| 670 | def __init__(self, job, checkpoint_manager=None, resume_from_epoch=None, |
| 671 | upload_task_group_builder=None): |
| 672 | """Initializes the JobRunner. |
| 673 | |
| 674 | Args: |
| 675 | job: A Job object. The job to be executed. |
| 676 | checkpoint_manager: Can be a CheckpointManager for single machine |
| 677 | or a MultiNodeCheckpointManager for multi-machine. The manager |
| 678 | that initializes/saves/loads checkpoints. |
| 679 | resume_from_epoch: An integer. The epoch to resume from. |
| 680 | upload_task_group_builder: A subclass of the |
| 681 | UploadTaskGroupBuilder. Creates a task group to upload |
| 682 | checkpoints. |
| 683 | """ |
| 684 | self.resume_from_epoch = resume_from_epoch |
| 685 | self.checkpoint_manager = checkpoint_manager |
| 686 | self.job = job |
| 687 | self.upload_task_group_builder = upload_task_group_builder |
| 688 | |
| 689 | def train(self, session): |
| 690 | """Runs the training flow. |
| 691 | |
| 692 | Args: |
| 693 | session: A Session object. Valid choises are: LocalSession, |
| 694 | LocalHostScheduler, and DistributedSession. It is used to |
| 695 | execute one TaskGroup a time. |
| 696 | """ |
| 697 | # identify the epoch we must resume from |
| 698 | if self.checkpoint_manager: |
| 699 | self.checkpoint_manager.set_params(nodes=self.job.nodes_to_checkpoint()) |
| 700 | self.resume_from_epoch = self.checkpoint_manager.\ |
| 701 | get_resume_from_epoch_id(self.resume_from_epoch) |
| 702 | if self.resume_from_epoch is not None: |
| 703 | logger.info('Resuming from epoch {}'.format(self.resume_from_epoch)) |
| 704 | |
| 705 | # Initialize all the nodes. |
| 706 | from_scratch = self.resume_from_epoch is None |
| 707 | if from_scratch: |
| 708 | session.run(self.job.init_group) |
| 709 | |
| 710 | if self.checkpoint_manager: |
| 711 | logger.info('Preparing checkpoints ...') |
| 712 | session.run(self.checkpoint_manager.init( |
no outgoing calls
searching dependent graphs…