Read state dict as fake tensors
(path, handle_prefix="model.diffusion_model.", is_text_model=False)
| 68 | return metadata |
| 69 | |
| 70 | def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", is_text_model=False): |
| 71 | """ |
| 72 | Read state dict as fake tensors |
| 73 | """ |
| 74 | reader = gguf.GGUFReader(path) |
| 75 | |
| 76 | # filter and strip prefix |
| 77 | has_prefix = False |
| 78 | if handle_prefix is not None: |
| 79 | prefix_len = len(handle_prefix) |
| 80 | tensor_names = set(tensor.name for tensor in reader.tensors) |
| 81 | has_prefix = any(s.startswith(handle_prefix) for s in tensor_names) |
| 82 | |
| 83 | tensors = [] |
| 84 | for tensor in reader.tensors: |
| 85 | sd_key = tensor_name = tensor.name |
| 86 | if has_prefix: |
| 87 | if not tensor_name.startswith(handle_prefix): |
| 88 | continue |
| 89 | sd_key = tensor_name[prefix_len:] |
| 90 | tensors.append((sd_key, tensor)) |
| 91 | |
| 92 | # detect and verify architecture |
| 93 | compat = None |
| 94 | arch_str = get_field(reader, "general.architecture", str) |
| 95 | type_str = get_field(reader, "general.type", str) |
| 96 | if arch_str in [None, "pig", "cow"]: |
| 97 | if is_text_model: |
| 98 | raise ValueError(f"This gguf file is incompatible with llama.cpp!\nConsider using safetensors or a compatible gguf file\n({path})") |
| 99 | compat = "sd.cpp" if arch_str is None else arch_str |
| 100 | # import here to avoid changes to convert.py breaking regular models |
| 101 | from .tools.convert import detect_arch |
| 102 | try: |
| 103 | arch_str = detect_arch(set(val[0] for val in tensors)).arch |
| 104 | except Exception as e: |
| 105 | raise ValueError(f"This model is not currently supported - ({e})") |
| 106 | elif arch_str not in TXT_ARCH_LIST and is_text_model: |
| 107 | if type_str not in VIS_TYPE_LIST: |
| 108 | raise ValueError(f"Unexpected text model architecture type in GGUF file: {arch_str!r}") |
| 109 | elif arch_str not in IMG_ARCH_LIST and not is_text_model: |
| 110 | raise ValueError(f"Unexpected architecture type in GGUF file: {arch_str!r}") |
| 111 | |
| 112 | if compat: |
| 113 | logging.warning(f"Warning: This gguf model file is loaded in compatibility mode '{compat}' [arch:{arch_str}]") |
| 114 | |
| 115 | # main loading loop |
| 116 | state_dict = {} |
| 117 | qtype_dict = {} |
| 118 | for sd_key, tensor in tensors: |
| 119 | tensor_name = tensor.name |
| 120 | # torch_tensor = torch.from_numpy(tensor.data) # mmap |
| 121 | |
| 122 | # NOTE: line above replaced with this block to avoid persistent numpy warning about mmap |
| 123 | with warnings.catch_warnings(): |
| 124 | warnings.filterwarnings("ignore", message="The given NumPy array is not writable") |
| 125 | torch_tensor = torch.from_numpy(tensor.data) # mmap |
| 126 | |
| 127 | shape = get_orig_shape(reader, tensor_name) |
no test coverage detected