one epoch training
(train_loader, model, classifier, criterion, optimizer, epoch, opt)
| 131 | |
| 132 | |
| 133 | def train(train_loader, model, classifier, criterion, optimizer, epoch, opt): |
| 134 | """one epoch training""" |
| 135 | model.eval() |
| 136 | classifier.train() |
| 137 | |
| 138 | batch_time = AverageMeter() |
| 139 | data_time = AverageMeter() |
| 140 | losses = AverageMeter() |
| 141 | top1 = AverageMeter() |
| 142 | |
| 143 | end = time.time() |
| 144 | for idx, (images, labels) in enumerate(train_loader): |
| 145 | data_time.update(time.time() - end) |
| 146 | |
| 147 | images = images.cuda(non_blocking=True) |
| 148 | labels = labels.cuda(non_blocking=True) |
| 149 | bsz = labels.shape[0] |
| 150 | |
| 151 | # warm-up learning rate |
| 152 | warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer) |
| 153 | |
| 154 | # compute loss |
| 155 | with torch.no_grad(): |
| 156 | features = model.encoder(images) |
| 157 | output = classifier(features.detach()) |
| 158 | loss = criterion(output, labels) |
| 159 | |
| 160 | # update metric |
| 161 | losses.update(loss.item(), bsz) |
| 162 | acc1, acc5 = accuracy(output, labels, topk=(1, 5)) |
| 163 | top1.update(acc1[0], bsz) |
| 164 | |
| 165 | # SGD |
| 166 | optimizer.zero_grad() |
| 167 | loss.backward() |
| 168 | optimizer.step() |
| 169 | |
| 170 | # measure elapsed time |
| 171 | batch_time.update(time.time() - end) |
| 172 | end = time.time() |
| 173 | |
| 174 | # print info |
| 175 | if (idx + 1) % opt.print_freq == 0: |
| 176 | print('Train: [{0}][{1}/{2}]\t' |
| 177 | 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t' |
| 178 | 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t' |
| 179 | 'loss {loss.val:.3f} ({loss.avg:.3f})\t' |
| 180 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format( |
| 181 | epoch, idx + 1, len(train_loader), batch_time=batch_time, |
| 182 | data_time=data_time, loss=losses, top1=top1)) |
| 183 | sys.stdout.flush() |
| 184 | |
| 185 | return losses.avg, top1.avg |
| 186 | |
| 187 | |
| 188 | def validate(val_loader, model, classifier, criterion, opt): |
no test coverage detected