对传入的图像进行预测,支持图像地址,opencv 读取图片,偏慢 :param img_path: 图像地址 :param is_numpy: :return:
(self, img_path: str, is_output_polygon=False, short_size: int = 1024)
| 68 | self.transform = get_transforms(self.transform) |
| 69 | |
| 70 | def predict(self, img_path: str, is_output_polygon=False, short_size: int = 1024): |
| 71 | """ |
| 72 | 对传入的图像进行预测,支持图像地址,opencv 读取图片,偏慢 |
| 73 | :param img_path: 图像地址 |
| 74 | :param is_numpy: |
| 75 | :return: |
| 76 | """ |
| 77 | assert os.path.exists(img_path), "file is not exists" |
| 78 | img = cv2.imread(img_path, 1 if self.img_mode != "GRAY" else 0) |
| 79 | if self.img_mode == "RGB": |
| 80 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
| 81 | h, w = img.shape[:2] |
| 82 | img = resize_image(img, short_size) |
| 83 | # 将图片由(w,h)变为(1,img_channel,h,w) |
| 84 | tensor = self.transform(img) |
| 85 | tensor = tensor.unsqueeze_(0) |
| 86 | |
| 87 | batch = {"shape": [(h, w)]} |
| 88 | with paddle.no_grad(): |
| 89 | start = time.time() |
| 90 | preds = self.model(tensor) |
| 91 | box_list, score_list = self.post_process( |
| 92 | batch, preds, is_output_polygon=is_output_polygon |
| 93 | ) |
| 94 | box_list, score_list = box_list[0], score_list[0] |
| 95 | if len(box_list) > 0: |
| 96 | if is_output_polygon: |
| 97 | idx = [x.sum() > 0 for x in box_list] |
| 98 | box_list = [box_list[i] for i, v in enumerate(idx) if v] |
| 99 | score_list = [score_list[i] for i, v in enumerate(idx) if v] |
| 100 | else: |
| 101 | idx = ( |
| 102 | box_list.reshape(box_list.shape[0], -1).sum(axis=1) > 0 |
| 103 | ) # 去掉全为0的框 |
| 104 | box_list, score_list = box_list[idx], score_list[idx] |
| 105 | else: |
| 106 | box_list, score_list = [], [] |
| 107 | t = time.time() - start |
| 108 | return preds[0, 0, :, :].detach().cpu().numpy(), box_list, score_list, t |
| 109 | |
| 110 | |
| 111 | def save_depoly(net, input, save_path): |
nothing calls this directly
no test coverage detected