MCPcopy Index your code
hub / github.com/huggingface/diffusers / main

Function main

scripts/extract_lora_from_model.py:114–146  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

112
113@torch.no_grad()
114def main(args):
115 model_finetuned = CogVideoXTransformer3DModel.from_pretrained(
116 args.finetune_ckpt_path, subfolder=args.finetune_subfolder, torch_dtype=torch.bfloat16
117 )
118 state_dict_ft = model_finetuned.state_dict()
119
120 # Change the `subfolder` as needed.
121 base_model = CogVideoXTransformer3DModel.from_pretrained(
122 args.base_ckpt_path, subfolder=args.base_subfolder, torch_dtype=torch.bfloat16
123 )
124 state_dict = base_model.state_dict()
125 output_dict = {}
126
127 for k in tqdm(state_dict, desc="Extracting LoRA..."):
128 original_param = state_dict[k]
129 finetuned_param = state_dict_ft[k]
130 if len(original_param.shape) >= 2:
131 diff = finetuned_param.float() - original_param.float()
132 out = extract_lora(diff, RANK)
133 name = k
134
135 if name.endswith(".weight"):
136 name = name[: -len(".weight")]
137 down_key = "{}.lora_A.weight".format(name)
138 up_key = "{}.lora_B.weight".format(name)
139
140 output_dict[up_key] = out[0].contiguous().to(finetuned_param.dtype)
141 output_dict[down_key] = out[1].contiguous().to(finetuned_param.dtype)
142
143 prefix = "transformer" if "transformer" in base_model.__class__.__name__.lower() else "unet"
144 output_dict = {f"{prefix}.{k}": v for k, v in output_dict.items()}
145 save_file(output_dict, args.lora_out_path)
146 print(f"LoRA saved and it contains {len(output_dict)} keys.")
147
148
149if __name__ == "__main__":

Callers 1

Calls 5

extract_loraFunction · 0.85
floatMethod · 0.80
from_pretrainedMethod · 0.45
state_dictMethod · 0.45
toMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…