| 259 | |
| 260 | |
| 261 | def fast_show_mask_gpu( |
| 262 | annotation, |
| 263 | ax, |
| 264 | random_color=False, |
| 265 | bbox=None, |
| 266 | points=None, |
| 267 | point_label=None, |
| 268 | retinamask=True, |
| 269 | target_height=960, |
| 270 | target_width=960, |
| 271 | ): |
| 272 | msak_sum = annotation.shape[0] |
| 273 | height = annotation.shape[1] |
| 274 | weight = annotation.shape[2] |
| 275 | areas = torch.sum(annotation, dim=(1, 2)) |
| 276 | sorted_indices = torch.argsort(areas, descending=False) |
| 277 | annotation = annotation[sorted_indices] |
| 278 | # 找每个位置第一个非零值下标 |
| 279 | index = (annotation != 0).to(torch.long).argmax(dim=0) |
| 280 | if random_color == True: |
| 281 | color = torch.rand((msak_sum, 1, 1, 3)).to(annotation.device) |
| 282 | else: |
| 283 | color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor( |
| 284 | [30 / 255, 144 / 255, 255 / 255] |
| 285 | ).to(annotation.device) |
| 286 | transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6 |
| 287 | visual = torch.cat([color, transparency], dim=-1) |
| 288 | mask_image = torch.unsqueeze(annotation, -1) * visual |
| 289 | # 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式 |
| 290 | show = torch.zeros((height, weight, 4)).to(annotation.device) |
| 291 | h_indices, w_indices = torch.meshgrid( |
| 292 | torch.arange(height), torch.arange(weight), indexing="ij" |
| 293 | ) |
| 294 | indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None)) |
| 295 | # 使用向量化索引更新show的值 |
| 296 | show[h_indices, w_indices, :] = mask_image[indices] |
| 297 | show_cpu = show.cpu().numpy() |
| 298 | if bbox is not None: |
| 299 | x1, y1, x2, y2 = bbox |
| 300 | ax.add_patch( |
| 301 | plt.Rectangle( |
| 302 | (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1 |
| 303 | ) |
| 304 | ) |
| 305 | # draw point |
| 306 | if points is not None: |
| 307 | plt.scatter( |
| 308 | [point[0] for i, point in enumerate(points) if point_label[i] == 1], |
| 309 | [point[1] for i, point in enumerate(points) if point_label[i] == 1], |
| 310 | s=20, |
| 311 | c="y", |
| 312 | ) |
| 313 | plt.scatter( |
| 314 | [point[0] for i, point in enumerate(points) if point_label[i] == 0], |
| 315 | [point[1] for i, point in enumerate(points) if point_label[i] == 0], |
| 316 | s=20, |
| 317 | c="m", |
| 318 | ) |