Internal: bladedisc opt. Args: model: Model instance or model name. model_inputs: TODO. enable_fp16: TODO.
(model, model_inputs, enable_fp16=True)
| 163 | |
| 164 | |
| 165 | def _bladedisc_opt(model, model_inputs, enable_fp16=True): |
| 166 | """Internal: bladedisc opt. |
| 167 | |
| 168 | Args: |
| 169 | model: Model instance or model name. |
| 170 | model_inputs: TODO. |
| 171 | enable_fp16: TODO. |
| 172 | """ |
| 173 | model = model.eval() |
| 174 | try: |
| 175 | import torch_blade |
| 176 | except Exception as e: |
| 177 | print( |
| 178 | f"Warning, if you are exporting bladedisc, please install it and try it again: pip install -U torch_blade\n" |
| 179 | ) |
| 180 | torch_config = torch_blade.config.Config() |
| 181 | torch_config.enable_fp16 = enable_fp16 |
| 182 | with torch.no_grad(), torch_config: |
| 183 | opt_model = torch_blade.optimize( |
| 184 | model, |
| 185 | allow_tracing=True, |
| 186 | model_inputs=model_inputs, |
| 187 | ) |
| 188 | return opt_model |
| 189 | |
| 190 | |
| 191 | def _rescale_input_hook(m, x, scale): |
no test coverage detected
searching dependent graphs…