()
| 48 | |
| 49 | @logger.catch |
| 50 | def main(): |
| 51 | args = make_parser().parse_args() |
| 52 | logger.info("args value: {}".format(args)) |
| 53 | exp = get_exp(args.exp_file, args.name) |
| 54 | exp.merge(args.opts) |
| 55 | |
| 56 | if not args.experiment_name: |
| 57 | args.experiment_name = exp.exp_name |
| 58 | |
| 59 | model = exp.get_model() |
| 60 | if args.ckpt is None: |
| 61 | file_name = os.path.join(exp.output_dir, args.experiment_name) |
| 62 | ckpt_file = os.path.join(file_name, "best_ckpt.pth.tar") |
| 63 | else: |
| 64 | ckpt_file = args.ckpt |
| 65 | |
| 66 | # load the model state dict |
| 67 | ckpt = torch.load(ckpt_file, map_location="cpu") |
| 68 | |
| 69 | model.eval() |
| 70 | if "model" in ckpt: |
| 71 | ckpt = ckpt["model"] |
| 72 | model.load_state_dict(ckpt) |
| 73 | model = replace_module(model, nn.SiLU, SiLU) |
| 74 | model.head.decode_in_inference = False |
| 75 | |
| 76 | logger.info("loading checkpoint done.") |
| 77 | dummy_input = torch.randn(1, 3, exp.test_size[0], exp.test_size[1]) |
| 78 | torch.onnx._export( |
| 79 | model, |
| 80 | dummy_input, |
| 81 | args.output_name, |
| 82 | input_names=[args.input], |
| 83 | output_names=[args.output], |
| 84 | opset_version=args.opset, |
| 85 | ) |
| 86 | logger.info("generated onnx model named {}".format(args.output_name)) |
| 87 | |
| 88 | if not args.no_onnxsim: |
| 89 | import onnx |
| 90 | |
| 91 | from onnxsim import simplify |
| 92 | |
| 93 | # use onnxsimplify to reduce reduent model. |
| 94 | onnx_model = onnx.load(args.output_name) |
| 95 | model_simp, check = simplify(onnx_model) |
| 96 | assert check, "Simplified ONNX model could not be validated" |
| 97 | onnx.save(model_simp, args.output_name) |
| 98 | logger.info("generated simplified onnx model named {}".format(args.output_name)) |
| 99 | |
| 100 | |
| 101 | if __name__ == "__main__": |
no test coverage detected