MCPcopy
hub / github.com/BIT-DataLab/Edit-Banana / predict

Method predict

modules/sam3_info_extractor.py:250–318  ·  view source on GitHub ↗

SAM3推理 Args: image_path: 图片路径 prompts: 提示词列表 score_threshold: 置信度阈值 min_area: 最小面积阈值 Returns: 元素列表

(self, image_path: str, prompts: List[str], 
                score_threshold: float = 0.5,
                min_area: int = 100)

Source from the content-addressed store, hash-verified

248 print("[SAM3Model] 模型加载完成!")
249
250 def predict(self, image_path: str, prompts: List[str],
251 score_threshold: float = 0.5,
252 min_area: int = 100) -> List[Dict[str, Any]]:
253 """
254 SAM3推理
255
256 Args:
257 image_path: 图片路径
258 prompts: 提示词列表
259 score_threshold: 置信度阈值
260 min_area: 最小面积阈值
261
262 Returns:
263 元素列表
264 """
265 if not self._is_loaded:
266 self.load()
267
268 state, pil_image = self._get_image_state(image_path)
269
270 results = []
271 for prompt in prompts:
272 self._processor.reset_all_prompts(state)
273 result_state = self._processor.set_text_prompt(prompt=prompt, state=state)
274
275 masks = result_state.get("masks", [])
276 boxes = result_state.get("boxes", [])
277 scores = result_state.get("scores", [])
278
279 num_masks = masks.shape[0] if (isinstance(masks, torch.Tensor) and masks.dim() > 0) else len(masks)
280
281 for i in range(num_masks):
282 score = scores[i]
283 score_val = score.item() if hasattr(score, 'item') else float(score)
284
285 if score_val < score_threshold:
286 continue
287
288 # 提取bbox
289 box = boxes[i]
290 bbox = box.cpu().numpy().tolist() if isinstance(box, torch.Tensor) else box
291 bbox = [int(coord) for coord in bbox]
292
293 # 检查面积
294 area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
295 if area < min_area:
296 continue
297
298 # 提取mask
299 mask = masks[i]
300 binary_mask = mask.cpu().numpy() if isinstance(mask, torch.Tensor) else np.array(mask)
301 if binary_mask.ndim > 2:
302 binary_mask = binary_mask.squeeze()
303 binary_mask = (binary_mask > 0.5).astype(np.uint8) * 255
304
305 # 提取polygon
306 polygon = self._extract_polygon(binary_mask, min_area)
307

Callers 3

processMethod · 0.45
extract_by_groupMethod · 0.45

Calls 3

loadMethod · 0.95
_get_image_stateMethod · 0.95
_extract_polygonMethod · 0.95

Tested by

no test coverage detected