one epoch training
(train_loader, model, criterion, optimizer, epoch, opt)
| 195 | |
| 196 | |
| 197 | def train(train_loader, model, criterion, optimizer, epoch, opt): |
| 198 | """one epoch training""" |
| 199 | model.train() |
| 200 | |
| 201 | batch_time = AverageMeter() |
| 202 | data_time = AverageMeter() |
| 203 | losses = AverageMeter() |
| 204 | |
| 205 | end = time.time() |
| 206 | for idx, (images, labels) in enumerate(train_loader): |
| 207 | data_time.update(time.time() - end) |
| 208 | |
| 209 | images = torch.cat([images[0], images[1]], dim=0) |
| 210 | if torch.cuda.is_available(): |
| 211 | images = images.cuda(non_blocking=True) |
| 212 | labels = labels.cuda(non_blocking=True) |
| 213 | bsz = labels.shape[0] |
| 214 | |
| 215 | # warm-up learning rate |
| 216 | warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer) |
| 217 | |
| 218 | # compute loss |
| 219 | features = model(images) |
| 220 | f1, f2 = torch.split(features, [bsz, bsz], dim=0) |
| 221 | features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1) |
| 222 | if opt.method == 'SupCon': |
| 223 | loss = criterion(features, labels) |
| 224 | elif opt.method == 'SimCLR': |
| 225 | loss = criterion(features) |
| 226 | else: |
| 227 | raise ValueError('contrastive method not supported: {}'. |
| 228 | format(opt.method)) |
| 229 | |
| 230 | # update metric |
| 231 | losses.update(loss.item(), bsz) |
| 232 | |
| 233 | # SGD |
| 234 | optimizer.zero_grad() |
| 235 | loss.backward() |
| 236 | optimizer.step() |
| 237 | |
| 238 | # measure elapsed time |
| 239 | batch_time.update(time.time() - end) |
| 240 | end = time.time() |
| 241 | |
| 242 | # print info |
| 243 | if (idx + 1) % opt.print_freq == 0: |
| 244 | print('Train: [{0}][{1}/{2}]\t' |
| 245 | 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t' |
| 246 | 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t' |
| 247 | 'loss {loss.val:.3f} ({loss.avg:.3f})'.format( |
| 248 | epoch, idx + 1, len(train_loader), batch_time=batch_time, |
| 249 | data_time=data_time, loss=losses)) |
| 250 | sys.stdout.flush() |
| 251 | |
| 252 | return losses.avg |
| 253 | |
| 254 |
no test coverage detected