Args: repeats (int): repeats number for prediction Returns: result (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box, matix element:[class, score, x_min, y_min, x_max, y_max] MaskRCNN's
(self, repeats=1, run_benchmark=False)
| 184 | return filter_res |
| 185 | |
| 186 | def predict(self, repeats=1, run_benchmark=False): |
| 187 | ''' |
| 188 | Args: |
| 189 | repeats (int): repeats number for prediction |
| 190 | Returns: |
| 191 | result (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box, |
| 192 | matix element:[class, score, x_min, y_min, x_max, y_max] |
| 193 | MaskRCNN's result include 'masks': np.ndarray: |
| 194 | shape: [N, im_h, im_w] |
| 195 | ''' |
| 196 | # model prediction |
| 197 | np_boxes_num, np_boxes, np_masks = np.array([0]), None, None |
| 198 | |
| 199 | if run_benchmark: |
| 200 | for i in range(repeats): |
| 201 | self.predictor.run() |
| 202 | if self.device == 'GPU': |
| 203 | paddle.device.cuda.synchronize() |
| 204 | else: |
| 205 | paddle.device.synchronize(device=self.device.lower()) |
| 206 | |
| 207 | result = dict( |
| 208 | boxes=np_boxes, masks=np_masks, boxes_num=np_boxes_num) |
| 209 | return result |
| 210 | |
| 211 | for i in range(repeats): |
| 212 | self.predictor.run() |
| 213 | output_names = self.predictor.get_output_names() |
| 214 | boxes_tensor = self.predictor.get_output_handle(output_names[0]) |
| 215 | np_boxes = boxes_tensor.copy_to_cpu() |
| 216 | if len(output_names) == 1: |
| 217 | # some exported model can not get tensor 'bbox_num' |
| 218 | np_boxes_num = np.array([len(np_boxes)]) |
| 219 | else: |
| 220 | boxes_num = self.predictor.get_output_handle(output_names[1]) |
| 221 | np_boxes_num = boxes_num.copy_to_cpu() |
| 222 | if self.pred_config.mask: |
| 223 | masks_tensor = self.predictor.get_output_handle(output_names[2]) |
| 224 | np_masks = masks_tensor.copy_to_cpu() |
| 225 | result = dict(boxes=np_boxes, masks=np_masks, boxes_num=np_boxes_num) |
| 226 | return result |
| 227 | |
| 228 | def merge_batch_result(self, batch_result): |
| 229 | if len(batch_result) == 1: |
no test coverage detected