MCPcopy
hub / github.com/microsoft/Cream / train

Function train

CDARTS/benchmark201/core/augment_function.py:7–80  ·  view source on GitHub ↗
(train_loader, model, optimizer, epoch, writer, logger, config)

Source from the content-addressed store, hash-verified

5from models.loss import CrossEntropyLabelSmooth
6
7def train(train_loader, model, optimizer, epoch, writer, logger, config):
8 device = torch.device("cuda")
9 if config.label_smooth > 0:
10 criterion = CrossEntropyLabelSmooth(config.n_classes, config.label_smooth).to(device)
11 else:
12 criterion = nn.CrossEntropyLoss().to(device)
13
14 top1 = utils.AverageMeter()
15 top5 = utils.AverageMeter()
16 losses = utils.AverageMeter()
17
18 step_num = len(train_loader)
19 cur_step = epoch*step_num
20 cur_lr = optimizer.param_groups[0]['lr']
21 if config.local_rank == 0:
22 logger.info("Train Epoch {} LR {}".format(epoch, cur_lr))
23 writer.add_scalar('train/lr', cur_lr, cur_step)
24
25 model.train()
26
27 for step, (X, y) in enumerate(train_loader):
28 X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)
29 N = X.size(0)
30
31 X, target_a, target_b, lam = data_utils.mixup_data(X, y, config.mixup_alpha, use_cuda=True)
32
33 optimizer.zero_grad()
34 logits, logits_aux = model(X)
35 # loss = criterion(logits, y)
36 loss = data_utils.mixup_criterion(criterion, logits, target_a, target_b, lam)
37 if config.aux_weight > 0:
38 # loss_aux = criterion(logits_aux, y)
39 loss_aux = data_utils.mixup_criterion(criterion, logits_aux, target_a, target_b, lam)
40 loss = loss + config.aux_weight * loss_aux
41
42 if config.use_amp:
43 from apex import amp
44 with amp.scale_loss(loss, optimizer) as scaled_loss:
45 scaled_loss.backward()
46 else:
47 loss.backward()
48 # gradient clipping
49 nn.utils.clip_grad_norm_(model.module.parameters(), config.grad_clip)
50 optimizer.step()
51
52 prec1, prec5 = utils.accuracy(logits, y, topk=(1, 5))
53 if config.distributed:
54 reduced_loss = utils.reduce_tensor(loss.data, config.world_size)
55 prec1 = utils.reduce_tensor(prec1, config.world_size)
56 prec5 = utils.reduce_tensor(prec5, config.world_size)
57 else:
58 reduced_loss = loss.data
59
60 losses.update(reduced_loss.item(), N)
61 top1.update(prec1.item(), N)
62 top5.update(prec5.item(), N)
63
64 torch.cuda.synchronize()

Callers

nothing calls this directly

Calls 9

updateMethod · 0.95
toMethod · 0.80
formatMethod · 0.80
zero_gradMethod · 0.80
trainMethod · 0.45
sizeMethod · 0.45
backwardMethod · 0.45
stepMethod · 0.45

Tested by

no test coverage detected