MCPcopy Index your code
hub / github.com/FoundationVision/ByteTrack / main

Function main

tools/trt.py:31–70  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

29
30@logger.catch
31def main():
32 args = make_parser().parse_args()
33 exp = get_exp(args.exp_file, args.name)
34 if not args.experiment_name:
35 args.experiment_name = exp.exp_name
36
37 model = exp.get_model()
38 file_name = os.path.join(exp.output_dir, args.experiment_name)
39 os.makedirs(file_name, exist_ok=True)
40 if args.ckpt is None:
41 ckpt_file = os.path.join(file_name, "best_ckpt.pth.tar")
42 else:
43 ckpt_file = args.ckpt
44
45 ckpt = torch.load(ckpt_file, map_location="cpu")
46 # load the model state dict
47
48 model.load_state_dict(ckpt["model"])
49 logger.info("loaded checkpoint done.")
50 model.eval()
51 model.cuda()
52 model.head.decode_in_inference = False
53 x = torch.ones(1, 3, exp.test_size[0], exp.test_size[1]).cuda()
54 model_trt = torch2trt(
55 model,
56 [x],
57 fp16_mode=True,
58 log_level=trt.Logger.INFO,
59 max_workspace_size=(1 << 32),
60 )
61 torch.save(model_trt.state_dict(), os.path.join(file_name, "model_trt.pth"))
62 logger.info("Converted TensorRT model done.")
63 engine_file = os.path.join(file_name, "model_trt.engine")
64 engine_file_demo = os.path.join("deploy", "TensorRT", "cpp", "model_trt.engine")
65 with open(engine_file, "wb") as f:
66 f.write(model_trt.engine.serialize())
67
68 shutil.copyfile(engine_file, engine_file_demo)
69
70 logger.info("Converted TensorRT model engine file is saved for C++ inference.")
71
72
73if __name__ == "__main__":

Callers 1

trt.pyFile · 0.70

Calls 5

get_expFunction · 0.90
writeMethod · 0.80
make_parserFunction · 0.70
get_modelMethod · 0.45
evalMethod · 0.45

Tested by

no test coverage detected