MCPcopy Index your code
hub / github.com/CASIA-LMC-Lab/FastSAM / fast_show_mask_gpu

Function fast_show_mask_gpu

utils/tools.py:261–323  ·  view source on GitHub ↗
(
    annotation,
    ax,
    random_color=False,
    bbox=None,
    points=None,
    point_label=None,
    retinamask=True,
    target_height=960,
    target_width=960,
)

Source from the content-addressed store, hash-verified

259
260
261def fast_show_mask_gpu(
262 annotation,
263 ax,
264 random_color=False,
265 bbox=None,
266 points=None,
267 point_label=None,
268 retinamask=True,
269 target_height=960,
270 target_width=960,
271):
272 msak_sum = annotation.shape[0]
273 height = annotation.shape[1]
274 weight = annotation.shape[2]
275 areas = torch.sum(annotation, dim=(1, 2))
276 sorted_indices = torch.argsort(areas, descending=False)
277 annotation = annotation[sorted_indices]
278 # 找每个位置第一个非零值下标
279 index = (annotation != 0).to(torch.long).argmax(dim=0)
280 if random_color == True:
281 color = torch.rand((msak_sum, 1, 1, 3)).to(annotation.device)
282 else:
283 color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor(
284 [30 / 255, 144 / 255, 255 / 255]
285 ).to(annotation.device)
286 transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6
287 visual = torch.cat([color, transparency], dim=-1)
288 mask_image = torch.unsqueeze(annotation, -1) * visual
289 # 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式
290 show = torch.zeros((height, weight, 4)).to(annotation.device)
291 h_indices, w_indices = torch.meshgrid(
292 torch.arange(height), torch.arange(weight), indexing="ij"
293 )
294 indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
295 # 使用向量化索引更新show的值
296 show[h_indices, w_indices, :] = mask_image[indices]
297 show_cpu = show.cpu().numpy()
298 if bbox is not None:
299 x1, y1, x2, y2 = bbox
300 ax.add_patch(
301 plt.Rectangle(
302 (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
303 )
304 )
305 # draw point
306 if points is not None:
307 plt.scatter(
308 [point[0] for i, point in enumerate(points) if point_label[i] == 1],
309 [point[1] for i, point in enumerate(points) if point_label[i] == 1],
310 s=20,
311 c="y",
312 )
313 plt.scatter(
314 [point[0] for i, point in enumerate(points) if point_label[i] == 0],
315 [point[1] for i, point in enumerate(points) if point_label[i] == 0],
316 s=20,
317 c="m",
318 )

Callers 1

fast_processFunction · 0.70

Calls

no outgoing calls

Tested by

no test coverage detected