| 29 | |
| 30 | @logger.catch |
| 31 | def main(): |
| 32 | args = make_parser().parse_args() |
| 33 | exp = get_exp(args.exp_file, args.name) |
| 34 | if not args.experiment_name: |
| 35 | args.experiment_name = exp.exp_name |
| 36 | |
| 37 | model = exp.get_model() |
| 38 | file_name = os.path.join(exp.output_dir, args.experiment_name) |
| 39 | os.makedirs(file_name, exist_ok=True) |
| 40 | if args.ckpt is None: |
| 41 | ckpt_file = os.path.join(file_name, "best_ckpt.pth.tar") |
| 42 | else: |
| 43 | ckpt_file = args.ckpt |
| 44 | |
| 45 | ckpt = torch.load(ckpt_file, map_location="cpu") |
| 46 | # load the model state dict |
| 47 | |
| 48 | model.load_state_dict(ckpt["model"]) |
| 49 | logger.info("loaded checkpoint done.") |
| 50 | model.eval() |
| 51 | model.cuda() |
| 52 | model.head.decode_in_inference = False |
| 53 | x = torch.ones(1, 3, exp.test_size[0], exp.test_size[1]).cuda() |
| 54 | model_trt = torch2trt( |
| 55 | model, |
| 56 | [x], |
| 57 | fp16_mode=True, |
| 58 | log_level=trt.Logger.INFO, |
| 59 | max_workspace_size=(1 << 32), |
| 60 | ) |
| 61 | torch.save(model_trt.state_dict(), os.path.join(file_name, "model_trt.pth")) |
| 62 | logger.info("Converted TensorRT model done.") |
| 63 | engine_file = os.path.join(file_name, "model_trt.engine") |
| 64 | engine_file_demo = os.path.join("deploy", "TensorRT", "cpp", "model_trt.engine") |
| 65 | with open(engine_file, "wb") as f: |
| 66 | f.write(model_trt.engine.serialize()) |
| 67 | |
| 68 | shutil.copyfile(engine_file, engine_file_demo) |
| 69 | |
| 70 | logger.info("Converted TensorRT model engine file is saved for C++ inference.") |
| 71 | |
| 72 | |
| 73 | if __name__ == "__main__": |