(model, im, file, opset, dynamic, simplify, prefix=colorstr('ONNX:'))
| 131 | |
| 132 | @try_export |
| 133 | def export_onnx(model, im, file, opset, dynamic, simplify, prefix=colorstr('ONNX:')): |
| 134 | # YOLOv5 ONNX export |
| 135 | check_requirements('onnx') |
| 136 | import onnx |
| 137 | |
| 138 | LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...') |
| 139 | f = file.with_suffix('.onnx') |
| 140 | |
| 141 | output_names = ['output0', 'output1'] if isinstance(model, SegmentationModel) else ['output0'] |
| 142 | if dynamic: |
| 143 | dynamic = {'images': {0: 'batch', 2: 'height', 3: 'width'}} # shape(1,3,640,640) |
| 144 | if isinstance(model, SegmentationModel): |
| 145 | dynamic['output0'] = {0: 'batch', 1: 'anchors'} # shape(1,25200,85) |
| 146 | dynamic['output1'] = {0: 'batch', 2: 'mask_height', 3: 'mask_width'} # shape(1,32,160,160) |
| 147 | elif isinstance(model, DetectionModel): |
| 148 | dynamic['output0'] = {0: 'batch', 1: 'anchors'} # shape(1,25200,85) |
| 149 | |
| 150 | torch.onnx.export( |
| 151 | model.cpu() if dynamic else model, # --dynamic only compatible with cpu |
| 152 | im.cpu() if dynamic else im, |
| 153 | f, |
| 154 | verbose=False, |
| 155 | opset_version=opset, |
| 156 | do_constant_folding=True, |
| 157 | input_names=['images'], |
| 158 | output_names=output_names, |
| 159 | dynamic_axes=dynamic or None) |
| 160 | |
| 161 | # Checks |
| 162 | model_onnx = onnx.load(f) # load onnx model |
| 163 | onnx.checker.check_model(model_onnx) # check onnx model |
| 164 | |
| 165 | # Metadata |
| 166 | d = {'stride': int(max(model.stride)), 'names': model.names} |
| 167 | for k, v in d.items(): |
| 168 | meta = model_onnx.metadata_props.add() |
| 169 | meta.key, meta.value = k, str(v) |
| 170 | onnx.save(model_onnx, f) |
| 171 | |
| 172 | # Simplify |
| 173 | if simplify: |
| 174 | try: |
| 175 | cuda = torch.cuda.is_available() |
| 176 | check_requirements(('onnxruntime-gpu' if cuda else 'onnxruntime', 'onnx-simplifier>=0.4.1')) |
| 177 | import onnxsim |
| 178 | |
| 179 | LOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...') |
| 180 | model_onnx, check = onnxsim.simplify(model_onnx) |
| 181 | assert check, 'assert check failed' |
| 182 | onnx.save(model_onnx, f) |
| 183 | except Exception as e: |
| 184 | LOGGER.info(f'{prefix} simplifier failure: {e}') |
| 185 | return f, model_onnx |
| 186 | |
| 187 | |
| 188 | @try_export |
no test coverage detected