MCPcopy
hub / github.com/hustvl/Vim / train_one_epoch

Function train_one_epoch

vim/engine.py:20–101  ·  view source on GitHub ↗
(model: torch.nn.Module, criterion: DistillationLoss,
                    data_loader: Iterable, optimizer: torch.optim.Optimizer,
                    device: torch.device, epoch: int, loss_scaler, amp_autocast, max_norm: float = 0,
                    model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
                    set_training_mode=True, args = None)

Source from the content-addressed store, hash-verified

18
19
20def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss,
21 data_loader: Iterable, optimizer: torch.optim.Optimizer,
22 device: torch.device, epoch: int, loss_scaler, amp_autocast, max_norm: float = 0,
23 model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
24 set_training_mode=True, args = None):
25 model.train(set_training_mode)
26 metric_logger = utils.MetricLogger(delimiter=" ")
27 metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
28 header = 'Epoch: [{}]'.format(epoch)
29 print_freq = 10
30
31 if args.cosub:
32 criterion = torch.nn.BCEWithLogitsLoss()
33
34 # debug
35 # count = 0
36 for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
37 # count += 1
38 # if count > 20:
39 # break
40
41 samples = samples.to(device, non_blocking=True)
42 targets = targets.to(device, non_blocking=True)
43
44 if mixup_fn is not None:
45 samples, targets = mixup_fn(samples, targets)
46
47 if args.cosub:
48 samples = torch.cat((samples,samples),dim=0)
49
50 if args.bce_loss:
51 targets = targets.gt(0.0).type(targets.dtype)
52
53 with amp_autocast():
54 outputs = model(samples, if_random_cls_token_position=args.if_random_cls_token_position, if_random_token_rank=args.if_random_token_rank)
55 # outputs = model(samples)
56 if not args.cosub:
57 loss = criterion(samples, outputs, targets)
58 else:
59 outputs = torch.split(outputs, outputs.shape[0]//2, dim=0)
60 loss = 0.25 * criterion(outputs[0], targets)
61 loss = loss + 0.25 * criterion(outputs[1], targets)
62 loss = loss + 0.25 * criterion(outputs[0], outputs[1].detach().sigmoid())
63 loss = loss + 0.25 * criterion(outputs[1], outputs[0].detach().sigmoid())
64
65 if args.if_nan2num:
66 with amp_autocast():
67 loss = torch.nan_to_num(loss)
68
69 loss_value = loss.item()
70
71 if not math.isfinite(loss_value):
72 print("Loss is {}, stopping training".format(loss_value))
73 if args.if_continue_inf:
74 optimizer.zero_grad()
75 continue
76 else:
77 sys.exit(1)

Callers 1

mainFunction · 0.90

Calls 10

add_meterMethod · 0.95
log_everyMethod · 0.95
updateMethod · 0.95
printFunction · 0.85
trainMethod · 0.45
toMethod · 0.45
catMethod · 0.45
backwardMethod · 0.45
stepMethod · 0.45

Tested by

no test coverage detected