()
| 170 | return out |
| 171 | |
| 172 | def train(): |
| 173 | if not os.path.exists(args.save_folder): |
| 174 | os.mkdir(args.save_folder) |
| 175 | |
| 176 | dataset = COCODetection(image_path=cfg.dataset.train_images, |
| 177 | info_file=cfg.dataset.train_info, |
| 178 | transform=SSDAugmentation(MEANS)) |
| 179 | |
| 180 | if args.validation_epoch > 0: |
| 181 | setup_eval() |
| 182 | val_dataset = COCODetection(image_path=cfg.dataset.valid_images, |
| 183 | info_file=cfg.dataset.valid_info, |
| 184 | transform=BaseTransform(MEANS)) |
| 185 | |
| 186 | # Parallel wraps the underlying module, but when saving and loading we don't want that |
| 187 | yolact_net = Yolact() |
| 188 | net = yolact_net |
| 189 | net.train() |
| 190 | |
| 191 | if args.log: |
| 192 | log = Log(cfg.name, args.log_folder, dict(args._get_kwargs()), |
| 193 | overwrite=(args.resume is None), log_gpu_stats=args.log_gpu) |
| 194 | |
| 195 | # I don't use the timer during training (I use a different timing method). |
| 196 | # Apparently there's a race condition with multiple GPUs, so disable it just to be safe. |
| 197 | timer.disable_all() |
| 198 | |
| 199 | # Both of these can set args.resume to None, so do them before the check |
| 200 | if args.resume == 'interrupt': |
| 201 | args.resume = SavePath.get_interrupt(args.save_folder) |
| 202 | elif args.resume == 'latest': |
| 203 | args.resume = SavePath.get_latest(args.save_folder, cfg.name) |
| 204 | |
| 205 | if args.resume is not None: |
| 206 | print('Resuming training, loading {}...'.format(args.resume)) |
| 207 | yolact_net.load_weights(args.resume) |
| 208 | |
| 209 | if args.start_iter == -1: |
| 210 | args.start_iter = SavePath.from_str(args.resume).iteration |
| 211 | else: |
| 212 | print('Initializing weights...') |
| 213 | yolact_net.init_weights(backbone_path=args.save_folder + cfg.backbone.path) |
| 214 | |
| 215 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum, |
| 216 | weight_decay=args.decay) |
| 217 | criterion = MultiBoxLoss(num_classes=cfg.num_classes, |
| 218 | pos_threshold=cfg.positive_iou_threshold, |
| 219 | neg_threshold=cfg.negative_iou_threshold, |
| 220 | negpos_ratio=cfg.ohem_negpos_ratio) |
| 221 | |
| 222 | if args.batch_alloc is not None: |
| 223 | args.batch_alloc = [int(x) for x in args.batch_alloc.split(',')] |
| 224 | if sum(args.batch_alloc) != args.batch_size: |
| 225 | print('Error: Batch allocation (%s) does not sum to batch size (%s).' % (args.batch_alloc, args.batch_size)) |
| 226 | exit(-1) |
| 227 | |
| 228 | net = CustomDataParallel(NetLoss(net, criterion)) |
| 229 | if args.cuda: |
no test coverage detected