Reads a checkpoint file, returning properly formatted errors if they arise.
(
checkpoint_file: str | os.PathLike,
dduf_entries: dict[str, DDUFEntry] | None = None,
disable_mmap: bool = False,
map_location: str | torch.device = "cpu",
)
| 153 | |
| 154 | |
| 155 | def load_state_dict( |
| 156 | checkpoint_file: str | os.PathLike, |
| 157 | dduf_entries: dict[str, DDUFEntry] | None = None, |
| 158 | disable_mmap: bool = False, |
| 159 | map_location: str | torch.device = "cpu", |
| 160 | ): |
| 161 | """ |
| 162 | Reads a checkpoint file, returning properly formatted errors if they arise. |
| 163 | """ |
| 164 | # TODO: maybe refactor a bit this part where we pass a dict here |
| 165 | if isinstance(checkpoint_file, dict): |
| 166 | return checkpoint_file |
| 167 | try: |
| 168 | file_extension = os.path.basename(checkpoint_file).split(".")[-1] |
| 169 | if file_extension == SAFETENSORS_FILE_EXTENSION: |
| 170 | if dduf_entries: |
| 171 | # tensors are loaded on cpu |
| 172 | with dduf_entries[checkpoint_file].as_mmap() as mm: |
| 173 | return safetensors.torch.load(mm) |
| 174 | if disable_mmap: |
| 175 | return safetensors.torch.load(open(checkpoint_file, "rb").read()) |
| 176 | else: |
| 177 | return safetensors.torch.load_file(checkpoint_file, device=map_location) |
| 178 | elif file_extension == GGUF_FILE_EXTENSION: |
| 179 | return load_gguf_checkpoint(checkpoint_file) |
| 180 | else: |
| 181 | extra_args = {} |
| 182 | weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {} |
| 183 | # mmap can only be used with files serialized with zipfile-based format. |
| 184 | if ( |
| 185 | isinstance(checkpoint_file, str) |
| 186 | and map_location != "meta" |
| 187 | and is_torch_version(">=", "2.1.0") |
| 188 | and is_zipfile(checkpoint_file) |
| 189 | and not disable_mmap |
| 190 | ): |
| 191 | extra_args = {"mmap": True} |
| 192 | return torch.load(checkpoint_file, map_location=map_location, **weights_only_kwarg, **extra_args) |
| 193 | except Exception as e: |
| 194 | try: |
| 195 | with open(checkpoint_file) as f: |
| 196 | if f.read().startswith("version"): |
| 197 | raise OSError( |
| 198 | "You seem to have cloned a repository without having git-lfs installed. Please install " |
| 199 | "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder " |
| 200 | "you cloned." |
| 201 | ) |
| 202 | else: |
| 203 | raise ValueError( |
| 204 | f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained " |
| 205 | "model. Make sure you have saved the model properly." |
| 206 | ) from e |
| 207 | except (UnicodeDecodeError, ValueError): |
| 208 | raise OSError( |
| 209 | f"Unable to load weights from checkpoint file for '{checkpoint_file}' at '{checkpoint_file}'. " |
| 210 | ) |
| 211 | |
| 212 |
no test coverage detected
searching dependent graphs…