MCPcopy Index your code
hub / github.com/HobbitLong/PyContrast / ContrastTrainer

Class ContrastTrainer

pycontrast/learning/contrast_trainer.py:19–491  ·  view source on GitHub ↗

trainer for contrastive pretraining

Source from the content-addressed store, hash-verified

17
18
19class ContrastTrainer(BaseTrainer):
20 """trainer for contrastive pretraining"""
21 def __init__(self, args):
22 super(ContrastTrainer, self).__init__(args)
23
24 def logging(self, epoch, logs, lr):
25 """ logging to tensorboard
26
27 Args:
28 epoch: training epoch
29 logs: loss and accuracy
30 lr: learning rate
31 """
32 args = self.args
33 if args.rank == 0:
34 self.logger.log_value('loss', logs[0], epoch)
35 self.logger.log_value('acc', logs[1], epoch)
36 self.logger.log_value('jig_loss', logs[2], epoch)
37 self.logger.log_value('jig_acc', logs[3], epoch)
38 self.logger.log_value('learning_rate', lr, epoch)
39
40 def wrap_up(self, model, model_ema, optimizer):
41 """Wrap up models with apex and DDP
42
43 Args:
44 model: model
45 model_ema: momentum encoder
46 optimizer: optimizer
47 """
48 args = self.args
49
50 model.cuda(args.gpu)
51 if isinstance(model_ema, torch.nn.Module):
52 model_ema.cuda(args.gpu)
53
54 # to amp model if needed
55 if args.amp:
56 model, optimizer = amp.initialize(
57 model, optimizer, opt_level=args.opt_level
58 )
59 if isinstance(model_ema, torch.nn.Module):
60 model_ema = amp.initialize(
61 model_ema, opt_level=args.opt_level
62 )
63 # to distributed data parallel
64 model = DDP(model, device_ids=[args.gpu])
65
66 if isinstance(model_ema, torch.nn.Module):
67 self.momentum_update(model.module, model_ema, 0)
68
69 return model, model_ema, optimizer
70
71 def broadcast_memory(self, contrast):
72 """Synchronize memory buffers
73
74 Args:
75 contrast: memory.
76 """

Callers 1

main_workerFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected