(path)
| 269 | return name |
| 270 | |
| 271 | def gguf_mmproj_loader(path): |
| 272 | # Reverse version of Qwen2VLVisionModel.modify_tensors |
| 273 | logging.info("Attenpting to find mmproj file for text encoder...") |
| 274 | |
| 275 | # get name to match w/o quant suffix |
| 276 | tenc_fname = os.path.basename(path) |
| 277 | tenc = os.path.splitext(tenc_fname)[0].lower() |
| 278 | tenc = strip_quant_suffix(tenc) |
| 279 | |
| 280 | # try and find matching mmproj |
| 281 | target = [] |
| 282 | root = os.path.dirname(path) |
| 283 | for fname in os.listdir(root): |
| 284 | name, ext = os.path.splitext(fname) |
| 285 | if ext.lower() != ".gguf": |
| 286 | continue |
| 287 | if "mmproj" not in name.lower(): |
| 288 | continue |
| 289 | if tenc in name.lower(): |
| 290 | target.append(fname) |
| 291 | |
| 292 | if len(target) == 0: |
| 293 | logging.error(f"Error: Can't find mmproj file for '{tenc_fname}' (matching:'{tenc}')! Qwen-Image-Edit will be broken!") |
| 294 | return {} |
| 295 | if len(target) > 1: |
| 296 | logging.error(f"Ambiguous mmproj for text encoder '{tenc_fname}', will use first match.") |
| 297 | |
| 298 | logging.info(f"Using mmproj '{target[0]}' for text encoder '{tenc_fname}'.") |
| 299 | target = os.path.join(root, target[0]) |
| 300 | vsd, _ = gguf_sd_loader(target, is_text_model=True) |
| 301 | |
| 302 | # concat 4D to 5D |
| 303 | if "v.patch_embd.weight.1" in vsd: |
| 304 | w1 = dequantize_tensor(vsd.pop("v.patch_embd.weight"), dtype=torch.float32) |
| 305 | w2 = dequantize_tensor(vsd.pop("v.patch_embd.weight.1"), dtype=torch.float32) |
| 306 | vsd["v.patch_embd.weight"] = torch.stack([w1, w2], dim=2) |
| 307 | |
| 308 | # run main replacement |
| 309 | vsd = sd_map_replace(vsd, CLIP_VISION_SD_MAP) |
| 310 | |
| 311 | # handle split Q/K/V |
| 312 | if "visual.blocks.0.attn_q.weight" in vsd: |
| 313 | attns = {} |
| 314 | # filter out attentions + group |
| 315 | for k,v in vsd.items(): |
| 316 | if any(x in k for x in ["attn_q", "attn_k", "attn_v"]): |
| 317 | k_attn, k_name = k.rsplit(".attn_", 1) |
| 318 | k_attn += ".attn.qkv." + k_name.split(".")[-1] |
| 319 | if k_attn not in attns: |
| 320 | attns[k_attn] = {} |
| 321 | attns[k_attn][k_name] = dequantize_tensor( |
| 322 | v, dtype=(torch.bfloat16 if is_quantized(v) else torch.float16) |
| 323 | ) |
| 324 | |
| 325 | # recombine |
| 326 | for k,v in attns.items(): |
| 327 | suffix = k.split(".")[-1] |
| 328 | vsd[k] = torch.cat([ |
no test coverage detected