(input_tuple, inModalities=-1, inChannels=-1, cuda=False, args=None)
| 40 | |
| 41 | |
| 42 | def prepare_input(input_tuple, inModalities=-1, inChannels=-1, cuda=False, args=None): |
| 43 | if args is not None: |
| 44 | modalities = args.inModalities |
| 45 | channels = args.inChannels |
| 46 | in_cuda = args.cuda |
| 47 | else: |
| 48 | modalities = inModalities |
| 49 | channels = inChannels |
| 50 | in_cuda = cuda |
| 51 | if modalities == 4: |
| 52 | if channels == 4: |
| 53 | img_1, img_2, img_3, img_4, target = input_tuple |
| 54 | input_tensor = torch.cat((img_1, img_2, img_3, img_4), dim=1) |
| 55 | elif channels == 3: |
| 56 | # t1 post constast is ommited |
| 57 | img_1, _, img_3, img_4, target = input_tuple |
| 58 | input_tensor = torch.cat((img_1, img_3, img_4), dim=1) |
| 59 | elif channels == 2: |
| 60 | # t1 and t2 only |
| 61 | img_1, _, img_3, _, target = input_tuple |
| 62 | input_tensor = torch.cat((img_1, img_3), dim=1) |
| 63 | elif channels == 1: |
| 64 | # t1 only |
| 65 | input_tensor, _, _, target = input_tuple |
| 66 | if modalities == 3: |
| 67 | if channels == 3: |
| 68 | img_1, img_2, img_3, target = input_tuple |
| 69 | input_tensor = torch.cat((img_1, img_2, img_3), dim=1) |
| 70 | elif channels == 2: |
| 71 | img_1, img_2, _, target = input_tuple |
| 72 | input_tensor = torch.cat((img_1, img_2), dim=1) |
| 73 | elif channels == 1: |
| 74 | input_tensor, _, _, target = input_tuple |
| 75 | elif modalities == 2: |
| 76 | if channels == 2: |
| 77 | img_t1, img_t2, target = input_tuple |
| 78 | |
| 79 | input_tensor = torch.cat((img_t1, img_t2), dim=1) |
| 80 | |
| 81 | elif channels == 1: |
| 82 | input_tensor, _, target = input_tuple |
| 83 | elif modalities == 1: |
| 84 | input_tensor, target = input_tuple |
| 85 | |
| 86 | if in_cuda: |
| 87 | input_tensor, target = input_tensor.cuda(), target.cuda() |
| 88 | |
| 89 | return input_tensor, target |
| 90 | |
| 91 | |
| 92 | def adjust_opt(optAlg, optimizer, epoch): |
no outgoing calls
no test coverage detected