()
| 10 | |
| 11 | import torch |
| 12 | def main(): |
| 13 | args = get_arguments() |
| 14 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" |
| 15 | ## FOR REPRODUCIBILITY OF RESULTS |
| 16 | seed = 1777777 |
| 17 | utils.reproducibility(args, seed) |
| 18 | |
| 19 | utils.make_dirs(args.save) |
| 20 | utils.save_arguments(args, args.save) |
| 21 | |
| 22 | training_generator, val_generator, full_volume, affine = medical_loaders.generate_datasets(args, |
| 23 | path='.././datasets') |
| 24 | model, optimizer = medzoo.create_model(args) |
| 25 | criterion = create_loss('CrossEntropyLoss') |
| 26 | criterion = DiceLoss(classes=args.classes, weight=torch.tensor([0.1, 1, 1, 1]).cuda()) |
| 27 | |
| 28 | if args.cuda: |
| 29 | model = model.cuda() |
| 30 | print("Model transferred in GPU.....") |
| 31 | |
| 32 | trainer = Trainer(args, model, criterion, optimizer, train_data_loader=training_generator, |
| 33 | valid_data_loader=val_generator, lr_scheduler=None) |
| 34 | print("START TRAINING...") |
| 35 | trainer.training() |
| 36 | |
| 37 | |
| 38 | def get_arguments(): |
no test coverage detected