MCPcopy Index your code
hub / github.com/FoundationVision/ByteTrack / main

Function main

tools/export_onnx.py:50–98  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

48
49@logger.catch
50def main():
51 args = make_parser().parse_args()
52 logger.info("args value: {}".format(args))
53 exp = get_exp(args.exp_file, args.name)
54 exp.merge(args.opts)
55
56 if not args.experiment_name:
57 args.experiment_name = exp.exp_name
58
59 model = exp.get_model()
60 if args.ckpt is None:
61 file_name = os.path.join(exp.output_dir, args.experiment_name)
62 ckpt_file = os.path.join(file_name, "best_ckpt.pth.tar")
63 else:
64 ckpt_file = args.ckpt
65
66 # load the model state dict
67 ckpt = torch.load(ckpt_file, map_location="cpu")
68
69 model.eval()
70 if "model" in ckpt:
71 ckpt = ckpt["model"]
72 model.load_state_dict(ckpt)
73 model = replace_module(model, nn.SiLU, SiLU)
74 model.head.decode_in_inference = False
75
76 logger.info("loading checkpoint done.")
77 dummy_input = torch.randn(1, 3, exp.test_size[0], exp.test_size[1])
78 torch.onnx._export(
79 model,
80 dummy_input,
81 args.output_name,
82 input_names=[args.input],
83 output_names=[args.output],
84 opset_version=args.opset,
85 )
86 logger.info("generated onnx model named {}".format(args.output_name))
87
88 if not args.no_onnxsim:
89 import onnx
90
91 from onnxsim import simplify
92
93 # use onnxsimplify to reduce reduent model.
94 onnx_model = onnx.load(args.output_name)
95 model_simp, check = simplify(onnx_model)
96 assert check, "Simplified ONNX model could not be validated"
97 onnx.save(model_simp, args.output_name)
98 logger.info("generated simplified onnx model named {}".format(args.output_name))
99
100
101if __name__ == "__main__":

Callers 1

export_onnx.pyFile · 0.70

Calls 6

get_expFunction · 0.90
replace_moduleFunction · 0.90
mergeMethod · 0.80
make_parserFunction · 0.70
get_modelMethod · 0.45
evalMethod · 0.45

Tested by

no test coverage detected