| 882 | |
| 883 | |
| 884 | def load_hub_or_local_checkpoint(repo_id: str | None = None, filename: str | None = None) -> dict[str, Any]: |
| 885 | if repo_id is None and filename is None: |
| 886 | raise ValueError("Please supply at least one of `repo_id` or `filename`") |
| 887 | |
| 888 | if repo_id is not None: |
| 889 | if filename is None: |
| 890 | raise ValueError("If repo_id is specified, filename must also be specified.") |
| 891 | ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename) |
| 892 | else: |
| 893 | ckpt_path = filename |
| 894 | |
| 895 | _, ext = os.path.splitext(ckpt_path) |
| 896 | if ext in [".safetensors", ".sft"]: |
| 897 | state_dict = safetensors.torch.load_file(ckpt_path) |
| 898 | else: |
| 899 | state_dict = torch.load(ckpt_path, map_location="cpu") |
| 900 | |
| 901 | return state_dict |
| 902 | |
| 903 | |
| 904 | def get_model_state_dict_from_combined_ckpt(combined_ckpt: dict[str, Any], prefix: str) -> dict[str, Any]: |