Calculates the image embeddings for the provided image, allowing masks to be predicted with the 'predict' method. Arguments: image (np.ndarray or PIL Image): The input image to embed in RGB format. The image should be in HWC format if np.ndarray, or WHC format if
(
self,
image: Union[np.ndarray, Image],
)
| 84 | |
| 85 | @torch.no_grad() |
| 86 | def set_image( |
| 87 | self, |
| 88 | image: Union[np.ndarray, Image], |
| 89 | ) -> None: |
| 90 | """ |
| 91 | Calculates the image embeddings for the provided image, allowing |
| 92 | masks to be predicted with the 'predict' method. |
| 93 | |
| 94 | Arguments: |
| 95 | image (np.ndarray or PIL Image): The input image to embed in RGB format. The image should be in HWC format if np.ndarray, or WHC format if PIL Image |
| 96 | with pixel values in [0, 255]. |
| 97 | image_format (str): The color format of the image, in ['RGB', 'BGR']. |
| 98 | """ |
| 99 | self.reset_predictor() |
| 100 | # Transform the image to the form expected by the model |
| 101 | if isinstance(image, np.ndarray): |
| 102 | logging.info("For numpy array image, we assume (HxWxC) format") |
| 103 | self._orig_hw = [image.shape[:2]] |
| 104 | elif isinstance(image, Image): |
| 105 | w, h = image.size |
| 106 | self._orig_hw = [(h, w)] |
| 107 | else: |
| 108 | raise NotImplementedError("Image format not supported") |
| 109 | |
| 110 | input_image = self._transforms(image) |
| 111 | input_image = input_image[None, ...].to(self.device) |
| 112 | |
| 113 | assert ( |
| 114 | len(input_image.shape) == 4 and input_image.shape[1] == 3 |
| 115 | ), f"input_image must be of size 1x3xHxW, got {input_image.shape}" |
| 116 | logging.info("Computing image embeddings for the provided image...") |
| 117 | backbone_out = self.model.forward_image(input_image) |
| 118 | _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out) |
| 119 | # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos |
| 120 | if self.model.directly_add_no_mem_embed: |
| 121 | vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed |
| 122 | |
| 123 | feats = [ |
| 124 | feat.permute(1, 2, 0).view(1, -1, *feat_size) |
| 125 | for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1]) |
| 126 | ][::-1] |
| 127 | self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]} |
| 128 | self._is_image_set = True |
| 129 | logging.info("Image embeddings computed.") |
| 130 | |
| 131 | @torch.no_grad() |
| 132 | def set_image_batch( |
no test coverage detected