| 115 | |
| 116 | |
| 117 | class Predictor(object): |
| 118 | def __init__( |
| 119 | self, |
| 120 | model, |
| 121 | exp, |
| 122 | trt_file=None, |
| 123 | decoder=None, |
| 124 | device=torch.device("cpu"), |
| 125 | fp16=False |
| 126 | ): |
| 127 | self.model = model |
| 128 | self.decoder = decoder |
| 129 | self.num_classes = exp.num_classes |
| 130 | self.confthre = exp.test_conf |
| 131 | self.nmsthre = exp.nmsthre |
| 132 | self.test_size = exp.test_size |
| 133 | self.device = device |
| 134 | self.fp16 = fp16 |
| 135 | if trt_file is not None: |
| 136 | from torch2trt import TRTModule |
| 137 | |
| 138 | model_trt = TRTModule() |
| 139 | model_trt.load_state_dict(torch.load(trt_file)) |
| 140 | |
| 141 | x = torch.ones((1, 3, exp.test_size[0], exp.test_size[1]), device=device) |
| 142 | self.model(x) |
| 143 | self.model = model_trt |
| 144 | self.rgb_means = (0.485, 0.456, 0.406) |
| 145 | self.std = (0.229, 0.224, 0.225) |
| 146 | |
| 147 | def inference(self, img, timer): |
| 148 | img_info = {"id": 0} |
| 149 | if isinstance(img, str): |
| 150 | img_info["file_name"] = osp.basename(img) |
| 151 | img = cv2.imread(img) |
| 152 | else: |
| 153 | img_info["file_name"] = None |
| 154 | |
| 155 | height, width = img.shape[:2] |
| 156 | img_info["height"] = height |
| 157 | img_info["width"] = width |
| 158 | img_info["raw_img"] = img |
| 159 | |
| 160 | img, ratio = preproc(img, self.test_size, self.rgb_means, self.std) |
| 161 | img_info["ratio"] = ratio |
| 162 | img = torch.from_numpy(img).unsqueeze(0).float().to(self.device) |
| 163 | if self.fp16: |
| 164 | img = img.half() # to FP16 |
| 165 | |
| 166 | with torch.no_grad(): |
| 167 | timer.tic() |
| 168 | outputs = self.model(img) |
| 169 | if self.decoder is not None: |
| 170 | outputs = self.decoder(outputs, dtype=outputs.type()) |
| 171 | outputs = postprocess( |
| 172 | outputs, self.num_classes, self.confthre, self.nmsthre |
| 173 | ) |
| 174 | #logger.info("Infer time: {:.4f}s".format(time.time() - t0)) |