MCPcopy
hub / github.com/cubiq/ComfyUI_InstantID / patch_attention

Method patch_attention

InstantID.py:448–516  ·  view source on GitHub ↗
(self, instantid, insightface, image, model, weight, start_at, end_at, noise=0.0, mask=None)

Source from the content-addressed store, hash-verified

446 CATEGORY = "InstantID"
447
448 def patch_attention(self, instantid, insightface, image, model, weight, start_at, end_at, noise=0.0, mask=None):
449 self.dtype = torch.float16 if comfy.model_management.should_use_fp16() else torch.float32
450 self.device = comfy.model_management.get_torch_device()
451
452 face_embed = extractFeatures(insightface, image)
453 if face_embed is None:
454 raise Exception('Reference Image: No face detected.')
455
456 clip_embed = face_embed
457 # InstantID works better with averaged embeds (TODO: needs testing)
458 if clip_embed.shape[0] > 1:
459 clip_embed = torch.mean(clip_embed, dim=0).unsqueeze(0)
460
461 if noise > 0:
462 seed = int(torch.sum(clip_embed).item()) % 1000000007
463 torch.manual_seed(seed)
464 clip_embed_zeroed = noise * torch.rand_like(clip_embed)
465 else:
466 clip_embed_zeroed = torch.zeros_like(clip_embed)
467
468 # 1: patch the attention
469 self.instantid = instantid
470 self.instantid.to(self.device, dtype=self.dtype)
471
472 image_prompt_embeds, uncond_image_prompt_embeds = self.instantid.get_image_embeds(clip_embed.to(self.device, dtype=self.dtype), clip_embed_zeroed.to(self.device, dtype=self.dtype))
473
474 image_prompt_embeds = image_prompt_embeds.to(self.device, dtype=self.dtype)
475 uncond_image_prompt_embeds = uncond_image_prompt_embeds.to(self.device, dtype=self.dtype)
476
477 if weight == 0:
478 return (model, { "cond": image_prompt_embeds, "uncond": uncond_image_prompt_embeds } )
479
480 work_model = model.clone()
481
482 sigma_start = model.get_model_object("model_sampling").percent_to_sigma(start_at)
483 sigma_end = model.get_model_object("model_sampling").percent_to_sigma(end_at)
484
485 if mask is not None:
486 mask = mask.to(self.device)
487
488 patch_kwargs = {
489 "weight": weight,
490 "ipadapter": self.instantid,
491 "cond": image_prompt_embeds,
492 "uncond": uncond_image_prompt_embeds,
493 "mask": mask,
494 "sigma_start": sigma_start,
495 "sigma_end": sigma_end,
496 }
497
498 number = 0
499 for id in [4,5,7,8]: # id of input_blocks that have cross attention
500 block_indices = range(2) if id in [4, 5] else range(10) # transformer_depth
501 for index in block_indices:
502 patch_kwargs["module_key"] = str(number*2+1)
503 _set_model_patch_replace(work_model, patch_kwargs, ("input", id, index))
504 number += 1
505 for id in range(6): # id of output_blocks that have cross attention

Callers

nothing calls this directly

Calls 3

extractFeaturesFunction · 0.85
_set_model_patch_replaceFunction · 0.85
get_image_embedsMethod · 0.80

Tested by

no test coverage detected