MCPcopy
hub / github.com/xiaolai-sqlai/mobilenetv3 / train_one_epoch

Function train_one_epoch

engine.py:17–135  ·  view source on GitHub ↗
(model: torch.nn.Module, criterion: torch.nn.Module,
                    data_loader: Iterable, optimizer: torch.optim.Optimizer,
                    device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
                    model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None,
                    wandb_logger=None, start_steps=None, lr_schedule_values=None, wd_schedule_values=None,
                    num_training_steps_per_epoch=None, update_freq=None, use_amp=False)

Source from the content-addressed store, hash-verified

15import utils
16
17def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
18 data_loader: Iterable, optimizer: torch.optim.Optimizer,
19 device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
20 model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None,
21 wandb_logger=None, start_steps=None, lr_schedule_values=None, wd_schedule_values=None,
22 num_training_steps_per_epoch=None, update_freq=None, use_amp=False):
23 model.train(True)
24 metric_logger = utils.MetricLogger(delimiter=" ")
25 metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
26 metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
27 header = 'Epoch: [{}]'.format(epoch)
28 print_freq = 200
29
30 optimizer.zero_grad()
31
32 for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
33 step = data_iter_step // update_freq
34 if step >= num_training_steps_per_epoch:
35 continue
36 it = start_steps + step # global training iteration
37 # Update LR & WD for the first acc
38 if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0:
39 for i, param_group in enumerate(optimizer.param_groups):
40 if lr_schedule_values is not None:
41 param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"]
42 if wd_schedule_values is not None and param_group["weight_decay"] > 0:
43 param_group["weight_decay"] = wd_schedule_values[it]
44
45 samples = samples.to(device, non_blocking=True)
46 targets = targets.to(device, non_blocking=True)
47
48 if mixup_fn is not None:
49 samples, targets = mixup_fn(samples, targets)
50
51 if use_amp:
52 with torch.cuda.amp.autocast():
53 output = model(samples)
54 loss = criterion(output, targets)
55 else: # full precision
56 output = model(samples)
57 loss = criterion(output, targets)
58
59 loss_value = loss.item()
60
61 if not math.isfinite(loss_value): # this could trigger if using AMP
62 print("Loss is {}, stopping training".format(loss_value))
63 assert math.isfinite(loss_value)
64
65 if use_amp:
66 # this attribute is added by timm on one optimizer (adahessian)
67 is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
68 loss /= update_freq
69 grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm,
70 parameters=model.parameters(), create_graph=is_second_order,
71 update_grad=(data_iter_step + 1) % update_freq == 0)
72 if (data_iter_step + 1) % update_freq == 0:
73 optimizer.zero_grad()
74 if model_ema is not None:

Callers 1

mainFunction · 0.90

Calls 6

add_meterMethod · 0.95
log_everyMethod · 0.95
updateMethod · 0.95
printFunction · 0.85
maxMethod · 0.80

Tested by

no test coverage detected