COCO average precision (AP) Evaluation. Iterate inference on the test dataset and the results are evaluated by COCO API. NOTE: This function will change training mode to False, please save states if needed. Args: model : model to evaluate. Retu
(
self,
model,
distributed=False,
half=False,
trt_file=None,
decoder=None,
test_size=None,
result_folder=None
)
| 77 | self.args = args |
| 78 | |
| 79 | def evaluate( |
| 80 | self, |
| 81 | model, |
| 82 | distributed=False, |
| 83 | half=False, |
| 84 | trt_file=None, |
| 85 | decoder=None, |
| 86 | test_size=None, |
| 87 | result_folder=None |
| 88 | ): |
| 89 | """ |
| 90 | COCO average precision (AP) Evaluation. Iterate inference on the test dataset |
| 91 | and the results are evaluated by COCO API. |
| 92 | |
| 93 | NOTE: This function will change training mode to False, please save states if needed. |
| 94 | |
| 95 | Args: |
| 96 | model : model to evaluate. |
| 97 | |
| 98 | Returns: |
| 99 | ap50_95 (float) : COCO AP of IoU=50:95 |
| 100 | ap50 (float) : COCO AP of IoU=50 |
| 101 | summary (sr): summary info of evaluation. |
| 102 | """ |
| 103 | # TODO half to amp_test |
| 104 | tensor_type = torch.cuda.HalfTensor if half else torch.cuda.FloatTensor |
| 105 | model = model.eval() |
| 106 | if half: |
| 107 | model = model.half() |
| 108 | ids = [] |
| 109 | data_list = [] |
| 110 | results = [] |
| 111 | video_names = defaultdict() |
| 112 | progress_bar = tqdm if is_main_process() else iter |
| 113 | |
| 114 | inference_time = 0 |
| 115 | track_time = 0 |
| 116 | n_samples = len(self.dataloader) - 1 |
| 117 | |
| 118 | if trt_file is not None: |
| 119 | from torch2trt import TRTModule |
| 120 | |
| 121 | model_trt = TRTModule() |
| 122 | model_trt.load_state_dict(torch.load(trt_file)) |
| 123 | |
| 124 | x = torch.ones(1, 3, test_size[0], test_size[1]).cuda() |
| 125 | model(x) |
| 126 | model = model_trt |
| 127 | |
| 128 | tracker = BYTETracker(self.args) |
| 129 | ori_thresh = self.args.track_thresh |
| 130 | for cur_iter, (imgs, _, info_imgs, ids) in enumerate( |
| 131 | progress_bar(self.dataloader) |
| 132 | ): |
| 133 | with torch.no_grad(): |
| 134 | # init tracker |
| 135 | frame_id = info_imgs[2].item() |
| 136 | video_id = info_imgs[3].item() |
no test coverage detected