MCPcopy
hub / github.com/shunsukesaito/PIFu / train_color

Function train_color

apps/train_color.py:28–188  ·  view source on GitHub ↗
(opt)

Source from the content-addressed store, hash-verified

26opt = BaseOptions().parse()
27
28def train_color(opt):
29 # set cuda
30 cuda = torch.device('cuda:%d' % opt.gpu_id)
31
32 train_dataset = TrainDataset(opt, phase='train')
33 test_dataset = TrainDataset(opt, phase='test')
34
35 projection_mode = train_dataset.projection_mode
36
37 # create data loader
38 train_data_loader = DataLoader(train_dataset,
39 batch_size=opt.batch_size, shuffle=not opt.serial_batches,
40 num_workers=opt.num_threads, pin_memory=opt.pin_memory)
41
42 print('train data size: ', len(train_data_loader))
43
44 # NOTE: batch size should be 1 and use all the points for evaluation
45 test_data_loader = DataLoader(test_dataset,
46 batch_size=1, shuffle=False,
47 num_workers=opt.num_threads, pin_memory=opt.pin_memory)
48 print('test data size: ', len(test_data_loader))
49
50 # create net
51 netG = HGPIFuNet(opt, projection_mode).to(device=cuda)
52
53 lr = opt.learning_rate
54
55 # Always use resnet for color regression
56 netC = ResBlkPIFuNet(opt).to(device=cuda)
57 optimizerC = torch.optim.Adam(netC.parameters(), lr=opt.learning_rate)
58
59 def set_train():
60 netG.eval()
61 netC.train()
62
63 def set_eval():
64 netG.eval()
65 netC.eval()
66
67 print('Using NetworkG: ', netG.name, 'networkC: ', netC.name)
68
69 # load checkpoints
70 if opt.load_netG_checkpoint_path is not None:
71 print('loading for net G ...', opt.load_netG_checkpoint_path)
72 netG.load_state_dict(torch.load(opt.load_netG_checkpoint_path, map_location=cuda))
73 else:
74 model_path_G = '%s/%s/netG_latest' % (opt.checkpoints_path, opt.name)
75 print('loading for net G ...', model_path_G)
76 netG.load_state_dict(torch.load(model_path_G, map_location=cuda))
77
78 if opt.load_netC_checkpoint_path is not None:
79 print('loading for net C ...', opt.load_netC_checkpoint_path)
80 netC.load_state_dict(torch.load(opt.load_netC_checkpoint_path, map_location=cuda))
81
82 if opt.continue_train:
83 if opt.resume_epoch < 0:
84 model_path_C = '%s/%s/netC_latest' % (opt.checkpoints_path, opt.name)
85 else:

Callers 1

train_color.pyFile · 0.85

Calls 13

TrainDatasetClass · 0.85
HGPIFuNetClass · 0.85
ResBlkPIFuNetClass · 0.85
save_samples_rgbFunction · 0.85
get_im_featMethod · 0.80
set_trainFunction · 0.70
set_evalFunction · 0.70
reshape_sample_tensorFunction · 0.50
calc_error_colorFunction · 0.50
gen_mesh_colorFunction · 0.50
filterMethod · 0.45

Tested by

no test coverage detected