MCPcopy
hub / github.com/pytorch/vision / train_one_epoch

Function train_one_epoch

references/detection/engine.py:12–60  ·  view source on GitHub ↗
(model, optimizer, data_loader, device, epoch, print_freq, scaler=None)

Source from the content-addressed store, hash-verified

10
11
12def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None):
13 model.train()
14 metric_logger = utils.MetricLogger(delimiter=" ")
15 metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
16 header = f"Epoch: [{epoch}]"
17
18 lr_scheduler = None
19 if epoch == 0:
20 warmup_factor = 1.0 / 1000
21 warmup_iters = min(1000, len(data_loader) - 1)
22
23 lr_scheduler = torch.optim.lr_scheduler.LinearLR(
24 optimizer, start_factor=warmup_factor, total_iters=warmup_iters
25 )
26
27 for images, targets in metric_logger.log_every(data_loader, print_freq, header):
28 images = list(image.to(device) for image in images)
29 targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
30 with torch.cuda.amp.autocast(enabled=scaler is not None):
31 loss_dict = model(images, targets)
32 losses = sum(loss for loss in loss_dict.values())
33
34 # reduce losses over all GPUs for logging purposes
35 loss_dict_reduced = utils.reduce_dict(loss_dict)
36 losses_reduced = sum(loss for loss in loss_dict_reduced.values())
37
38 loss_value = losses_reduced.item()
39
40 if not math.isfinite(loss_value):
41 print(f"Loss is {loss_value}, stopping training")
42 print(loss_dict_reduced)
43 sys.exit(1)
44
45 optimizer.zero_grad()
46 if scaler is not None:
47 scaler.scale(losses).backward()
48 scaler.step(optimizer)
49 scaler.update()
50 else:
51 losses.backward()
52 optimizer.step()
53
54 if lr_scheduler is not None:
55 lr_scheduler.step()
56
57 metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
58 metric_logger.update(lr=optimizer.param_groups[0]["lr"])
59
60 return metric_logger
61
62
63def _get_iou_types(model):

Callers 1

mainFunction · 0.90

Calls 7

add_meterMethod · 0.95
log_everyMethod · 0.95
updateMethod · 0.95
trainMethod · 0.80
toMethod · 0.80
printFunction · 0.70
backwardMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…