SAM3模型封装
| 212 | |
| 213 | # ======================== SAM3模型封装 ======================== |
| 214 | class SAM3Model(ModelWrapper): |
| 215 | """SAM3模型封装""" |
| 216 | |
| 217 | def __init__(self, checkpoint_path: str, bpe_path: str, device: str = None): |
| 218 | super().__init__() |
| 219 | self.checkpoint_path = checkpoint_path |
| 220 | self.bpe_path = bpe_path |
| 221 | self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
| 222 | self._processor = None |
| 223 | |
| 224 | # 图像状态缓存 |
| 225 | self._state_cache = OrderedDict() |
| 226 | self._max_cache_size = 3 |
| 227 | self._cache_lock = threading.Lock() |
| 228 | |
| 229 | def load(self): |
| 230 | """加载SAM3模型""" |
| 231 | if self._is_loaded: |
| 232 | return |
| 233 | |
| 234 | print(f"[SAM3Model] 加载模型中... (设备: {self.device})") |
| 235 | |
| 236 | from sam3.model_builder import build_sam3_image_model |
| 237 | from sam3.model.sam3_image_processor import Sam3Processor |
| 238 | |
| 239 | self._model = build_sam3_image_model( |
| 240 | bpe_path=self.bpe_path, |
| 241 | checkpoint_path=self.checkpoint_path, |
| 242 | load_from_HF=False, |
| 243 | device=self.device |
| 244 | ) |
| 245 | self._processor = Sam3Processor(self._model) |
| 246 | self._is_loaded = True |
| 247 | |
| 248 | print("[SAM3Model] 模型加载完成!") |
| 249 | |
| 250 | def predict(self, image_path: str, prompts: List[str], |
| 251 | score_threshold: float = 0.5, |
| 252 | min_area: int = 100) -> List[Dict[str, Any]]: |
| 253 | """ |
| 254 | SAM3推理 |
| 255 | |
| 256 | Args: |
| 257 | image_path: 图片路径 |
| 258 | prompts: 提示词列表 |
| 259 | score_threshold: 置信度阈值 |
| 260 | min_area: 最小面积阈值 |
| 261 | |
| 262 | Returns: |
| 263 | 元素列表 |
| 264 | """ |
| 265 | if not self._is_loaded: |
| 266 | self.load() |
| 267 | |
| 268 | state, pil_image = self._get_image_state(image_path) |
| 269 | |
| 270 | results = [] |
| 271 | for prompt in prompts: |