(
modelobj: UNetModel,
onnx_path: str,
weights_map_path: str,
lora_name: str,
profile: ProfileSettings,
)
| 86 | |
| 87 | |
| 88 | def export_lora( |
| 89 | modelobj: UNetModel, |
| 90 | onnx_path: str, |
| 91 | weights_map_path: str, |
| 92 | lora_name: str, |
| 93 | profile: ProfileSettings, |
| 94 | ) -> dict: |
| 95 | info("Exporting to ONNX...") |
| 96 | inputs = modelobj.get_sample_input( |
| 97 | profile.bs_opt * 2, |
| 98 | profile.h_opt // 8, |
| 99 | profile.w_opt // 8, |
| 100 | profile.t_opt, |
| 101 | ) |
| 102 | |
| 103 | with open(weights_map_path, "r") as fp_wts: |
| 104 | print(f"[I] Loading weights map: {weights_map_path} ") |
| 105 | [weights_name_mapping, weights_shape_mapping] = json.load(fp_wts) |
| 106 | |
| 107 | with torch.inference_mode(), torch.autocast("cuda"): |
| 108 | modelobj.unet = apply_lora( |
| 109 | modelobj.unet, os.path.splitext(lora_name)[0], inputs |
| 110 | ) |
| 111 | |
| 112 | refit_dict = get_refit_weights( |
| 113 | modelobj.unet.state_dict(), |
| 114 | onnx_path, |
| 115 | weights_name_mapping, |
| 116 | weights_shape_mapping, |
| 117 | ) |
| 118 | |
| 119 | return refit_dict |
| 120 | |
| 121 | |
| 122 | def swap_sdpa(func): |
no test coverage detected