save classifier to checkpoint
(self, classifier, optimizer, epoch)
| 111 | return start_epoch |
| 112 | |
| 113 | def save(self, classifier, optimizer, epoch): |
| 114 | """save classifier to checkpoint""" |
| 115 | args = self.args |
| 116 | if args.local_rank == 0: |
| 117 | # saving the classifier to each instance |
| 118 | print('==> Saving...') |
| 119 | state = { |
| 120 | 'epoch': epoch, |
| 121 | 'classifier': classifier.state_dict(), |
| 122 | 'optimizer': optimizer.state_dict(), |
| 123 | } |
| 124 | save_file = os.path.join(args.model_folder, 'current.pth') |
| 125 | torch.save(state, save_file) |
| 126 | if epoch % args.save_freq == 0: |
| 127 | save_file = os.path.join( |
| 128 | args.model_folder, 'ckpt_epoch_{}.pth'.format(epoch)) |
| 129 | torch.save(state, save_file) |
| 130 | # help release GPU memory |
| 131 | del state |
| 132 | |
| 133 | def train(self, epoch, train_loader, model, classifier, |
| 134 | criterion, optimizer): |