MCPcopy
hub / github.com/microsoft/Cream / train_one_epoch

Function train_one_epoch

EfficientViT/classification/engine.py:21–73  ·  view source on GitHub ↗
(model: torch.nn.Module, criterion: DistillationLoss,
                    data_loader: Iterable, optimizer: torch.optim.Optimizer,
                    device: torch.device, epoch: int, loss_scaler,
                    clip_grad: float = 0,
                    clip_mode: str = 'norm',
                    model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
                    set_training_mode=True,
                    set_bn_eval=False,)

Source from the content-addressed store, hash-verified

19 m.eval()
20
21def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss,
22 data_loader: Iterable, optimizer: torch.optim.Optimizer,
23 device: torch.device, epoch: int, loss_scaler,
24 clip_grad: float = 0,
25 clip_mode: str = 'norm',
26 model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
27 set_training_mode=True,
28 set_bn_eval=False,):
29 model.train(set_training_mode)
30 if set_bn_eval:
31 set_bn_state(model)
32 metric_logger = utils.MetricLogger(delimiter=" ")
33 metric_logger.add_meter('lr', utils.SmoothedValue(
34 window_size=1, fmt='{value:.6f}'))
35 header = 'Epoch: [{}]'.format(epoch)
36 print_freq = 100
37
38 for samples, targets in metric_logger.log_every(
39 data_loader, print_freq, header):
40 samples = samples.to(device, non_blocking=True)
41 targets = targets.to(device, non_blocking=True)
42
43 if mixup_fn is not None:
44 samples, targets = mixup_fn(samples, targets)
45
46 if True: # with torch.cuda.amp.autocast():
47 outputs = model(samples)
48 loss = criterion(samples, outputs, targets)
49
50 loss_value = loss.item()
51
52 if not math.isfinite(loss_value):
53 print("Loss is {}, stopping training".format(loss_value))
54 sys.exit(1)
55
56 optimizer.zero_grad()
57
58 # this attribute is added by timm on one optimizer (adahessian)
59 is_second_order = hasattr(
60 optimizer, 'is_second_order') and optimizer.is_second_order
61 loss_scaler(loss, optimizer, clip_grad=clip_grad, clip_mode=clip_mode,
62 parameters=model.parameters(), create_graph=is_second_order)
63
64 torch.cuda.synchronize()
65 if model_ema is not None:
66 model_ema.update(model)
67
68 metric_logger.update(loss=loss_value)
69 metric_logger.update(lr=optimizer.param_groups[0]["lr"])
70 # gather the stats from all processes
71 metric_logger.synchronize_between_processes()
72 print("Averaged stats:", metric_logger)
73 return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
74
75
76@torch.no_grad()

Callers 1

mainFunction · 0.90

Calls 10

add_meterMethod · 0.95
log_everyMethod · 0.95
updateMethod · 0.95
formatMethod · 0.80
toMethod · 0.80
zero_gradMethod · 0.80
set_bn_stateFunction · 0.70
printFunction · 0.70
trainMethod · 0.45

Tested by

no test coverage detected