Internal: torchscripts. Args: model: Model instance or model name. path: TODO. device: Target device ("cuda:0", "cpu", etc.).
(model, path, device="cuda")
| 137 | |
| 138 | |
| 139 | def _torchscripts(model, path, device="cuda"): |
| 140 | """Internal: torchscripts. |
| 141 | |
| 142 | Args: |
| 143 | model: Model instance or model name. |
| 144 | path: TODO. |
| 145 | device: Target device ("cuda:0", "cpu", etc.). |
| 146 | """ |
| 147 | dummy_input = model.export_dummy_inputs() |
| 148 | |
| 149 | if device == "cuda": |
| 150 | model = model.cuda() |
| 151 | if isinstance(dummy_input, torch.Tensor): |
| 152 | dummy_input = dummy_input.cuda() |
| 153 | else: |
| 154 | dummy_input = tuple([i.cuda() for i in dummy_input]) |
| 155 | |
| 156 | model_script = torch.jit.trace(model, dummy_input) |
| 157 | if isinstance(model.export_name, str): |
| 158 | model_script.save(os.path.join(path, f"{model.export_name}".replace("onnx", "torchscript"))) |
| 159 | else: |
| 160 | model_script.save( |
| 161 | os.path.join(path, f"{model.export_name()}".replace("onnx", "torchscript")) |
| 162 | ) |
| 163 | |
| 164 | |
| 165 | def _bladedisc_opt(model, model_inputs, enable_fp16=True): |
no test coverage detected
searching dependent graphs…