| 56 | return model |
| 57 | |
| 58 | def inference(model: BasicVSRPlusPlusGan | BasicVSR, video: list[Image], device) -> list[Image]: |
| 59 | input_frame_count = len(video) |
| 60 | input_frame_shape = video[0].shape |
| 61 | if device and type(device) == str: |
| 62 | device = torch.device(device) |
| 63 | with torch.no_grad(): |
| 64 | input = torch.stack(image_utils.img2tensor(video, bgr2rgb=False, float32=True), dim=0) |
| 65 | input = torch.unsqueeze(input, dim=0) # TCHW -> BTCHW |
| 66 | result = model(inputs=input.to(device)) |
| 67 | result = torch.squeeze(result, dim=0) # BTCHW -> TCHW |
| 68 | result = list(torch.unbind(result, 0)) |
| 69 | output = image_utils.tensor2img(result, rgb2bgr=False, out_type=np.uint8, min_max=(0, 1)) |
| 70 | output_frame_count = len(output) |
| 71 | output_frame_shape = output[0].shape |
| 72 | assert input_frame_count == output_frame_count and input_frame_shape == output_frame_shape |
| 73 | return output |