trainer for Linear evaluation
| 12 | |
| 13 | |
| 14 | class LinearTrainer(BaseTrainer): |
| 15 | """trainer for Linear evaluation""" |
| 16 | def __init__(self, args): |
| 17 | super(LinearTrainer, self).__init__(args) |
| 18 | |
| 19 | def logging(self, epoch, logs, lr=None, train=True): |
| 20 | """ logging to tensorboard |
| 21 | |
| 22 | Args: |
| 23 | epoch: training epoch |
| 24 | logs: loss and accuracy |
| 25 | lr: learning rate |
| 26 | train: True of False |
| 27 | """ |
| 28 | args = self.args |
| 29 | if args.rank == 0: |
| 30 | pre = 'train_' if train else 'test_' |
| 31 | self.logger.log_value(pre+'acc', logs[0], epoch) |
| 32 | self.logger.log_value(pre+'acc5', logs[1], epoch) |
| 33 | self.logger.log_value(pre+'loss', logs[2], epoch) |
| 34 | if train and (lr is not None): |
| 35 | self.logger.log_value('learning_rate', lr, epoch) |
| 36 | |
| 37 | def wrap_up(self, model, classifier): |
| 38 | """Wrap up models with DDP |
| 39 | |
| 40 | Args: |
| 41 | model: pretrained encoder, should be frozen |
| 42 | classifier: linear classifier |
| 43 | """ |
| 44 | args = self.args |
| 45 | model = model.cuda() |
| 46 | classifier = classifier.cuda() |
| 47 | model.eval() |
| 48 | model = DDP(model, device_ids=[args.gpu]) |
| 49 | classifier = DDP(classifier, device_ids=[args.gpu]) |
| 50 | |
| 51 | return model, classifier |
| 52 | |
| 53 | def load_encoder_weights(self, model): |
| 54 | """load pre-trained weights for encoder |
| 55 | |
| 56 | Args: |
| 57 | model: pretrained encoder, should be frozen |
| 58 | """ |
| 59 | args = self.args |
| 60 | if args.ckpt: |
| 61 | ckpt = torch.load(args.ckpt, map_location='cpu') |
| 62 | state_dict = ckpt['model'] |
| 63 | if args.modal == 'RGB': |
| 64 | # Unimodal (RGB) case |
| 65 | encoder_state_dict = OrderedDict() |
| 66 | for k, v in state_dict.items(): |
| 67 | k = k.replace('module.', '') |
| 68 | if 'encoder' in k: |
| 69 | k = k.replace('encoder.', '') |
| 70 | encoder_state_dict[k] = v |
| 71 | model.encoder.load_state_dict(encoder_state_dict) |