(opt)
| 26 | opt = BaseOptions().parse() |
| 27 | |
| 28 | def train_color(opt): |
| 29 | # set cuda |
| 30 | cuda = torch.device('cuda:%d' % opt.gpu_id) |
| 31 | |
| 32 | train_dataset = TrainDataset(opt, phase='train') |
| 33 | test_dataset = TrainDataset(opt, phase='test') |
| 34 | |
| 35 | projection_mode = train_dataset.projection_mode |
| 36 | |
| 37 | # create data loader |
| 38 | train_data_loader = DataLoader(train_dataset, |
| 39 | batch_size=opt.batch_size, shuffle=not opt.serial_batches, |
| 40 | num_workers=opt.num_threads, pin_memory=opt.pin_memory) |
| 41 | |
| 42 | print('train data size: ', len(train_data_loader)) |
| 43 | |
| 44 | # NOTE: batch size should be 1 and use all the points for evaluation |
| 45 | test_data_loader = DataLoader(test_dataset, |
| 46 | batch_size=1, shuffle=False, |
| 47 | num_workers=opt.num_threads, pin_memory=opt.pin_memory) |
| 48 | print('test data size: ', len(test_data_loader)) |
| 49 | |
| 50 | # create net |
| 51 | netG = HGPIFuNet(opt, projection_mode).to(device=cuda) |
| 52 | |
| 53 | lr = opt.learning_rate |
| 54 | |
| 55 | # Always use resnet for color regression |
| 56 | netC = ResBlkPIFuNet(opt).to(device=cuda) |
| 57 | optimizerC = torch.optim.Adam(netC.parameters(), lr=opt.learning_rate) |
| 58 | |
| 59 | def set_train(): |
| 60 | netG.eval() |
| 61 | netC.train() |
| 62 | |
| 63 | def set_eval(): |
| 64 | netG.eval() |
| 65 | netC.eval() |
| 66 | |
| 67 | print('Using NetworkG: ', netG.name, 'networkC: ', netC.name) |
| 68 | |
| 69 | # load checkpoints |
| 70 | if opt.load_netG_checkpoint_path is not None: |
| 71 | print('loading for net G ...', opt.load_netG_checkpoint_path) |
| 72 | netG.load_state_dict(torch.load(opt.load_netG_checkpoint_path, map_location=cuda)) |
| 73 | else: |
| 74 | model_path_G = '%s/%s/netG_latest' % (opt.checkpoints_path, opt.name) |
| 75 | print('loading for net G ...', model_path_G) |
| 76 | netG.load_state_dict(torch.load(model_path_G, map_location=cuda)) |
| 77 | |
| 78 | if opt.load_netC_checkpoint_path is not None: |
| 79 | print('loading for net C ...', opt.load_netC_checkpoint_path) |
| 80 | netC.load_state_dict(torch.load(opt.load_netC_checkpoint_path, map_location=cuda)) |
| 81 | |
| 82 | if opt.continue_train: |
| 83 | if opt.resume_epoch < 0: |
| 84 | model_path_C = '%s/%s/netC_latest' % (opt.checkpoints_path, opt.name) |
| 85 | else: |
no test coverage detected