MCPcopy
hub / github.com/black0017/MedicalZooPytorch / main

Function main

tests/train_with_trainer_class.py:12–35  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

10
11import torch
12def main():
13 args = get_arguments()
14 os.environ["CUDA_VISIBLE_DEVICES"] = "0"
15 ## FOR REPRODUCIBILITY OF RESULTS
16 seed = 1777777
17 utils.reproducibility(args, seed)
18
19 utils.make_dirs(args.save)
20 utils.save_arguments(args, args.save)
21
22 training_generator, val_generator, full_volume, affine = medical_loaders.generate_datasets(args,
23 path='.././datasets')
24 model, optimizer = medzoo.create_model(args)
25 criterion = create_loss('CrossEntropyLoss')
26 criterion = DiceLoss(classes=args.classes, weight=torch.tensor([0.1, 1, 1, 1]).cuda())
27
28 if args.cuda:
29 model = model.cuda()
30 print("Model transferred in GPU.....")
31
32 trainer = Trainer(args, model, criterion, optimizer, train_data_loader=training_generator,
33 valid_data_loader=val_generator, lr_scheduler=None)
34 print("START TRAINING...")
35 trainer.training()
36
37
38def get_arguments():

Callers 1

Calls 5

trainingMethod · 0.95
create_lossFunction · 0.90
DiceLossClass · 0.90
TrainerClass · 0.90
get_argumentsFunction · 0.70

Tested by

no test coverage detected