| 41 | # https://github.com/NVIDIA/TensorRT/blob/main/tools/pytorch-quantization/examples/finetune_quant_resnet50.ipynb |
| 42 | |
| 43 | def export_onnx(model, onnx_filename, batch_onnx): |
| 44 | model.eval() |
| 45 | quant_nn.TensorQuantizer.use_fb_fake_quant = True # We have to shift to pytorch's fake quant ops before exporting the model to ONNX |
| 46 | opset_version = 13 |
| 47 | |
| 48 | # Export ONNX for multiple batch sizes |
| 49 | print("Creating ONNX file: " + onnx_filename) |
| 50 | dummy_input = torch.randn(batch_onnx, 3, 224, 224, device='cuda') #TODO: switch input dims by model |
| 51 | torch.onnx.export(model, dummy_input, onnx_filename, verbose=False, opset_version=opset_version, enable_onnx_checker=False, do_constant_folding=True) |
| 52 | return True |