(net,path,gpu_id)
| 16 | |
| 17 | ################################## IO ################################## |
| 18 | def save(net,path,gpu_id): |
| 19 | if isinstance(net, nn.DataParallel): |
| 20 | torch.save(net.module.cpu().state_dict(),path) |
| 21 | else: |
| 22 | torch.save(net.cpu().state_dict(),path) |
| 23 | if gpu_id != '-1': |
| 24 | net.cuda() |
| 25 | |
| 26 | def todevice(net,gpu_id): |
| 27 | if gpu_id != '-1' and len(gpu_id) == 1: |