MCPcopy
hub / github.com/HobbitLong/SupContrast / train

Function train

main_linear.py:133–185  ·  view source on GitHub ↗

one epoch training

(train_loader, model, classifier, criterion, optimizer, epoch, opt)

Source from the content-addressed store, hash-verified

131
132
133def train(train_loader, model, classifier, criterion, optimizer, epoch, opt):
134 """one epoch training"""
135 model.eval()
136 classifier.train()
137
138 batch_time = AverageMeter()
139 data_time = AverageMeter()
140 losses = AverageMeter()
141 top1 = AverageMeter()
142
143 end = time.time()
144 for idx, (images, labels) in enumerate(train_loader):
145 data_time.update(time.time() - end)
146
147 images = images.cuda(non_blocking=True)
148 labels = labels.cuda(non_blocking=True)
149 bsz = labels.shape[0]
150
151 # warm-up learning rate
152 warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer)
153
154 # compute loss
155 with torch.no_grad():
156 features = model.encoder(images)
157 output = classifier(features.detach())
158 loss = criterion(output, labels)
159
160 # update metric
161 losses.update(loss.item(), bsz)
162 acc1, acc5 = accuracy(output, labels, topk=(1, 5))
163 top1.update(acc1[0], bsz)
164
165 # SGD
166 optimizer.zero_grad()
167 loss.backward()
168 optimizer.step()
169
170 # measure elapsed time
171 batch_time.update(time.time() - end)
172 end = time.time()
173
174 # print info
175 if (idx + 1) % opt.print_freq == 0:
176 print('Train: [{0}][{1}/{2}]\t'
177 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
178 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
179 'loss {loss.val:.3f} ({loss.avg:.3f})\t'
180 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
181 epoch, idx + 1, len(train_loader), batch_time=batch_time,
182 data_time=data_time, loss=losses, top1=top1))
183 sys.stdout.flush()
184
185 return losses.avg, top1.avg
186
187
188def validate(val_loader, model, classifier, criterion, opt):

Callers 1

mainFunction · 0.70

Calls 4

updateMethod · 0.95
AverageMeterClass · 0.90
warmup_learning_rateFunction · 0.90
accuracyFunction · 0.90

Tested by

no test coverage detected