Args: mask (numpy.ndarray): shape N x 1 x H x W Outputs: new_mask (numpy.ndarray): shape H x W
(self, masks)
| 1308 | return updated_image_path |
| 1309 | |
| 1310 | def merge_masks(self, masks): |
| 1311 | ''' |
| 1312 | Args: |
| 1313 | mask (numpy.ndarray): shape N x 1 x H x W |
| 1314 | Outputs: |
| 1315 | new_mask (numpy.ndarray): shape H x W |
| 1316 | ''' |
| 1317 | if type(masks) == torch.Tensor: |
| 1318 | x = masks |
| 1319 | elif type(masks) == np.ndarray: |
| 1320 | x = torch.tensor(masks,dtype=int) |
| 1321 | else: |
| 1322 | raise TypeError("the type of the input masks must be numpy.ndarray or torch.tensor") |
| 1323 | x = x.squeeze(dim=1) |
| 1324 | value, _ = x.max(dim=0) |
| 1325 | new_mask = value.cpu().numpy() |
| 1326 | new_mask.astype(np.uint8) |
| 1327 | return new_mask |
| 1328 | |
| 1329 | def get_mask(self, image_path, text_prompt): |
| 1330 |