Postprocesses the output of Yolact on testing mode into a format that makes sense, accounting for all the possible configuration settings. Args: - det_output: The lost of dicts that Detect outputs. - w: The real with of the image. - h: The real height of the ima
(det_output, w, h, batch_idx=0, interpolation_mode='bilinear',
visualize_lincomb=False, crop_masks=True, score_threshold=0)
| 13 | from .box_utils import crop, sanitize_coordinates |
| 14 | |
| 15 | def postprocess(det_output, w, h, batch_idx=0, interpolation_mode='bilinear', |
| 16 | visualize_lincomb=False, crop_masks=True, score_threshold=0): |
| 17 | """ |
| 18 | Postprocesses the output of Yolact on testing mode into a format that makes sense, |
| 19 | accounting for all the possible configuration settings. |
| 20 | |
| 21 | Args: |
| 22 | - det_output: The lost of dicts that Detect outputs. |
| 23 | - w: The real with of the image. |
| 24 | - h: The real height of the image. |
| 25 | - batch_idx: If you have multiple images for this batch, the image's index in the batch. |
| 26 | - interpolation_mode: Can be 'nearest' | 'area' | 'bilinear' (see torch.nn.functional.interpolate) |
| 27 | |
| 28 | Returns 4 torch Tensors (in the following order): |
| 29 | - classes [num_det]: The class idx for each detection. |
| 30 | - scores [num_det]: The confidence score for each detection. |
| 31 | - boxes [num_det, 4]: The bounding box for each detection in absolute point form. |
| 32 | - masks [num_det, h, w]: Full image masks for each detection. |
| 33 | """ |
| 34 | |
| 35 | dets = det_output[batch_idx] |
| 36 | net = dets['net'] |
| 37 | dets = dets['detection'] |
| 38 | |
| 39 | if dets is None: |
| 40 | return [torch.Tensor()] * 4 # Warning, this is 4 copies of the same thing |
| 41 | |
| 42 | if score_threshold > 0: |
| 43 | keep = dets['score'] > score_threshold |
| 44 | |
| 45 | for k in dets: |
| 46 | if k != 'proto': |
| 47 | dets[k] = dets[k][keep] |
| 48 | |
| 49 | if dets['score'].size(0) == 0: |
| 50 | return [torch.Tensor()] * 4 |
| 51 | |
| 52 | # Actually extract everything from dets now |
| 53 | classes = dets['class'] |
| 54 | boxes = dets['box'] |
| 55 | scores = dets['score'] |
| 56 | masks = dets['mask'] |
| 57 | |
| 58 | if cfg.mask_type == mask_type.lincomb and cfg.eval_mask_branch: |
| 59 | # At this points masks is only the coefficients |
| 60 | proto_data = dets['proto'] |
| 61 | |
| 62 | # Test flag, do not upvote |
| 63 | if cfg.mask_proto_debug: |
| 64 | np.save('scripts/proto.npy', proto_data.cpu().numpy()) |
| 65 | |
| 66 | if visualize_lincomb: |
| 67 | display_lincomb(proto_data, masks) |
| 68 | |
| 69 | masks = proto_data @ masks.t() |
| 70 | masks = cfg.mask_proto_mask_activation(masks) |
| 71 | |
| 72 | # Crop masks before upsampling because you know why |
no test coverage detected