| 31 | self.start_epoch = 1 |
| 32 | |
| 33 | def training(self): |
| 34 | for epoch in range(self.start_epoch, self.args.nEpochs): |
| 35 | self.train_epoch(epoch) |
| 36 | |
| 37 | if self.do_validation: |
| 38 | self.validate_epoch(epoch) |
| 39 | |
| 40 | val_loss = self.writer.data['val']['loss'] / self.writer.data['val']['count'] |
| 41 | |
| 42 | if self.args.save is not None and ((epoch + 1) % self.save_frequency): |
| 43 | self.model.save_checkpoint(self.args.save, |
| 44 | epoch, val_loss, |
| 45 | optimizer=self.optimizer) |
| 46 | |
| 47 | self.writer.write_end_of_epoch(epoch) |
| 48 | |
| 49 | self.writer.reset('train') |
| 50 | self.writer.reset('val') |
| 51 | |
| 52 | def train_epoch(self, epoch): |
| 53 | self.model.train() |