| 10 | |
| 11 | |
| 12 | def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None): |
| 13 | model.train() |
| 14 | metric_logger = utils.MetricLogger(delimiter=" ") |
| 15 | metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}")) |
| 16 | header = f"Epoch: [{epoch}]" |
| 17 | |
| 18 | lr_scheduler = None |
| 19 | if epoch == 0: |
| 20 | warmup_factor = 1.0 / 1000 |
| 21 | warmup_iters = min(1000, len(data_loader) - 1) |
| 22 | |
| 23 | lr_scheduler = torch.optim.lr_scheduler.LinearLR( |
| 24 | optimizer, start_factor=warmup_factor, total_iters=warmup_iters |
| 25 | ) |
| 26 | |
| 27 | for images, targets in metric_logger.log_every(data_loader, print_freq, header): |
| 28 | images = list(image.to(device) for image in images) |
| 29 | targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets] |
| 30 | with torch.cuda.amp.autocast(enabled=scaler is not None): |
| 31 | loss_dict = model(images, targets) |
| 32 | losses = sum(loss for loss in loss_dict.values()) |
| 33 | |
| 34 | # reduce losses over all GPUs for logging purposes |
| 35 | loss_dict_reduced = utils.reduce_dict(loss_dict) |
| 36 | losses_reduced = sum(loss for loss in loss_dict_reduced.values()) |
| 37 | |
| 38 | loss_value = losses_reduced.item() |
| 39 | |
| 40 | if not math.isfinite(loss_value): |
| 41 | print(f"Loss is {loss_value}, stopping training") |
| 42 | print(loss_dict_reduced) |
| 43 | sys.exit(1) |
| 44 | |
| 45 | optimizer.zero_grad() |
| 46 | if scaler is not None: |
| 47 | scaler.scale(losses).backward() |
| 48 | scaler.step(optimizer) |
| 49 | scaler.update() |
| 50 | else: |
| 51 | losses.backward() |
| 52 | optimizer.step() |
| 53 | |
| 54 | if lr_scheduler is not None: |
| 55 | lr_scheduler.step() |
| 56 | |
| 57 | metric_logger.update(loss=losses_reduced, **loss_dict_reduced) |
| 58 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) |
| 59 | |
| 60 | return metric_logger |
| 61 | |
| 62 | |
| 63 | def _get_iou_types(model): |