Preprocess the image and mask.
(
self,
image: PIL.Image.Image,
mask: PIL.Image.Image | None = None,
height: int | None = None,
width: int | None = None,
padding_mask_crop: int | None = None,
)
| 879 | ) |
| 880 | |
| 881 | def preprocess( |
| 882 | self, |
| 883 | image: PIL.Image.Image, |
| 884 | mask: PIL.Image.Image | None = None, |
| 885 | height: int | None = None, |
| 886 | width: int | None = None, |
| 887 | padding_mask_crop: int | None = None, |
| 888 | ) -> tuple[torch.Tensor, torch.Tensor]: |
| 889 | """ |
| 890 | Preprocess the image and mask. |
| 891 | """ |
| 892 | if mask is None and padding_mask_crop is not None: |
| 893 | raise ValueError("mask must be provided if padding_mask_crop is provided") |
| 894 | |
| 895 | # if mask is None, same behavior as regular image processor |
| 896 | if mask is None: |
| 897 | return self._image_processor.preprocess(image, height=height, width=width) |
| 898 | |
| 899 | if padding_mask_crop is not None: |
| 900 | crops_coords = self._image_processor.get_crop_region(mask, width, height, pad=padding_mask_crop) |
| 901 | resize_mode = "fill" |
| 902 | else: |
| 903 | crops_coords = None |
| 904 | resize_mode = "default" |
| 905 | |
| 906 | processed_image = self._image_processor.preprocess( |
| 907 | image, |
| 908 | height=height, |
| 909 | width=width, |
| 910 | crops_coords=crops_coords, |
| 911 | resize_mode=resize_mode, |
| 912 | ) |
| 913 | |
| 914 | processed_mask = self._mask_processor.preprocess( |
| 915 | mask, |
| 916 | height=height, |
| 917 | width=width, |
| 918 | resize_mode=resize_mode, |
| 919 | crops_coords=crops_coords, |
| 920 | ) |
| 921 | |
| 922 | if crops_coords is not None: |
| 923 | postprocessing_kwargs = { |
| 924 | "crops_coords": crops_coords, |
| 925 | "original_image": image, |
| 926 | "original_mask": mask, |
| 927 | } |
| 928 | else: |
| 929 | postprocessing_kwargs = { |
| 930 | "crops_coords": None, |
| 931 | "original_image": None, |
| 932 | "original_mask": None, |
| 933 | } |
| 934 | |
| 935 | return processed_image, processed_mask, postprocessing_kwargs |
| 936 | |
| 937 | def postprocess( |
| 938 | self, |