MCPcopy Index your code
hub / github.com/FoundationVision/ByteTrack / evaluate

Method evaluate

yolox/evaluators/coco_evaluator.py:52–134  ·  view source on GitHub ↗

COCO average precision (AP) Evaluation. Iterate inference on the test dataset and the results are evaluated by COCO API. NOTE: This function will change training mode to False, please save states if needed. Args: model : model to evaluate. Retu

(
        self,
        model,
        distributed=False,
        half=False,
        trt_file=None,
        decoder=None,
        test_size=None,
    )

Source from the content-addressed store, hash-verified

50 self.testdev = testdev
51
52 def evaluate(
53 self,
54 model,
55 distributed=False,
56 half=False,
57 trt_file=None,
58 decoder=None,
59 test_size=None,
60 ):
61 """
62 COCO average precision (AP) Evaluation. Iterate inference on the test dataset
63 and the results are evaluated by COCO API.
64
65 NOTE: This function will change training mode to False, please save states if needed.
66
67 Args:
68 model : model to evaluate.
69
70 Returns:
71 ap50_95 (float) : COCO AP of IoU=50:95
72 ap50 (float) : COCO AP of IoU=50
73 summary (sr): summary info of evaluation.
74 """
75 # TODO half to amp_test
76 tensor_type = torch.cuda.HalfTensor if half else torch.cuda.FloatTensor
77 model = model.eval()
78 if half:
79 model = model.half()
80 ids = []
81 data_list = []
82 progress_bar = tqdm if is_main_process() else iter
83
84 inference_time = 0
85 nms_time = 0
86 n_samples = len(self.dataloader) - 1
87
88 if trt_file is not None:
89 from torch2trt import TRTModule
90
91 model_trt = TRTModule()
92 model_trt.load_state_dict(torch.load(trt_file))
93
94 x = torch.ones(1, 3, test_size[0], test_size[1]).cuda()
95 model(x)
96 model = model_trt
97
98 for cur_iter, (imgs, _, info_imgs, ids) in enumerate(
99 progress_bar(self.dataloader)
100 ):
101 with torch.no_grad():
102 imgs = imgs.type(tensor_type)
103
104 # skip the the last iters since batchsize might be not enough for batch inference
105 is_time_record = cur_iter < len(self.dataloader) - 1
106 if is_time_record:
107 start = time.time()
108
109 outputs = model(imgs)

Callers 1

evaluate_predictionMethod · 0.45

Calls 8

evaluate_predictionMethod · 0.95
is_main_processFunction · 0.90
time_synchronizedFunction · 0.90
postprocessFunction · 0.90
gatherFunction · 0.90
synchronizeFunction · 0.90
evalMethod · 0.45

Tested by

no test coverage detected