(self, image, channel)
| 1815 | FUNCTION = "load_image_mask" |
| 1816 | |
| 1817 | def load_image_mask(self, image, channel): |
| 1818 | image_tensor, mask_tensor = super().load_image(image) |
| 1819 | c = channel[0].upper() |
| 1820 | |
| 1821 | if c == 'A': |
| 1822 | return (mask_tensor,) |
| 1823 | |
| 1824 | channel_idx = {'R': 0, 'G': 1, 'B': 2}.get(c, 0) |
| 1825 | |
| 1826 | if channel_idx < image_tensor.shape[-1]: |
| 1827 | return (image_tensor[..., channel_idx].clone(),) |
| 1828 | else: |
| 1829 | empty_mask = torch.zeros( |
| 1830 | image_tensor.shape[:-1], |
| 1831 | dtype=image_tensor.dtype, |
| 1832 | device=image_tensor.device |
| 1833 | ) |
| 1834 | return (empty_mask,) |
| 1835 | |
| 1836 | @classmethod |
| 1837 | def IS_CHANGED(s, image, channel): |
nothing calls this directly
no test coverage detected