(file_path, torch_dtype=None)
| 53 | setattr(torch, torch_function_name, old_torch_function) |
| 54 | |
| 55 | def load_state_dict_from_folder(file_path, torch_dtype=None): |
| 56 | state_dict = {} |
| 57 | for file_name in os.listdir(file_path): |
| 58 | if "." in file_name and file_name.split(".")[-1] in [ |
| 59 | "safetensors", "bin", "ckpt", "pth", "pt" |
| 60 | ]: |
| 61 | state_dict.update(load_state_dict(os.path.join(file_path, file_name), torch_dtype=torch_dtype)) |
| 62 | return state_dict |
| 63 | |
| 64 | |
| 65 | def load_state_dict(file_path, torch_dtype=None, device="cpu"): |
no test coverage detected