trainer for contrastive pretraining
| 17 | |
| 18 | |
| 19 | class ContrastTrainer(BaseTrainer): |
| 20 | """trainer for contrastive pretraining""" |
| 21 | def __init__(self, args): |
| 22 | super(ContrastTrainer, self).__init__(args) |
| 23 | |
| 24 | def logging(self, epoch, logs, lr): |
| 25 | """ logging to tensorboard |
| 26 | |
| 27 | Args: |
| 28 | epoch: training epoch |
| 29 | logs: loss and accuracy |
| 30 | lr: learning rate |
| 31 | """ |
| 32 | args = self.args |
| 33 | if args.rank == 0: |
| 34 | self.logger.log_value('loss', logs[0], epoch) |
| 35 | self.logger.log_value('acc', logs[1], epoch) |
| 36 | self.logger.log_value('jig_loss', logs[2], epoch) |
| 37 | self.logger.log_value('jig_acc', logs[3], epoch) |
| 38 | self.logger.log_value('learning_rate', lr, epoch) |
| 39 | |
| 40 | def wrap_up(self, model, model_ema, optimizer): |
| 41 | """Wrap up models with apex and DDP |
| 42 | |
| 43 | Args: |
| 44 | model: model |
| 45 | model_ema: momentum encoder |
| 46 | optimizer: optimizer |
| 47 | """ |
| 48 | args = self.args |
| 49 | |
| 50 | model.cuda(args.gpu) |
| 51 | if isinstance(model_ema, torch.nn.Module): |
| 52 | model_ema.cuda(args.gpu) |
| 53 | |
| 54 | # to amp model if needed |
| 55 | if args.amp: |
| 56 | model, optimizer = amp.initialize( |
| 57 | model, optimizer, opt_level=args.opt_level |
| 58 | ) |
| 59 | if isinstance(model_ema, torch.nn.Module): |
| 60 | model_ema = amp.initialize( |
| 61 | model_ema, opt_level=args.opt_level |
| 62 | ) |
| 63 | # to distributed data parallel |
| 64 | model = DDP(model, device_ids=[args.gpu]) |
| 65 | |
| 66 | if isinstance(model_ema, torch.nn.Module): |
| 67 | self.momentum_update(model.module, model_ema, 0) |
| 68 | |
| 69 | return model, model_ema, optimizer |
| 70 | |
| 71 | def broadcast_memory(self, contrast): |
| 72 | """Synchronize memory buffers |
| 73 | |
| 74 | Args: |
| 75 | contrast: memory. |
| 76 | """ |