MCPcopy Index your code
hub / github.com/HobbitLong/PyContrast / LinearTrainer

Class LinearTrainer

pycontrast/learning/linear_trainer.py:14–242  ·  view source on GitHub ↗

trainer for Linear evaluation

Source from the content-addressed store, hash-verified

12
13
14class LinearTrainer(BaseTrainer):
15 """trainer for Linear evaluation"""
16 def __init__(self, args):
17 super(LinearTrainer, self).__init__(args)
18
19 def logging(self, epoch, logs, lr=None, train=True):
20 """ logging to tensorboard
21
22 Args:
23 epoch: training epoch
24 logs: loss and accuracy
25 lr: learning rate
26 train: True of False
27 """
28 args = self.args
29 if args.rank == 0:
30 pre = 'train_' if train else 'test_'
31 self.logger.log_value(pre+'acc', logs[0], epoch)
32 self.logger.log_value(pre+'acc5', logs[1], epoch)
33 self.logger.log_value(pre+'loss', logs[2], epoch)
34 if train and (lr is not None):
35 self.logger.log_value('learning_rate', lr, epoch)
36
37 def wrap_up(self, model, classifier):
38 """Wrap up models with DDP
39
40 Args:
41 model: pretrained encoder, should be frozen
42 classifier: linear classifier
43 """
44 args = self.args
45 model = model.cuda()
46 classifier = classifier.cuda()
47 model.eval()
48 model = DDP(model, device_ids=[args.gpu])
49 classifier = DDP(classifier, device_ids=[args.gpu])
50
51 return model, classifier
52
53 def load_encoder_weights(self, model):
54 """load pre-trained weights for encoder
55
56 Args:
57 model: pretrained encoder, should be frozen
58 """
59 args = self.args
60 if args.ckpt:
61 ckpt = torch.load(args.ckpt, map_location='cpu')
62 state_dict = ckpt['model']
63 if args.modal == 'RGB':
64 # Unimodal (RGB) case
65 encoder_state_dict = OrderedDict()
66 for k, v in state_dict.items():
67 k = k.replace('module.', '')
68 if 'encoder' in k:
69 k = k.replace('encoder.', '')
70 encoder_state_dict[k] = v
71 model.encoder.load_state_dict(encoder_state_dict)

Callers 1

main_workerFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected