Base class for all detection models in SAHI. Subclasses must implement ``load_model``, ``perform_inference``, and ``_create_object_prediction_list_from_original_predictions`` to integrate a new detection framework. The base class handles device management, dependency checking, categ
| 18 | |
| 19 | |
| 20 | class DetectionModel: |
| 21 | """Base class for all detection models in SAHI. |
| 22 | |
| 23 | Subclasses must implement ``load_model``, ``perform_inference``, and |
| 24 | ``_create_object_prediction_list_from_original_predictions`` to integrate |
| 25 | a new detection framework. The base class handles device management, |
| 26 | dependency checking, category remapping, and the public prediction API. |
| 27 | """ |
| 28 | |
| 29 | required_packages: list[str] | None = None |
| 30 | |
| 31 | def __init__( |
| 32 | self, |
| 33 | model_path: str | None = None, |
| 34 | model: Any | None = None, |
| 35 | config_path: str | None = None, |
| 36 | device: str | None = None, |
| 37 | mask_threshold: float = 0.5, |
| 38 | confidence_threshold: float = 0.3, |
| 39 | category_mapping: dict | None = None, |
| 40 | category_remapping: dict | None = None, |
| 41 | load_at_init: bool = True, |
| 42 | image_size: int | None = None, |
| 43 | ) -> None: |
| 44 | """Init object detection/instance segmentation model. |
| 45 | |
| 46 | Args: |
| 47 | model_path: str |
| 48 | Path for the instance segmentation model weight |
| 49 | model: Any |
| 50 | A pre-loaded detection model instance. |
| 51 | config_path: str |
| 52 | Path for the mmdetection instance segmentation model config file |
| 53 | device: Torch device, "cpu", "mps", "cuda", "cuda:0", "cuda:1", etc. |
| 54 | mask_threshold: float |
| 55 | Value to threshold mask pixels, should be between 0 and 1 |
| 56 | confidence_threshold: float |
| 57 | All predictions with score < confidence_threshold will be discarded |
| 58 | category_mapping: dict: str to str |
| 59 | Mapping from category id (str) to category name (str) e.g. {"1": "pedestrian"} |
| 60 | category_remapping: dict: str to int |
| 61 | Remap category ids based on category names, after performing inference e.g. {"car": 3} |
| 62 | load_at_init: bool |
| 63 | If True, automatically loads the model at initialization |
| 64 | image_size: int |
| 65 | Inference input size. |
| 66 | """ |
| 67 | self.model_path = model_path |
| 68 | self.config_path = config_path |
| 69 | self.model: Any = None |
| 70 | self.mask_threshold = mask_threshold |
| 71 | self.confidence_threshold = confidence_threshold |
| 72 | self.category_mapping = category_mapping |
| 73 | self.category_remapping = category_remapping |
| 74 | self.image_size = image_size |
| 75 | self._original_predictions: Any = None |
| 76 | self._object_prediction_list_per_image: list[list[ObjectPrediction]] | None = None |
| 77 | self._batch_images: list[np.ndarray] | None = None |
no outgoing calls