Helper method to load the module `name` from `library_name` and `class_name`
(
library_name: str,
class_name: str,
importable_classes: list[Any],
pipelines: Any,
is_pipeline_module: bool,
pipeline_class: Any,
torch_dtype: torch.dtype,
provider: Any,
sess_options: Any,
device_map: dict[str, torch.device] | str | None,
max_memory: dict[int | str, int | str] | None,
offload_folder: str | os.PathLike | None,
offload_state_dict: bool,
model_variants: dict[str, str],
name: str,
from_flax: bool,
variant: str,
low_cpu_mem_usage: bool,
cached_folder: str | os.PathLike,
use_safetensors: bool,
dduf_entries: dict[str, DDUFEntry] | None,
provider_options: Any,
disable_mmap: bool,
quantization_config: Any | None = None,
use_flashpack: bool = False,
trust_remote_code: bool = False,
)
| 755 | |
| 756 | |
| 757 | def load_sub_model( |
| 758 | library_name: str, |
| 759 | class_name: str, |
| 760 | importable_classes: list[Any], |
| 761 | pipelines: Any, |
| 762 | is_pipeline_module: bool, |
| 763 | pipeline_class: Any, |
| 764 | torch_dtype: torch.dtype, |
| 765 | provider: Any, |
| 766 | sess_options: Any, |
| 767 | device_map: dict[str, torch.device] | str | None, |
| 768 | max_memory: dict[int | str, int | str] | None, |
| 769 | offload_folder: str | os.PathLike | None, |
| 770 | offload_state_dict: bool, |
| 771 | model_variants: dict[str, str], |
| 772 | name: str, |
| 773 | from_flax: bool, |
| 774 | variant: str, |
| 775 | low_cpu_mem_usage: bool, |
| 776 | cached_folder: str | os.PathLike, |
| 777 | use_safetensors: bool, |
| 778 | dduf_entries: dict[str, DDUFEntry] | None, |
| 779 | provider_options: Any, |
| 780 | disable_mmap: bool, |
| 781 | quantization_config: Any | None = None, |
| 782 | use_flashpack: bool = False, |
| 783 | trust_remote_code: bool = False, |
| 784 | ): |
| 785 | """Helper method to load the module `name` from `library_name` and `class_name`""" |
| 786 | from ..quantizers import PipelineQuantizationConfig |
| 787 | |
| 788 | # retrieve class candidates |
| 789 | |
| 790 | class_obj, class_candidates = get_class_obj_and_candidates( |
| 791 | library_name, |
| 792 | class_name, |
| 793 | importable_classes, |
| 794 | pipelines, |
| 795 | is_pipeline_module, |
| 796 | component_name=name, |
| 797 | cache_dir=cached_folder, |
| 798 | trust_remote_code=trust_remote_code, |
| 799 | ) |
| 800 | |
| 801 | load_method_name = None |
| 802 | # retrieve load method name |
| 803 | for class_name, class_candidate in class_candidates.items(): |
| 804 | if class_candidate is not None and issubclass(class_obj, class_candidate): |
| 805 | load_method_name = importable_classes[class_name][1] |
| 806 | |
| 807 | # if load method name is None, then we have a dummy module -> raise Error |
| 808 | if load_method_name is None: |
| 809 | none_module = class_obj.__module__ |
| 810 | is_dummy_path = none_module.startswith(DUMMY_MODULES_FOLDER) or none_module.startswith( |
| 811 | TRANSFORMERS_DUMMY_MODULES_FOLDER |
| 812 | ) |
| 813 | if is_dummy_path and "dummy" in none_module: |
| 814 | # call class_obj for nice error message of missing requirements |
no test coverage detected
searching dependent graphs…