MCPcopy
hub / github.com/HobbitLong/SupContrast / train

Function train

main_supcon.py:197–252  ·  view source on GitHub ↗

one epoch training

(train_loader, model, criterion, optimizer, epoch, opt)

Source from the content-addressed store, hash-verified

195
196
197def train(train_loader, model, criterion, optimizer, epoch, opt):
198 """one epoch training"""
199 model.train()
200
201 batch_time = AverageMeter()
202 data_time = AverageMeter()
203 losses = AverageMeter()
204
205 end = time.time()
206 for idx, (images, labels) in enumerate(train_loader):
207 data_time.update(time.time() - end)
208
209 images = torch.cat([images[0], images[1]], dim=0)
210 if torch.cuda.is_available():
211 images = images.cuda(non_blocking=True)
212 labels = labels.cuda(non_blocking=True)
213 bsz = labels.shape[0]
214
215 # warm-up learning rate
216 warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer)
217
218 # compute loss
219 features = model(images)
220 f1, f2 = torch.split(features, [bsz, bsz], dim=0)
221 features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1)
222 if opt.method == 'SupCon':
223 loss = criterion(features, labels)
224 elif opt.method == 'SimCLR':
225 loss = criterion(features)
226 else:
227 raise ValueError('contrastive method not supported: {}'.
228 format(opt.method))
229
230 # update metric
231 losses.update(loss.item(), bsz)
232
233 # SGD
234 optimizer.zero_grad()
235 loss.backward()
236 optimizer.step()
237
238 # measure elapsed time
239 batch_time.update(time.time() - end)
240 end = time.time()
241
242 # print info
243 if (idx + 1) % opt.print_freq == 0:
244 print('Train: [{0}][{1}/{2}]\t'
245 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
246 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
247 'loss {loss.val:.3f} ({loss.avg:.3f})'.format(
248 epoch, idx + 1, len(train_loader), batch_time=batch_time,
249 data_time=data_time, loss=losses))
250 sys.stdout.flush()
251
252 return losses.avg
253
254

Callers 1

mainFunction · 0.70

Calls 3

updateMethod · 0.95
AverageMeterClass · 0.90
warmup_learning_rateFunction · 0.90

Tested by

no test coverage detected