(self, image)
| 314 | |
| 315 | @torch.no_grad() |
| 316 | def encode(self, image): |
| 317 | list_of_np_rgba_hwc_uint8 = [np.array(image)] |
| 318 | memory_management.load_model_gpu(self.model) |
| 319 | list_of_np_rgb_padded = [pad_rgb(x) for x in list_of_np_rgba_hwc_uint8] |
| 320 | rgb_padded_bchw_01 = torch.from_numpy(np.stack(list_of_np_rgb_padded, axis=0)).float().movedim(-1, 1) |
| 321 | rgba_bchw_01 = torch.from_numpy(np.stack(list_of_np_rgba_hwc_uint8, axis=0)).float().movedim(-1, 1) / 255.0 |
| 322 | a_bchw_01 = rgba_bchw_01[:, 3:, :, :] |
| 323 | offset_feed = torch.cat([a_bchw_01, rgb_padded_bchw_01], dim=1).to(device=self.load_device, dtype=self.dtype) |
| 324 | offset = self.model.model(offset_feed) |
| 325 | return offset |
no test coverage detected