MCPcopy Index your code
hub / github.com/HobbitLong/PyContrast / main_worker

Function main_worker

pycontrast/main_linear.py:30–81  ·  view source on GitHub ↗
(gpu, ngpus_per_node, args)

Source from the content-addressed store, hash-verified

28
29
30def main_worker(gpu, ngpus_per_node, args):
31
32 # initialize trainer and ddp environment
33 trainer = LinearTrainer(args)
34 trainer.init_ddp_environment(gpu, ngpus_per_node)
35
36 # build encoder and classifier
37 model, _ = build_model(args)
38 classifier = build_linear(args)
39
40 # build dataset
41 train_loader, val_loader, train_sampler = \
42 build_linear_loader(args, ngpus_per_node)
43
44 # build criterion and optimizer
45 criterion = nn.CrossEntropyLoss().cuda()
46 optimizer = torch.optim.SGD(classifier.parameters(),
47 lr=args.learning_rate,
48 momentum=args.momentum,
49 weight_decay=args.weight_decay)
50
51 # load pre-trained ckpt for encoder
52 model = trainer.load_encoder_weights(model)
53
54 # wrap up models
55 model, classifier = trainer.wrap_up(model, classifier)
56
57 # check and resume a classifier
58 start_epoch = trainer.resume_model(classifier, optimizer)
59
60 # init tensorboard logger
61 trainer.init_tensorboard_logger()
62
63 # routine
64 for epoch in range(start_epoch, args.epochs + 1):
65 train_sampler.set_epoch(epoch)
66 trainer.adjust_learning_rate(optimizer, epoch)
67
68 outs = trainer.train(epoch, train_loader, model, classifier,
69 criterion, optimizer)
70
71 # log to tensorbard
72 trainer.logging(epoch, outs, optimizer.param_groups[0]['lr'], train=True)
73
74 # evaluation and logging
75 if args.rank % ngpus_per_node == 0:
76 outs = trainer.validate(epoch, val_loader, model,
77 classifier, criterion)
78 trainer.logging(epoch, outs, train=False)
79
80 # saving model
81 trainer.save(classifier, optimizer, epoch)
82
83
84if __name__ == '__main__':

Callers

nothing calls this directly

Calls 15

load_encoder_weightsMethod · 0.95
wrap_upMethod · 0.95
resume_modelMethod · 0.95
trainMethod · 0.95
loggingMethod · 0.95
validateMethod · 0.95
saveMethod · 0.95
LinearTrainerClass · 0.90
build_modelFunction · 0.90
build_linearFunction · 0.90
build_linear_loaderFunction · 0.90
init_ddp_environmentMethod · 0.80

Tested by

no test coverage detected