MCPcopy Index your code
hub / github.com/huggingface/diffusers / load_state_dict

Function load_state_dict

src/diffusers/models/model_loading_utils.py:155–210  ·  view source on GitHub ↗

Reads a checkpoint file, returning properly formatted errors if they arise.

(
    checkpoint_file: str | os.PathLike,
    dduf_entries: dict[str, DDUFEntry] | None = None,
    disable_mmap: bool = False,
    map_location: str | torch.device = "cpu",
)

Source from the content-addressed store, hash-verified

153
154
155def load_state_dict(
156 checkpoint_file: str | os.PathLike,
157 dduf_entries: dict[str, DDUFEntry] | None = None,
158 disable_mmap: bool = False,
159 map_location: str | torch.device = "cpu",
160):
161 """
162 Reads a checkpoint file, returning properly formatted errors if they arise.
163 """
164 # TODO: maybe refactor a bit this part where we pass a dict here
165 if isinstance(checkpoint_file, dict):
166 return checkpoint_file
167 try:
168 file_extension = os.path.basename(checkpoint_file).split(".")[-1]
169 if file_extension == SAFETENSORS_FILE_EXTENSION:
170 if dduf_entries:
171 # tensors are loaded on cpu
172 with dduf_entries[checkpoint_file].as_mmap() as mm:
173 return safetensors.torch.load(mm)
174 if disable_mmap:
175 return safetensors.torch.load(open(checkpoint_file, "rb").read())
176 else:
177 return safetensors.torch.load_file(checkpoint_file, device=map_location)
178 elif file_extension == GGUF_FILE_EXTENSION:
179 return load_gguf_checkpoint(checkpoint_file)
180 else:
181 extra_args = {}
182 weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
183 # mmap can only be used with files serialized with zipfile-based format.
184 if (
185 isinstance(checkpoint_file, str)
186 and map_location != "meta"
187 and is_torch_version(">=", "2.1.0")
188 and is_zipfile(checkpoint_file)
189 and not disable_mmap
190 ):
191 extra_args = {"mmap": True}
192 return torch.load(checkpoint_file, map_location=map_location, **weights_only_kwarg, **extra_args)
193 except Exception as e:
194 try:
195 with open(checkpoint_file) as f:
196 if f.read().startswith("version"):
197 raise OSError(
198 "You seem to have cloned a repository without having git-lfs installed. Please install "
199 "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
200 "you cloned."
201 )
202 else:
203 raise ValueError(
204 f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
205 "model. Make sure you have saved the model properly."
206 ) from e
207 except (UnicodeDecodeError, ValueError):
208 raise OSError(
209 f"Unable to load weights from checkpoint file for '{checkpoint_file}' at '{checkpoint_file}'. "
210 )
211
212

Callers 11

load_attn_procsMethod · 0.85
_fetch_state_dictFunction · 0.85
load_ip_adapterMethod · 0.85
load_ip_adapterMethod · 0.85
load_ip_adapterMethod · 0.85
load_ip_adapterMethod · 0.85
from_pretrainedMethod · 0.85
_load_shard_fileFunction · 0.85
from_pretrainedMethod · 0.85

Calls 4

load_gguf_checkpointFunction · 0.85
is_torch_versionFunction · 0.85
splitMethod · 0.80
loadMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…