MCPcopy
hub / github.com/ladaapp/lada / inference

Function inference

lada/models/basicvsrpp/inference.py:58–73  ·  view source on GitHub ↗
(model: BasicVSRPlusPlusGan | BasicVSR, video: list[Image], device)

Source from the content-addressed store, hash-verified

56 return model
57
58def inference(model: BasicVSRPlusPlusGan | BasicVSR, video: list[Image], device) -> list[Image]:
59 input_frame_count = len(video)
60 input_frame_shape = video[0].shape
61 if device and type(device) == str:
62 device = torch.device(device)
63 with torch.no_grad():
64 input = torch.stack(image_utils.img2tensor(video, bgr2rgb=False, float32=True), dim=0)
65 input = torch.unsqueeze(input, dim=0) # TCHW -> BTCHW
66 result = model(inputs=input.to(device))
67 result = torch.squeeze(result, dim=0) # BTCHW -> TCHW
68 result = list(torch.unbind(result, 0))
69 output = image_utils.tensor2img(result, rgb2bgr=False, out_type=np.uint8, min_max=(0, 1))
70 output_frame_count = len(output)
71 output_frame_shape = output[0].shape
72 assert input_frame_count == output_frame_count and input_frame_shape == output_frame_shape
73 return output

Callers

nothing calls this directly

Calls 2

stackMethod · 0.80
deviceMethod · 0.45

Tested by

no test coverage detected