Export model to ONNX format. Creates a deep copy of the model to isolate ONNX operator monkey-patching, then runs torch.onnx.export. The original model remains usable after export. Args: input: Sample input for tracing (auto-generated if None). **cfg
(self, input=None, **cfg)
| 929 | return results_ret_list |
| 930 | |
| 931 | def export(self, input=None, **cfg): |
| 932 | """Export model to ONNX format. |
| 933 | |
| 934 | Creates a deep copy of the model to isolate ONNX operator monkey-patching, |
| 935 | then runs torch.onnx.export. The original model remains usable after export. |
| 936 | |
| 937 | Args: |
| 938 | input: Sample input for tracing (auto-generated if None). |
| 939 | **cfg: Export parameters: |
| 940 | - type (str): Export format, "onnx" (default). |
| 941 | - quantize (bool): Whether to quantize the model. |
| 942 | - device (str): Device for export. |
| 943 | |
| 944 | Returns: |
| 945 | str: Path to the exported model directory. |
| 946 | """ |
| 947 | """ |
| 948 | |
| 949 | :param input: |
| 950 | :param type: |
| 951 | :param quantize: |
| 952 | :param fallback_num: |
| 953 | :param calib_num: |
| 954 | :param opset_version: |
| 955 | :param cfg: |
| 956 | :return: |
| 957 | """ |
| 958 | |
| 959 | device = cfg.get("device", "cpu") |
| 960 | |
| 961 | # 对模型进行深拷贝,隔离 ONNX 算子替换(Monkey-patching)对原模型的破坏 |
| 962 | # Implement deep copy of the model to isolate ONNX operator monkey-patching |
| 963 | # and prevent corruption of the original model |
| 964 | model = copy.deepcopy(self.model).to(device=device) |
| 965 | |
| 966 | # 对配置参数进行深拷贝,隔离 deep_update 和 del 的引用污染 |
| 967 | # Implement deep copy of configuration parameters to isolate reference pollution caused by deep_update and del. |
| 968 | kwargs = copy.deepcopy(self.kwargs) |
| 969 | |
| 970 | deep_update(kwargs, cfg) |
| 971 | kwargs["device"] = device |
| 972 | |
| 973 | # Safely delete keys that may cause issues during export |
| 974 | if "model" in kwargs: |
| 975 | del kwargs["model"] |
| 976 | |
| 977 | model.eval() |
| 978 | |
| 979 | type = kwargs.get("type", "onnx") |
| 980 | |
| 981 | key_list, data_list = prepare_data_iterator( |
| 982 | input, input_len=None, data_type=kwargs.get("data_type", None), key=None |
| 983 | ) |
| 984 | |
| 985 | with torch.no_grad(): |
| 986 | # 这里的导出操作只会魔改 model 副本,原实例的 self.model 依然是纯洁的 PyTorch 图 |
| 987 | # This export operation only mutates the model copy; |
| 988 | # the original self.model instance remains an intact PyTorch graph. |