MCPcopy Index your code
hub / github.com/FoundationVision/ByteTrack / imageflow_demo

Function imageflow_demo

tools/demo_track.py:236–298  ·  view source on GitHub ↗
(predictor, vis_folder, current_time, args)

Source from the content-addressed store, hash-verified

234
235
236def imageflow_demo(predictor, vis_folder, current_time, args):
237 cap = cv2.VideoCapture(args.path if args.demo == "video" else args.camid)
238 width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) # float
239 height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) # float
240 fps = cap.get(cv2.CAP_PROP_FPS)
241 timestamp = time.strftime("%Y_%m_%d_%H_%M_%S", current_time)
242 save_folder = osp.join(vis_folder, timestamp)
243 os.makedirs(save_folder, exist_ok=True)
244 if args.demo == "video":
245 save_path = osp.join(save_folder, args.path.split("/")[-1])
246 else:
247 save_path = osp.join(save_folder, "camera.mp4")
248 logger.info(f"video save_path is {save_path}")
249 vid_writer = cv2.VideoWriter(
250 save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (int(width), int(height))
251 )
252 tracker = BYTETracker(args, frame_rate=30)
253 timer = Timer()
254 frame_id = 0
255 results = []
256 while True:
257 if frame_id % 20 == 0:
258 logger.info('Processing frame {} ({:.2f} fps)'.format(frame_id, 1. / max(1e-5, timer.average_time)))
259 ret_val, frame = cap.read()
260 if ret_val:
261 outputs, img_info = predictor.inference(frame, timer)
262 if outputs[0] is not None:
263 online_targets = tracker.update(outputs[0], [img_info['height'], img_info['width']], exp.test_size)
264 online_tlwhs = []
265 online_ids = []
266 online_scores = []
267 for t in online_targets:
268 tlwh = t.tlwh
269 tid = t.track_id
270 vertical = tlwh[2] / tlwh[3] > args.aspect_ratio_thresh
271 if tlwh[2] * tlwh[3] > args.min_box_area and not vertical:
272 online_tlwhs.append(tlwh)
273 online_ids.append(tid)
274 online_scores.append(t.score)
275 results.append(
276 f"{frame_id},{tid},{tlwh[0]:.2f},{tlwh[1]:.2f},{tlwh[2]:.2f},{tlwh[3]:.2f},{t.score:.2f},-1,-1,-1\n"
277 )
278 timer.toc()
279 online_im = plot_tracking(
280 img_info['raw_img'], online_tlwhs, online_ids, frame_id=frame_id + 1, fps=1. / timer.average_time
281 )
282 else:
283 timer.toc()
284 online_im = img_info['raw_img']
285 if args.save_result:
286 vid_writer.write(online_im)
287 ch = cv2.waitKey(1)
288 if ch == 27 or ch == ord("q") or ch == ord("Q"):
289 break
290 else:
291 break
292 frame_id += 1
293

Callers 1

mainFunction · 0.70

Calls 7

updateMethod · 0.95
tocMethod · 0.95
BYTETrackerClass · 0.90
TimerClass · 0.90
plot_trackingFunction · 0.90
writeMethod · 0.80
inferenceMethod · 0.45

Tested by

no test coverage detected