Extract diagram elements via SAM3; prompt groups and thresholds from config.
| 367 | |
| 368 | # ======================== SAM3信息提取器 ======================== |
| 369 | class Sam3InfoExtractor(BaseProcessor): |
| 370 | """Extract diagram elements via SAM3; prompt groups and thresholds from config.""" |
| 371 | |
| 372 | def __init__(self, config=None, checkpoint_path: str = None, bpe_path: str = None): |
| 373 | super().__init__(config) |
| 374 | |
| 375 | # 从配置文件加载词组(不再硬编码) |
| 376 | self.prompt_groups = ConfigLoader.get_prompt_groups() |
| 377 | self.text_filter = ConfigLoader.get_text_filter() |
| 378 | self.dedup_config = ConfigLoader.get_deduplication_config() |
| 379 | |
| 380 | # 加载SAM3模型配置 |
| 381 | sam3_config = ConfigLoader.get_sam3_config() |
| 382 | self._checkpoint_path = checkpoint_path or sam3_config.get('checkpoint_path', '') |
| 383 | self._bpe_path = bpe_path or sam3_config.get('bpe_path', '') |
| 384 | |
| 385 | self._sam3_model: Optional[SAM3Model] = None |
| 386 | self._current_image_path: Optional[str] = None |
| 387 | |
| 388 | def reload_config(self): |
| 389 | """Reload config from disk.""" |
| 390 | ConfigLoader.load_config(force_reload=True) |
| 391 | self.prompt_groups = ConfigLoader.get_prompt_groups() |
| 392 | self.text_filter = ConfigLoader.get_text_filter() |
| 393 | self.dedup_config = ConfigLoader.get_deduplication_config() |
| 394 | self._log("Config reloaded") |
| 395 | |
| 396 | def load_model(self): |
| 397 | """Load SAM3 model.""" |
| 398 | if self._sam3_model is None: |
| 399 | sam3_config = ConfigLoader.get_sam3_config() |
| 400 | device = sam3_config.get("device") # e.g. "cpu" or "cuda", None = auto |
| 401 | self._sam3_model = SAM3Model( |
| 402 | checkpoint_path=self._checkpoint_path, |
| 403 | bpe_path=self._bpe_path, |
| 404 | device=device |
| 405 | ) |
| 406 | if not self._sam3_model.is_loaded: |
| 407 | self._sam3_model.load() |
| 408 | |
| 409 | def process(self, context: ProcessingContext) -> ProcessingResult: |
| 410 | """ |
| 411 | 处理入口 - 分组提取图片中的所有元素 |
| 412 | |
| 413 | Args: |
| 414 | context: 处理上下文,需要包含 image_path |
| 415 | |
| 416 | Returns: |
| 417 | ProcessingResult: 包含所有提取的ElementInfo |
| 418 | """ |
| 419 | self._log(f"开始处理: {context.image_path}") |
| 420 | |
| 421 | # 保存当前图像路径(供去重分析使用) |
| 422 | self._current_image_path = context.image_path |
| 423 | |
| 424 | self.load_model() |
| 425 | |
| 426 | pil_image = Image.open(context.image_path) |
no outgoing calls
no test coverage detected