(args)
| 112 | |
| 113 | @torch.no_grad() |
| 114 | def main(args): |
| 115 | model_finetuned = CogVideoXTransformer3DModel.from_pretrained( |
| 116 | args.finetune_ckpt_path, subfolder=args.finetune_subfolder, torch_dtype=torch.bfloat16 |
| 117 | ) |
| 118 | state_dict_ft = model_finetuned.state_dict() |
| 119 | |
| 120 | # Change the `subfolder` as needed. |
| 121 | base_model = CogVideoXTransformer3DModel.from_pretrained( |
| 122 | args.base_ckpt_path, subfolder=args.base_subfolder, torch_dtype=torch.bfloat16 |
| 123 | ) |
| 124 | state_dict = base_model.state_dict() |
| 125 | output_dict = {} |
| 126 | |
| 127 | for k in tqdm(state_dict, desc="Extracting LoRA..."): |
| 128 | original_param = state_dict[k] |
| 129 | finetuned_param = state_dict_ft[k] |
| 130 | if len(original_param.shape) >= 2: |
| 131 | diff = finetuned_param.float() - original_param.float() |
| 132 | out = extract_lora(diff, RANK) |
| 133 | name = k |
| 134 | |
| 135 | if name.endswith(".weight"): |
| 136 | name = name[: -len(".weight")] |
| 137 | down_key = "{}.lora_A.weight".format(name) |
| 138 | up_key = "{}.lora_B.weight".format(name) |
| 139 | |
| 140 | output_dict[up_key] = out[0].contiguous().to(finetuned_param.dtype) |
| 141 | output_dict[down_key] = out[1].contiguous().to(finetuned_param.dtype) |
| 142 | |
| 143 | prefix = "transformer" if "transformer" in base_model.__class__.__name__.lower() else "unet" |
| 144 | output_dict = {f"{prefix}.{k}": v for k, v in output_dict.items()} |
| 145 | save_file(output_dict, args.lora_out_path) |
| 146 | print(f"LoRA saved and it contains {len(output_dict)} keys.") |
| 147 | |
| 148 | |
| 149 | if __name__ == "__main__": |
no test coverage detected
searching dependent graphs…