MCPcopy Index your code
hub / github.com/HobbitLong/PyContrast / train

Method train

pycontrast/learning/linear_trainer.py:133–191  ·  view source on GitHub ↗
(self, epoch, train_loader, model, classifier,
              criterion, optimizer)

Source from the content-addressed store, hash-verified

131 del state
132
133 def train(self, epoch, train_loader, model, classifier,
134 criterion, optimizer):
135 time1 = time.time()
136 args = self.args
137
138 model.eval()
139 classifier.train()
140
141 batch_time = AverageMeter()
142 data_time = AverageMeter()
143 losses = AverageMeter()
144 top1 = AverageMeter()
145 top5 = AverageMeter()
146
147 end = time.time()
148 for idx, (input, target) in enumerate(train_loader):
149 data_time.update(time.time() - end)
150
151 input = input.float()
152 input = input.cuda(args.gpu, non_blocking=True)
153 target = target.cuda(args.gpu, non_blocking=True)
154
155 # forward
156 with torch.no_grad():
157 feat = model(x=input, mode=2)
158 feat = feat.detach()
159
160 output = classifier(feat)
161 loss = criterion(output, target)
162
163 acc1, acc5 = accuracy(output, target, topk=(1, 5))
164 losses.update(loss.item(), input.size(0))
165 top1.update(acc1[0], input.size(0))
166 top5.update(acc5[0], input.size(0))
167
168 # backward
169 optimizer.zero_grad()
170 loss.backward()
171 optimizer.step()
172
173 batch_time.update(time.time() - end)
174 end = time.time()
175
176 # print info
177 if args.local_rank == 0 and idx % args.print_freq == 0:
178 print('Epoch: [{0}][{1}/{2}]\t'
179 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
180 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
181 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
182 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
183 'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
184 epoch, idx, len(train_loader), batch_time=batch_time,
185 data_time=data_time, loss=losses, top1=top1, top5=top5))
186 sys.stdout.flush()
187
188 time2 = time.time()
189 print('train epoch {}, total time {:.2f}'.format(epoch, time2 - time1))
190

Callers 1

main_workerFunction · 0.95

Calls 4

updateMethod · 0.95
AverageMeterClass · 0.85
accuracyFunction · 0.85
cudaMethod · 0.80

Tested by

no test coverage detected