(train_loader, model, criterion, optimizer, epoch, args, lr_scheduler, is_main)
| 316 | |
| 317 | |
| 318 | def train(train_loader, model, criterion, optimizer, epoch, args, lr_scheduler, is_main): |
| 319 | batch_time = AverageMeter('Time', ':6.3f') |
| 320 | data_time = AverageMeter('Data', ':6.3f') |
| 321 | losses = AverageMeter('Loss', ':.4e') |
| 322 | top1 = AverageMeter('Acc@1', ':6.2f') |
| 323 | top5 = AverageMeter('Acc@5', ':6.2f') |
| 324 | progress = ProgressMeter( |
| 325 | len(train_loader), |
| 326 | [batch_time, data_time, losses, top1, top5, ], |
| 327 | prefix="Epoch: [{}]".format(epoch)) |
| 328 | |
| 329 | # switch to train mode |
| 330 | model.train() |
| 331 | |
| 332 | end = time.time() |
| 333 | for i, (images, target) in enumerate(train_loader): |
| 334 | # measure data loading time |
| 335 | data_time.update(time.time() - end) |
| 336 | |
| 337 | if args.gpu is not None: |
| 338 | images = images.cuda(args.gpu, non_blocking=True) |
| 339 | if torch.cuda.is_available(): |
| 340 | target = target.cuda(args.gpu, non_blocking=True) |
| 341 | |
| 342 | # compute output |
| 343 | |
| 344 | output = model(images) |
| 345 | loss = criterion(output, target) |
| 346 | |
| 347 | # measure accuracy and record loss |
| 348 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) |
| 349 | losses.update(loss.item(), images.size(0)) |
| 350 | top1.update(acc1[0], images.size(0)) |
| 351 | top5.update(acc5[0], images.size(0)) |
| 352 | |
| 353 | # compute gradient and do SGD step |
| 354 | optimizer.zero_grad() |
| 355 | loss.backward() |
| 356 | optimizer.step() |
| 357 | |
| 358 | # measure elapsed time |
| 359 | batch_time.update(time.time() - end) |
| 360 | end = time.time() |
| 361 | |
| 362 | if lr_scheduler is not None: |
| 363 | lr_scheduler.step() |
| 364 | |
| 365 | if is_main and i % args.print_freq == 0: |
| 366 | progress.display(i) |
| 367 | if is_main and i % 1000 == 0 and lr_scheduler is not None: |
| 368 | print('cur lr: ', lr_scheduler.get_lr()[0]) |
| 369 | |
| 370 | |
| 371 |
no test coverage detected