MCPcopy Index your code
hub / github.com/HobbitLong/PyContrast / main_worker

Function main_worker

pycontrast/main_contrast.py:31–78  ·  view source on GitHub ↗
(gpu, ngpus_per_node, args)

Source from the content-addressed store, hash-verified

29
30
31def main_worker(gpu, ngpus_per_node, args):
32
33 # initialize trainer and ddp environment
34 trainer = ContrastTrainer(args)
35 trainer.init_ddp_environment(gpu, ngpus_per_node)
36
37 # build model
38 model, model_ema = build_model(args)
39
40 # build dataset
41 train_dataset, train_loader, train_sampler = \
42 build_contrast_loader(args, ngpus_per_node)
43
44 # build memory
45 contrast = build_mem(args, len(train_dataset))
46 contrast.cuda()
47
48 # build criterion and optimizer
49 criterion = nn.CrossEntropyLoss().cuda()
50 optimizer = torch.optim.SGD(model.parameters(),
51 lr=args.learning_rate,
52 momentum=args.momentum,
53 weight_decay=args.weight_decay)
54
55 # wrap up models
56 model, model_ema, optimizer = trainer.wrap_up(model, model_ema, optimizer)
57
58 # optional step: synchronize memory
59 trainer.broadcast_memory(contrast)
60
61 # check and resume a model
62 start_epoch = trainer.resume_model(model, model_ema, contrast, optimizer)
63
64 # init tensorboard logger
65 trainer.init_tensorboard_logger()
66
67 for epoch in range(start_epoch, args.epochs + 1):
68 train_sampler.set_epoch(epoch)
69 trainer.adjust_learning_rate(optimizer, epoch)
70
71 outs = trainer.train(epoch, train_loader, model, model_ema,
72 contrast, criterion, optimizer)
73
74 # log to tensorbard
75 trainer.logging(epoch, outs, optimizer.param_groups[0]['lr'])
76
77 # save model
78 trainer.save(model, model_ema, contrast, optimizer, epoch)
79
80
81if __name__ == '__main__':

Callers

nothing calls this directly

Calls 14

wrap_upMethod · 0.95
broadcast_memoryMethod · 0.95
resume_modelMethod · 0.95
trainMethod · 0.95
loggingMethod · 0.95
saveMethod · 0.95
ContrastTrainerClass · 0.90
build_modelFunction · 0.90
build_contrast_loaderFunction · 0.90
build_memFunction · 0.90
init_ddp_environmentMethod · 0.80
cudaMethod · 0.80

Tested by

no test coverage detected