MCPcopy Index your code
hub / github.com/FoundationVision/ByteTrack / Predictor

Class Predictor

tools/demo_track.py:117–175  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

115
116
117class Predictor(object):
118 def __init__(
119 self,
120 model,
121 exp,
122 trt_file=None,
123 decoder=None,
124 device=torch.device("cpu"),
125 fp16=False
126 ):
127 self.model = model
128 self.decoder = decoder
129 self.num_classes = exp.num_classes
130 self.confthre = exp.test_conf
131 self.nmsthre = exp.nmsthre
132 self.test_size = exp.test_size
133 self.device = device
134 self.fp16 = fp16
135 if trt_file is not None:
136 from torch2trt import TRTModule
137
138 model_trt = TRTModule()
139 model_trt.load_state_dict(torch.load(trt_file))
140
141 x = torch.ones((1, 3, exp.test_size[0], exp.test_size[1]), device=device)
142 self.model(x)
143 self.model = model_trt
144 self.rgb_means = (0.485, 0.456, 0.406)
145 self.std = (0.229, 0.224, 0.225)
146
147 def inference(self, img, timer):
148 img_info = {"id": 0}
149 if isinstance(img, str):
150 img_info["file_name"] = osp.basename(img)
151 img = cv2.imread(img)
152 else:
153 img_info["file_name"] = None
154
155 height, width = img.shape[:2]
156 img_info["height"] = height
157 img_info["width"] = width
158 img_info["raw_img"] = img
159
160 img, ratio = preproc(img, self.test_size, self.rgb_means, self.std)
161 img_info["ratio"] = ratio
162 img = torch.from_numpy(img).unsqueeze(0).float().to(self.device)
163 if self.fp16:
164 img = img.half() # to FP16
165
166 with torch.no_grad():
167 timer.tic()
168 outputs = self.model(img)
169 if self.decoder is not None:
170 outputs = self.decoder(outputs, dtype=outputs.type())
171 outputs = postprocess(
172 outputs, self.num_classes, self.confthre, self.nmsthre
173 )
174 #logger.info("Infer time: {:.4f}s".format(time.time() - t0))

Callers 1

mainFunction · 0.70

Calls

no outgoing calls

Tested by

no test coverage detected