(self, instantid, insightface, image, model, weight, start_at, end_at, noise=0.0, mask=None)
| 446 | CATEGORY = "InstantID" |
| 447 | |
| 448 | def patch_attention(self, instantid, insightface, image, model, weight, start_at, end_at, noise=0.0, mask=None): |
| 449 | self.dtype = torch.float16 if comfy.model_management.should_use_fp16() else torch.float32 |
| 450 | self.device = comfy.model_management.get_torch_device() |
| 451 | |
| 452 | face_embed = extractFeatures(insightface, image) |
| 453 | if face_embed is None: |
| 454 | raise Exception('Reference Image: No face detected.') |
| 455 | |
| 456 | clip_embed = face_embed |
| 457 | # InstantID works better with averaged embeds (TODO: needs testing) |
| 458 | if clip_embed.shape[0] > 1: |
| 459 | clip_embed = torch.mean(clip_embed, dim=0).unsqueeze(0) |
| 460 | |
| 461 | if noise > 0: |
| 462 | seed = int(torch.sum(clip_embed).item()) % 1000000007 |
| 463 | torch.manual_seed(seed) |
| 464 | clip_embed_zeroed = noise * torch.rand_like(clip_embed) |
| 465 | else: |
| 466 | clip_embed_zeroed = torch.zeros_like(clip_embed) |
| 467 | |
| 468 | # 1: patch the attention |
| 469 | self.instantid = instantid |
| 470 | self.instantid.to(self.device, dtype=self.dtype) |
| 471 | |
| 472 | image_prompt_embeds, uncond_image_prompt_embeds = self.instantid.get_image_embeds(clip_embed.to(self.device, dtype=self.dtype), clip_embed_zeroed.to(self.device, dtype=self.dtype)) |
| 473 | |
| 474 | image_prompt_embeds = image_prompt_embeds.to(self.device, dtype=self.dtype) |
| 475 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.to(self.device, dtype=self.dtype) |
| 476 | |
| 477 | if weight == 0: |
| 478 | return (model, { "cond": image_prompt_embeds, "uncond": uncond_image_prompt_embeds } ) |
| 479 | |
| 480 | work_model = model.clone() |
| 481 | |
| 482 | sigma_start = model.get_model_object("model_sampling").percent_to_sigma(start_at) |
| 483 | sigma_end = model.get_model_object("model_sampling").percent_to_sigma(end_at) |
| 484 | |
| 485 | if mask is not None: |
| 486 | mask = mask.to(self.device) |
| 487 | |
| 488 | patch_kwargs = { |
| 489 | "weight": weight, |
| 490 | "ipadapter": self.instantid, |
| 491 | "cond": image_prompt_embeds, |
| 492 | "uncond": uncond_image_prompt_embeds, |
| 493 | "mask": mask, |
| 494 | "sigma_start": sigma_start, |
| 495 | "sigma_end": sigma_end, |
| 496 | } |
| 497 | |
| 498 | number = 0 |
| 499 | for id in [4,5,7,8]: # id of input_blocks that have cross attention |
| 500 | block_indices = range(2) if id in [4, 5] else range(10) # transformer_depth |
| 501 | for index in block_indices: |
| 502 | patch_kwargs["module_key"] = str(number*2+1) |
| 503 | _set_model_patch_replace(work_model, patch_kwargs, ("input", id, index)) |
| 504 | number += 1 |
| 505 | for id in range(6): # id of output_blocks that have cross attention |
nothing calls this directly
no test coverage detected