| 23 | |
| 24 | |
| 25 | class ToTensor: |
| 26 | |
| 27 | def __init__(self, dtype=torch.float32): |
| 28 | self.dtype = dtype |
| 29 | |
| 30 | def __call__(self, pil_img): |
| 31 | np_img = np.array(pil_img, dtype=np.uint8) |
| 32 | if np_img.ndim < 3: |
| 33 | np_img = np.expand_dims(np_img, axis=-1) |
| 34 | np_img = np.rollaxis(np_img, 2) # HWC to CHW |
| 35 | return torch.from_numpy(np_img).to(dtype=self.dtype) |
| 36 | |
| 37 | |
| 38 | _pil_interpolation_to_str = { |