MCPcopy
hub / github.com/DingXiaoH/RepVGG / train

Function train

quantization/quant_qat_train.py:318–368  ·  view source on GitHub ↗
(train_loader, model, criterion, optimizer, epoch, args, lr_scheduler, is_main)

Source from the content-addressed store, hash-verified

316
317
318def train(train_loader, model, criterion, optimizer, epoch, args, lr_scheduler, is_main):
319 batch_time = AverageMeter('Time', ':6.3f')
320 data_time = AverageMeter('Data', ':6.3f')
321 losses = AverageMeter('Loss', ':.4e')
322 top1 = AverageMeter('Acc@1', ':6.2f')
323 top5 = AverageMeter('Acc@5', ':6.2f')
324 progress = ProgressMeter(
325 len(train_loader),
326 [batch_time, data_time, losses, top1, top5, ],
327 prefix="Epoch: [{}]".format(epoch))
328
329 # switch to train mode
330 model.train()
331
332 end = time.time()
333 for i, (images, target) in enumerate(train_loader):
334 # measure data loading time
335 data_time.update(time.time() - end)
336
337 if args.gpu is not None:
338 images = images.cuda(args.gpu, non_blocking=True)
339 if torch.cuda.is_available():
340 target = target.cuda(args.gpu, non_blocking=True)
341
342 # compute output
343
344 output = model(images)
345 loss = criterion(output, target)
346
347 # measure accuracy and record loss
348 acc1, acc5 = accuracy(output, target, topk=(1, 5))
349 losses.update(loss.item(), images.size(0))
350 top1.update(acc1[0], images.size(0))
351 top5.update(acc5[0], images.size(0))
352
353 # compute gradient and do SGD step
354 optimizer.zero_grad()
355 loss.backward()
356 optimizer.step()
357
358 # measure elapsed time
359 batch_time.update(time.time() - end)
360 end = time.time()
361
362 if lr_scheduler is not None:
363 lr_scheduler.step()
364
365 if is_main and i % args.print_freq == 0:
366 progress.display(i)
367 if is_main and i % 1000 == 0 and lr_scheduler is not None:
368 print('cur lr: ', lr_scheduler.get_lr()[0])
369
370
371

Callers 1

main_workerFunction · 0.85

Calls 6

updateMethod · 0.95
displayMethod · 0.95
AverageMeterClass · 0.85
ProgressMeterClass · 0.85
accuracyFunction · 0.85
get_lrMethod · 0.80

Tested by

no test coverage detected