()
| 73 | |
| 74 | |
| 75 | def parse_args(): |
| 76 | parser = argparse.ArgumentParser() |
| 77 | parser.add_argument( |
| 78 | "--base_ckpt_path", |
| 79 | default=None, |
| 80 | type=str, |
| 81 | required=True, |
| 82 | help="Base checkpoint path from which the model was finetuned. Can be a model ID on the Hub.", |
| 83 | ) |
| 84 | parser.add_argument( |
| 85 | "--base_subfolder", |
| 86 | default="transformer", |
| 87 | type=str, |
| 88 | help="subfolder to load the base checkpoint from if any.", |
| 89 | ) |
| 90 | parser.add_argument( |
| 91 | "--finetune_ckpt_path", |
| 92 | default=None, |
| 93 | type=str, |
| 94 | required=True, |
| 95 | help="Fully fine-tuned checkpoint path. Can be a model ID on the Hub.", |
| 96 | ) |
| 97 | parser.add_argument( |
| 98 | "--finetune_subfolder", |
| 99 | default=None, |
| 100 | type=str, |
| 101 | help="subfolder to load the fulle finetuned checkpoint from if any.", |
| 102 | ) |
| 103 | parser.add_argument("--rank", default=64, type=int) |
| 104 | parser.add_argument("--lora_out_path", default=None, type=str, required=True) |
| 105 | args = parser.parse_args() |
| 106 | |
| 107 | if not args.lora_out_path.endswith(".safetensors"): |
| 108 | raise ValueError("`lora_out_path` must end with `.safetensors`.") |
| 109 | |
| 110 | return args |
| 111 | |
| 112 | |
| 113 | @torch.no_grad() |
no outgoing calls
no test coverage detected
searching dependent graphs…