SAM3推理 Args: image_path: 图片路径 prompts: 提示词列表 score_threshold: 置信度阈值 min_area: 最小面积阈值 Returns: 元素列表
(self, image_path: str, prompts: List[str],
score_threshold: float = 0.5,
min_area: int = 100)
| 248 | print("[SAM3Model] 模型加载完成!") |
| 249 | |
| 250 | def predict(self, image_path: str, prompts: List[str], |
| 251 | score_threshold: float = 0.5, |
| 252 | min_area: int = 100) -> List[Dict[str, Any]]: |
| 253 | """ |
| 254 | SAM3推理 |
| 255 | |
| 256 | Args: |
| 257 | image_path: 图片路径 |
| 258 | prompts: 提示词列表 |
| 259 | score_threshold: 置信度阈值 |
| 260 | min_area: 最小面积阈值 |
| 261 | |
| 262 | Returns: |
| 263 | 元素列表 |
| 264 | """ |
| 265 | if not self._is_loaded: |
| 266 | self.load() |
| 267 | |
| 268 | state, pil_image = self._get_image_state(image_path) |
| 269 | |
| 270 | results = [] |
| 271 | for prompt in prompts: |
| 272 | self._processor.reset_all_prompts(state) |
| 273 | result_state = self._processor.set_text_prompt(prompt=prompt, state=state) |
| 274 | |
| 275 | masks = result_state.get("masks", []) |
| 276 | boxes = result_state.get("boxes", []) |
| 277 | scores = result_state.get("scores", []) |
| 278 | |
| 279 | num_masks = masks.shape[0] if (isinstance(masks, torch.Tensor) and masks.dim() > 0) else len(masks) |
| 280 | |
| 281 | for i in range(num_masks): |
| 282 | score = scores[i] |
| 283 | score_val = score.item() if hasattr(score, 'item') else float(score) |
| 284 | |
| 285 | if score_val < score_threshold: |
| 286 | continue |
| 287 | |
| 288 | # 提取bbox |
| 289 | box = boxes[i] |
| 290 | bbox = box.cpu().numpy().tolist() if isinstance(box, torch.Tensor) else box |
| 291 | bbox = [int(coord) for coord in bbox] |
| 292 | |
| 293 | # 检查面积 |
| 294 | area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) |
| 295 | if area < min_area: |
| 296 | continue |
| 297 | |
| 298 | # 提取mask |
| 299 | mask = masks[i] |
| 300 | binary_mask = mask.cpu().numpy() if isinstance(mask, torch.Tensor) else np.array(mask) |
| 301 | if binary_mask.ndim > 2: |
| 302 | binary_mask = binary_mask.squeeze() |
| 303 | binary_mask = (binary_mask > 0.5).astype(np.uint8) * 255 |
| 304 | |
| 305 | # 提取polygon |
| 306 | polygon = self._extract_polygon(binary_mask, min_area) |
| 307 |
no test coverage detected