MCPcopy
hub / github.com/dbolya/yolact / train

Function train

train.py:172–385  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

170 return out
171
172def train():
173 if not os.path.exists(args.save_folder):
174 os.mkdir(args.save_folder)
175
176 dataset = COCODetection(image_path=cfg.dataset.train_images,
177 info_file=cfg.dataset.train_info,
178 transform=SSDAugmentation(MEANS))
179
180 if args.validation_epoch > 0:
181 setup_eval()
182 val_dataset = COCODetection(image_path=cfg.dataset.valid_images,
183 info_file=cfg.dataset.valid_info,
184 transform=BaseTransform(MEANS))
185
186 # Parallel wraps the underlying module, but when saving and loading we don't want that
187 yolact_net = Yolact()
188 net = yolact_net
189 net.train()
190
191 if args.log:
192 log = Log(cfg.name, args.log_folder, dict(args._get_kwargs()),
193 overwrite=(args.resume is None), log_gpu_stats=args.log_gpu)
194
195 # I don't use the timer during training (I use a different timing method).
196 # Apparently there's a race condition with multiple GPUs, so disable it just to be safe.
197 timer.disable_all()
198
199 # Both of these can set args.resume to None, so do them before the check
200 if args.resume == 'interrupt':
201 args.resume = SavePath.get_interrupt(args.save_folder)
202 elif args.resume == 'latest':
203 args.resume = SavePath.get_latest(args.save_folder, cfg.name)
204
205 if args.resume is not None:
206 print('Resuming training, loading {}...'.format(args.resume))
207 yolact_net.load_weights(args.resume)
208
209 if args.start_iter == -1:
210 args.start_iter = SavePath.from_str(args.resume).iteration
211 else:
212 print('Initializing weights...')
213 yolact_net.init_weights(backbone_path=args.save_folder + cfg.backbone.path)
214
215 optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum,
216 weight_decay=args.decay)
217 criterion = MultiBoxLoss(num_classes=cfg.num_classes,
218 pos_threshold=cfg.positive_iou_threshold,
219 neg_threshold=cfg.negative_iou_threshold,
220 negpos_ratio=cfg.ohem_negpos_ratio)
221
222 if args.batch_alloc is not None:
223 args.batch_alloc = [int(x) for x in args.batch_alloc.split(',')]
224 if sum(args.batch_alloc) != args.batch_size:
225 print('Error: Batch allocation (%s) does not sum to batch size (%s).' % (args.batch_alloc, args.batch_size))
226 exit(-1)
227
228 net = CustomDataParallel(NetLoss(net, criterion))
229 if args.cuda:

Callers 1

train.pyFile · 0.85

Calls 15

load_weightsMethod · 0.95
init_weightsMethod · 0.95
freeze_bnMethod · 0.95
addMethod · 0.95
get_avgMethod · 0.95
logMethod · 0.95
save_weightsMethod · 0.95
SSDAugmentationClass · 0.90
BaseTransformClass · 0.90
YolactClass · 0.90
LogClass · 0.90
MultiBoxLossClass · 0.90

Tested by

no test coverage detected