MCPcopy
hub / github.com/amdegroot/ssd.pytorch / train

Function train

train.py:71–199  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

69
70
71def train():
72 if args.dataset == 'COCO':
73 if args.dataset_root == VOC_ROOT:
74 if not os.path.exists(COCO_ROOT):
75 parser.error('Must specify dataset_root if specifying dataset')
76 print("WARNING: Using default COCO dataset_root because " +
77 "--dataset_root was not specified.")
78 args.dataset_root = COCO_ROOT
79 cfg = coco
80 dataset = COCODetection(root=args.dataset_root,
81 transform=SSDAugmentation(cfg['min_dim'],
82 MEANS))
83 elif args.dataset == 'VOC':
84 if args.dataset_root == COCO_ROOT:
85 parser.error('Must specify dataset if specifying dataset_root')
86 cfg = voc
87 dataset = VOCDetection(root=args.dataset_root,
88 transform=SSDAugmentation(cfg['min_dim'],
89 MEANS))
90
91 if args.visdom:
92 import visdom
93 viz = visdom.Visdom()
94
95 ssd_net = build_ssd('train', cfg['min_dim'], cfg['num_classes'])
96 net = ssd_net
97
98 if args.cuda:
99 net = torch.nn.DataParallel(ssd_net)
100 cudnn.benchmark = True
101
102 if args.resume:
103 print('Resuming training, loading {}...'.format(args.resume))
104 ssd_net.load_weights(args.resume)
105 else:
106 vgg_weights = torch.load(args.save_folder + args.basenet)
107 print('Loading base network...')
108 ssd_net.vgg.load_state_dict(vgg_weights)
109
110 if args.cuda:
111 net = net.cuda()
112
113 if not args.resume:
114 print('Initializing weights...')
115 # initialize newly added layers' weights with xavier method
116 ssd_net.extras.apply(weights_init)
117 ssd_net.loc.apply(weights_init)
118 ssd_net.conf.apply(weights_init)
119
120 optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum,
121 weight_decay=args.weight_decay)
122 criterion = MultiBoxLoss(cfg['num_classes'], 0.5, True, 0, True, 3, 0.5,
123 False, args.cuda)
124
125 net.train()
126 # loss counters
127 loc_loss = 0
128 conf_loss = 0

Callers 1

train.pyFile · 0.85

Calls 9

SSDAugmentationClass · 0.90
build_ssdFunction · 0.90
MultiBoxLossClass · 0.90
COCODetectionClass · 0.85
VOCDetectionClass · 0.85
create_vis_plotFunction · 0.85
update_vis_plotFunction · 0.85
adjust_learning_rateFunction · 0.85
load_weightsMethod · 0.80

Tested by

no test coverage detected