MCPcopy
hub / github.com/milesial/Pytorch-UNet / predict_img

Function predict_img

predict.py:15–44  ·  view source on GitHub ↗
(net,
                full_img,
                device,
                scale_factor=1,
                out_threshold=0.5)

Source from the content-addressed store, hash-verified

13from utils.utils import plot_img_and_mask
14
15def predict_img(net,
16 full_img,
17 device,
18 scale_factor=1,
19 out_threshold=0.5):
20 net.eval()
21 img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor, is_mask=False))
22 img = img.unsqueeze(0)
23 img = img.to(device=device, dtype=torch.float32)
24
25 with torch.no_grad():
26 output = net(img)
27
28 if net.n_classes > 1:
29 probs = F.softmax(output, dim=1)[0]
30 else:
31 probs = torch.sigmoid(output)[0]
32
33 tf = transforms.Compose([
34 transforms.ToPILImage(),
35 transforms.Resize((full_img.size[1], full_img.size[0])),
36 transforms.ToTensor()
37 ])
38
39 full_mask = tf(probs.cpu()).squeeze()
40
41 if net.n_classes == 1:
42 return (full_mask > out_threshold).numpy()
43 else:
44 return F.one_hot(full_mask.argmax(dim=0), net.n_classes).permute(2, 0, 1).numpy()
45
46
47def get_args():

Callers 1

predict.pyFile · 0.85

Calls 1

preprocessMethod · 0.80

Tested by

no test coverage detected