MCPcopy
hub / github.com/microsoft/Cream / train

Function train

CDARTS/benchmark201/core/pretrain_function.py:7–79  ·  view source on GitHub ↗
(train_loader, model, optimizer, epoch, writer, logger, config)

Source from the content-addressed store, hash-verified

5from models.loss import CrossEntropyLabelSmooth
6
7def train(train_loader, model, optimizer, epoch, writer, logger, config):
8 device = torch.device("cuda")
9 if config.label_smooth > 0:
10 criterion = CrossEntropyLabelSmooth(config.n_classes, config.label_smooth).to(device)
11 else:
12 criterion = nn.CrossEntropyLoss().to(device)
13
14 top1 = utils.AverageMeter()
15 top5 = utils.AverageMeter()
16 losses = utils.AverageMeter()
17
18 step_num = len(train_loader)
19 cur_step = epoch*step_num
20 cur_lr = optimizer.param_groups[0]['lr']
21 if config.local_rank == 0:
22 logger.info("Train Epoch {} LR {}".format(epoch, cur_lr))
23 writer.add_scalar('train/lr', cur_lr, cur_step)
24
25 model.train()
26
27 for step, (X, y) in enumerate(train_loader):
28 X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)
29 N = X.size(0)
30
31 X, target_a, target_b, lam = data_utils.mixup_data(X, y, config.mixup_alpha, use_cuda=True)
32
33 optimizer.zero_grad()
34 logits, logits_aux = model(X, layer_idx=0, super_flag=True, pretrain_flag=True)
35 loss = data_utils.mixup_criterion(criterion, logits, target_a, target_b, lam)
36 if config.aux_weight > 0:
37 # loss_aux = criterion(logits_aux, y)
38 loss_aux = data_utils.mixup_criterion(criterion, logits_aux, target_a, target_b, lam)
39 loss = loss + config.aux_weight * loss_aux
40
41 if config.use_amp:
42 from apex import amp
43 with amp.scale_loss(loss, optimizer) as scaled_loss:
44 scaled_loss.backward()
45 else:
46 loss.backward()
47 # gradient clipping
48 nn.utils.clip_grad_norm_(model.module.parameters(), config.grad_clip)
49 optimizer.step()
50
51 prec1, prec5 = utils.accuracy(logits, y, topk=(1, 5))
52 if config.distributed:
53 reduced_loss = utils.reduce_tensor(loss.data, config.world_size)
54 prec1 = utils.reduce_tensor(prec1, config.world_size)
55 prec5 = utils.reduce_tensor(prec5, config.world_size)
56 else:
57 reduced_loss = loss.data
58
59 losses.update(reduced_loss.item(), N)
60 top1.update(prec1.item(), N)
61 top5.update(prec5.item(), N)
62
63 torch.cuda.synchronize()
64 if config.local_rank == 0 and (step % config.print_freq == 0 or step == step_num):

Callers

nothing calls this directly

Calls 9

updateMethod · 0.95
toMethod · 0.80
formatMethod · 0.80
zero_gradMethod · 0.80
trainMethod · 0.45
sizeMethod · 0.45
backwardMethod · 0.45
stepMethod · 0.45

Tested by

no test coverage detected