(self, ipadapter, image, weight, mask=None, clip_vision=None)
| 1385 | CATEGORY = "ipadapter/embeds" |
| 1386 | |
| 1387 | def encode(self, ipadapter, image, weight, mask=None, clip_vision=None): |
| 1388 | if 'ipadapter' in ipadapter: |
| 1389 | ipadapter_model = ipadapter['ipadapter']['model'] |
| 1390 | clip_vision = clip_vision if clip_vision is not None else ipadapter['clipvision']['model'] |
| 1391 | else: |
| 1392 | ipadapter_model = ipadapter |
| 1393 | clip_vision = clip_vision |
| 1394 | |
| 1395 | if clip_vision is None: |
| 1396 | raise Exception("Missing CLIPVision model.") |
| 1397 | |
| 1398 | is_plus = "proj.3.weight" in ipadapter_model["image_proj"] or "latents" in ipadapter_model["image_proj"] or "perceiver_resampler.proj_in.weight" in ipadapter_model["image_proj"] |
| 1399 | is_kwai_kolors = is_plus and "layers.0.0.to_out.weight" in ipadapter_model["image_proj"] and ipadapter_model["image_proj"]["layers.0.0.to_out.weight"].shape[0] == 2048 |
| 1400 | |
| 1401 | clipvision_size = 224 if not is_kwai_kolors else 336 |
| 1402 | |
| 1403 | # resize and crop the mask to 224x224 |
| 1404 | if mask is not None and mask.shape[1:3] != torch.Size([clipvision_size, clipvision_size]): |
| 1405 | mask = mask.unsqueeze(1) |
| 1406 | transforms = T.Compose([ |
| 1407 | T.CenterCrop(min(mask.shape[2], mask.shape[3])), |
| 1408 | T.Resize((clipvision_size, clipvision_size), interpolation=T.InterpolationMode.BICUBIC, antialias=True), |
| 1409 | ]) |
| 1410 | mask = transforms(mask).squeeze(1) |
| 1411 | #mask = T.Resize((image.shape[1], image.shape[2]), interpolation=T.InterpolationMode.BICUBIC, antialias=True)(mask.unsqueeze(1)).squeeze(1) |
| 1412 | |
| 1413 | img_cond_embeds = encode_image_masked(clip_vision, image, mask, clipvision_size=clipvision_size) |
| 1414 | |
| 1415 | if is_plus: |
| 1416 | img_cond_embeds = img_cond_embeds.penultimate_hidden_states |
| 1417 | img_uncond_embeds = encode_image_masked(clip_vision, torch.zeros([1, clipvision_size, clipvision_size, 3]), clipvision_size=clipvision_size).penultimate_hidden_states |
| 1418 | else: |
| 1419 | img_cond_embeds = img_cond_embeds.image_embeds |
| 1420 | img_uncond_embeds = torch.zeros_like(img_cond_embeds) |
| 1421 | |
| 1422 | if weight != 1: |
| 1423 | img_cond_embeds = img_cond_embeds * weight |
| 1424 | |
| 1425 | return (img_cond_embeds, img_uncond_embeds, ) |
| 1426 | |
| 1427 | class IPAdapterCombineEmbeds: |
| 1428 | @classmethod |
nothing calls this directly
no test coverage detected