(gpu, ngpus_per_node, args)
| 28 | |
| 29 | |
| 30 | def main_worker(gpu, ngpus_per_node, args): |
| 31 | |
| 32 | # initialize trainer and ddp environment |
| 33 | trainer = LinearTrainer(args) |
| 34 | trainer.init_ddp_environment(gpu, ngpus_per_node) |
| 35 | |
| 36 | # build encoder and classifier |
| 37 | model, _ = build_model(args) |
| 38 | classifier = build_linear(args) |
| 39 | |
| 40 | # build dataset |
| 41 | train_loader, val_loader, train_sampler = \ |
| 42 | build_linear_loader(args, ngpus_per_node) |
| 43 | |
| 44 | # build criterion and optimizer |
| 45 | criterion = nn.CrossEntropyLoss().cuda() |
| 46 | optimizer = torch.optim.SGD(classifier.parameters(), |
| 47 | lr=args.learning_rate, |
| 48 | momentum=args.momentum, |
| 49 | weight_decay=args.weight_decay) |
| 50 | |
| 51 | # load pre-trained ckpt for encoder |
| 52 | model = trainer.load_encoder_weights(model) |
| 53 | |
| 54 | # wrap up models |
| 55 | model, classifier = trainer.wrap_up(model, classifier) |
| 56 | |
| 57 | # check and resume a classifier |
| 58 | start_epoch = trainer.resume_model(classifier, optimizer) |
| 59 | |
| 60 | # init tensorboard logger |
| 61 | trainer.init_tensorboard_logger() |
| 62 | |
| 63 | # routine |
| 64 | for epoch in range(start_epoch, args.epochs + 1): |
| 65 | train_sampler.set_epoch(epoch) |
| 66 | trainer.adjust_learning_rate(optimizer, epoch) |
| 67 | |
| 68 | outs = trainer.train(epoch, train_loader, model, classifier, |
| 69 | criterion, optimizer) |
| 70 | |
| 71 | # log to tensorbard |
| 72 | trainer.logging(epoch, outs, optimizer.param_groups[0]['lr'], train=True) |
| 73 | |
| 74 | # evaluation and logging |
| 75 | if args.rank % ngpus_per_node == 0: |
| 76 | outs = trainer.validate(epoch, val_loader, model, |
| 77 | classifier, criterion) |
| 78 | trainer.logging(epoch, outs, train=False) |
| 79 | |
| 80 | # saving model |
| 81 | trainer.save(classifier, optimizer, epoch) |
| 82 | |
| 83 | |
| 84 | if __name__ == '__main__': |
nothing calls this directly
no test coverage detected