(net, label, epoch, opt)
| 193 | |
| 194 | |
| 195 | def save_network(net, label, epoch, opt): |
| 196 | save_filename = '%s_net_%s.pth' % (epoch, label) |
| 197 | save_path = os.path.join(opt.checkpoints_dir, opt.name, save_filename) |
| 198 | torch.save(net.cpu().state_dict(), save_path) |
| 199 | if len(opt.gpu_ids) and torch.cuda.is_available(): |
| 200 | net.cuda() |
| 201 | |
| 202 | |
| 203 | def load_network(net, label, epoch, opt): |