COCO average precision (AP) Evaluation. Iterate inference on the test dataset and the results are evaluated by COCO API. NOTE: This function will change training mode to False, please save states if needed. Args: model : model to evaluate. Retu
(
self,
model,
distributed=False,
half=False,
trt_file=None,
decoder=None,
test_size=None,
)
| 50 | self.testdev = testdev |
| 51 | |
| 52 | def evaluate( |
| 53 | self, |
| 54 | model, |
| 55 | distributed=False, |
| 56 | half=False, |
| 57 | trt_file=None, |
| 58 | decoder=None, |
| 59 | test_size=None, |
| 60 | ): |
| 61 | """ |
| 62 | COCO average precision (AP) Evaluation. Iterate inference on the test dataset |
| 63 | and the results are evaluated by COCO API. |
| 64 | |
| 65 | NOTE: This function will change training mode to False, please save states if needed. |
| 66 | |
| 67 | Args: |
| 68 | model : model to evaluate. |
| 69 | |
| 70 | Returns: |
| 71 | ap50_95 (float) : COCO AP of IoU=50:95 |
| 72 | ap50 (float) : COCO AP of IoU=50 |
| 73 | summary (sr): summary info of evaluation. |
| 74 | """ |
| 75 | # TODO half to amp_test |
| 76 | tensor_type = torch.cuda.HalfTensor if half else torch.cuda.FloatTensor |
| 77 | model = model.eval() |
| 78 | if half: |
| 79 | model = model.half() |
| 80 | ids = [] |
| 81 | data_list = [] |
| 82 | progress_bar = tqdm if is_main_process() else iter |
| 83 | |
| 84 | inference_time = 0 |
| 85 | nms_time = 0 |
| 86 | n_samples = len(self.dataloader) - 1 |
| 87 | |
| 88 | if trt_file is not None: |
| 89 | from torch2trt import TRTModule |
| 90 | |
| 91 | model_trt = TRTModule() |
| 92 | model_trt.load_state_dict(torch.load(trt_file)) |
| 93 | |
| 94 | x = torch.ones(1, 3, test_size[0], test_size[1]).cuda() |
| 95 | model(x) |
| 96 | model = model_trt |
| 97 | |
| 98 | for cur_iter, (imgs, _, info_imgs, ids) in enumerate( |
| 99 | progress_bar(self.dataloader) |
| 100 | ): |
| 101 | with torch.no_grad(): |
| 102 | imgs = imgs.type(tensor_type) |
| 103 | |
| 104 | # skip the the last iters since batchsize might be not enough for batch inference |
| 105 | is_time_record = cur_iter < len(self.dataloader) - 1 |
| 106 | if is_time_record: |
| 107 | start = time.time() |
| 108 | |
| 109 | outputs = model(imgs) |
no test coverage detected