MCPcopy Index your code
hub / github.com/BIT-DataLab/Edit-Banana / SAM3Model

Class SAM3Model

modules/sam3_info_extractor.py:214–365  ·  view source on GitHub ↗

SAM3模型封装

Source from the content-addressed store, hash-verified

212
213# ======================== SAM3模型封装 ========================
214class SAM3Model(ModelWrapper):
215 """SAM3模型封装"""
216
217 def __init__(self, checkpoint_path: str, bpe_path: str, device: str = None):
218 super().__init__()
219 self.checkpoint_path = checkpoint_path
220 self.bpe_path = bpe_path
221 self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
222 self._processor = None
223
224 # 图像状态缓存
225 self._state_cache = OrderedDict()
226 self._max_cache_size = 3
227 self._cache_lock = threading.Lock()
228
229 def load(self):
230 """加载SAM3模型"""
231 if self._is_loaded:
232 return
233
234 print(f"[SAM3Model] 加载模型中... (设备: {self.device})")
235
236 from sam3.model_builder import build_sam3_image_model
237 from sam3.model.sam3_image_processor import Sam3Processor
238
239 self._model = build_sam3_image_model(
240 bpe_path=self.bpe_path,
241 checkpoint_path=self.checkpoint_path,
242 load_from_HF=False,
243 device=self.device
244 )
245 self._processor = Sam3Processor(self._model)
246 self._is_loaded = True
247
248 print("[SAM3Model] 模型加载完成!")
249
250 def predict(self, image_path: str, prompts: List[str],
251 score_threshold: float = 0.5,
252 min_area: int = 100) -> List[Dict[str, Any]]:
253 """
254 SAM3推理
255
256 Args:
257 image_path: 图片路径
258 prompts: 提示词列表
259 score_threshold: 置信度阈值
260 min_area: 最小面积阈值
261
262 Returns:
263 元素列表
264 """
265 if not self._is_loaded:
266 self.load()
267
268 state, pil_image = self._get_image_state(image_path)
269
270 results = []
271 for prompt in prompts:

Callers 1

load_modelMethod · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected