(
shard_file,
model,
model_state_dict,
device_map=None,
dtype=None,
hf_quantizer=None,
keep_in_fp32_modules=None,
dduf_entries=None,
loaded_keys=None,
unexpected_keys=None,
offload_index=None,
offload_folder=None,
state_dict_index=None,
state_dict_folder=None,
ignore_mismatched_sizes=False,
low_cpu_mem_usage=False,
disable_mmap=False,
)
| 338 | |
| 339 | |
| 340 | def _load_shard_file( |
| 341 | shard_file, |
| 342 | model, |
| 343 | model_state_dict, |
| 344 | device_map=None, |
| 345 | dtype=None, |
| 346 | hf_quantizer=None, |
| 347 | keep_in_fp32_modules=None, |
| 348 | dduf_entries=None, |
| 349 | loaded_keys=None, |
| 350 | unexpected_keys=None, |
| 351 | offload_index=None, |
| 352 | offload_folder=None, |
| 353 | state_dict_index=None, |
| 354 | state_dict_folder=None, |
| 355 | ignore_mismatched_sizes=False, |
| 356 | low_cpu_mem_usage=False, |
| 357 | disable_mmap=False, |
| 358 | ): |
| 359 | state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries, disable_mmap=disable_mmap) |
| 360 | mismatched_keys = _find_mismatched_keys( |
| 361 | state_dict, |
| 362 | model_state_dict, |
| 363 | loaded_keys, |
| 364 | ignore_mismatched_sizes, |
| 365 | ) |
| 366 | error_msgs = [] |
| 367 | if low_cpu_mem_usage: |
| 368 | offload_index, state_dict_index = load_model_dict_into_meta( |
| 369 | model, |
| 370 | state_dict, |
| 371 | device_map=device_map, |
| 372 | dtype=dtype, |
| 373 | hf_quantizer=hf_quantizer, |
| 374 | keep_in_fp32_modules=keep_in_fp32_modules, |
| 375 | unexpected_keys=unexpected_keys, |
| 376 | offload_folder=offload_folder, |
| 377 | offload_index=offload_index, |
| 378 | state_dict_index=state_dict_index, |
| 379 | state_dict_folder=state_dict_folder, |
| 380 | ) |
| 381 | else: |
| 382 | assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict) |
| 383 | |
| 384 | error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers) |
| 385 | return offload_index, state_dict_index, mismatched_keys, error_msgs |
| 386 | |
| 387 | |
| 388 | def _load_shard_files_with_threadpool( |
nothing calls this directly
no test coverage detected
searching dependent graphs…