()
| 1193 | |
| 1194 | |
| 1195 | def main(): |
| 1196 | if FLAGS.use_fd_format: |
| 1197 | deploy_file = os.path.join(FLAGS.model_dir, 'inference.yml') |
| 1198 | else: |
| 1199 | deploy_file = os.path.join(FLAGS.model_dir, 'infer_cfg.yml') |
| 1200 | with open(deploy_file) as f: |
| 1201 | yml_conf = yaml.safe_load(f) |
| 1202 | arch = yml_conf['arch'] |
| 1203 | detector_func = 'Detector' |
| 1204 | if arch == 'SOLOv2': |
| 1205 | detector_func = 'DetectorSOLOv2' |
| 1206 | elif arch == 'PicoDet': |
| 1207 | detector_func = 'DetectorPicoDet' |
| 1208 | elif arch == "CLRNet": |
| 1209 | detector_func = 'DetectorCLRNet' |
| 1210 | |
| 1211 | detector = eval(detector_func)( |
| 1212 | FLAGS.model_dir, |
| 1213 | device=FLAGS.device, |
| 1214 | run_mode=FLAGS.run_mode, |
| 1215 | batch_size=FLAGS.batch_size, |
| 1216 | trt_min_shape=FLAGS.trt_min_shape, |
| 1217 | trt_max_shape=FLAGS.trt_max_shape, |
| 1218 | trt_opt_shape=FLAGS.trt_opt_shape, |
| 1219 | trt_calib_mode=FLAGS.trt_calib_mode, |
| 1220 | cpu_threads=FLAGS.cpu_threads, |
| 1221 | enable_mkldnn=FLAGS.enable_mkldnn, |
| 1222 | enable_mkldnn_bfloat16=FLAGS.enable_mkldnn_bfloat16, |
| 1223 | threshold=FLAGS.threshold, |
| 1224 | output_dir=FLAGS.output_dir, |
| 1225 | use_fd_format=FLAGS.use_fd_format) |
| 1226 | |
| 1227 | # predict from video file or camera video stream |
| 1228 | if FLAGS.video_file is not None or FLAGS.camera_id != -1: |
| 1229 | detector.predict_video(FLAGS.video_file, FLAGS.camera_id) |
| 1230 | else: |
| 1231 | # predict from image |
| 1232 | if FLAGS.image_dir is None and FLAGS.image_file is not None: |
| 1233 | assert FLAGS.batch_size == 1, "batch_size should be 1, when image_file is not None" |
| 1234 | img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file) |
| 1235 | if FLAGS.slice_infer: |
| 1236 | detector.predict_image_slice( |
| 1237 | img_list, |
| 1238 | FLAGS.slice_size, |
| 1239 | FLAGS.overlap_ratio, |
| 1240 | FLAGS.combine_method, |
| 1241 | FLAGS.match_threshold, |
| 1242 | FLAGS.match_metric, |
| 1243 | visual=FLAGS.save_images, |
| 1244 | save_results=FLAGS.save_results) |
| 1245 | else: |
| 1246 | detector.predict_image( |
| 1247 | img_list, |
| 1248 | FLAGS.run_benchmark, |
| 1249 | repeats=100, |
| 1250 | visual=FLAGS.save_images, |
| 1251 | save_results=FLAGS.save_results) |
| 1252 | if not FLAGS.run_benchmark: |
no test coverage detected