| 13 | from utils.utils import plot_img_and_mask |
| 14 | |
| 15 | def predict_img(net, |
| 16 | full_img, |
| 17 | device, |
| 18 | scale_factor=1, |
| 19 | out_threshold=0.5): |
| 20 | net.eval() |
| 21 | img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor, is_mask=False)) |
| 22 | img = img.unsqueeze(0) |
| 23 | img = img.to(device=device, dtype=torch.float32) |
| 24 | |
| 25 | with torch.no_grad(): |
| 26 | output = net(img) |
| 27 | |
| 28 | if net.n_classes > 1: |
| 29 | probs = F.softmax(output, dim=1)[0] |
| 30 | else: |
| 31 | probs = torch.sigmoid(output)[0] |
| 32 | |
| 33 | tf = transforms.Compose([ |
| 34 | transforms.ToPILImage(), |
| 35 | transforms.Resize((full_img.size[1], full_img.size[0])), |
| 36 | transforms.ToTensor() |
| 37 | ]) |
| 38 | |
| 39 | full_mask = tf(probs.cpu()).squeeze() |
| 40 | |
| 41 | if net.n_classes == 1: |
| 42 | return (full_mask > out_threshold).numpy() |
| 43 | else: |
| 44 | return F.one_hot(full_mask.argmax(dim=0), net.n_classes).permute(2, 0, 1).numpy() |
| 45 | |
| 46 | |
| 47 | def get_args(): |