(self, instantid, insightface, control_net, image, model, positive, negative, start_at, end_at, weight=.8, ip_weight=None, cn_strength=None, noise=0.35, image_kps=None, mask=None, combine_embeds='average')
| 279 | CATEGORY = "InstantID" |
| 280 | |
| 281 | def apply_instantid(self, instantid, insightface, control_net, image, model, positive, negative, start_at, end_at, weight=.8, ip_weight=None, cn_strength=None, noise=0.35, image_kps=None, mask=None, combine_embeds='average'): |
| 282 | dtype = comfy.model_management.unet_dtype() |
| 283 | if dtype not in [torch.float32, torch.float16, torch.bfloat16]: |
| 284 | dtype = torch.float16 if comfy.model_management.should_use_fp16() else torch.float32 |
| 285 | |
| 286 | self.dtype = dtype |
| 287 | self.device = comfy.model_management.get_torch_device() |
| 288 | |
| 289 | ip_weight = weight if ip_weight is None else ip_weight |
| 290 | cn_strength = weight if cn_strength is None else cn_strength |
| 291 | |
| 292 | face_embed = extractFeatures(insightface, image) |
| 293 | if face_embed is None: |
| 294 | raise Exception('Reference Image: No face detected.') |
| 295 | |
| 296 | # if no keypoints image is provided, use the image itself (only the first one in the batch) |
| 297 | face_kps = extractFeatures(insightface, image_kps if image_kps is not None else image[0].unsqueeze(0), extract_kps=True) |
| 298 | |
| 299 | if face_kps is None: |
| 300 | face_kps = torch.zeros_like(image) if image_kps is None else image_kps |
| 301 | print(f"\033[33mWARNING: No face detected in the keypoints image!\033[0m") |
| 302 | |
| 303 | clip_embed = face_embed |
| 304 | # InstantID works better with averaged embeds (TODO: needs testing) |
| 305 | if clip_embed.shape[0] > 1: |
| 306 | if combine_embeds == 'average': |
| 307 | clip_embed = torch.mean(clip_embed, dim=0).unsqueeze(0) |
| 308 | elif combine_embeds == 'norm average': |
| 309 | clip_embed = torch.mean(clip_embed / torch.norm(clip_embed, dim=0, keepdim=True), dim=0).unsqueeze(0) |
| 310 | |
| 311 | if noise > 0: |
| 312 | seed = int(torch.sum(clip_embed).item()) % 1000000007 |
| 313 | torch.manual_seed(seed) |
| 314 | clip_embed_zeroed = noise * torch.rand_like(clip_embed) |
| 315 | #clip_embed_zeroed = add_noise(clip_embed, noise) |
| 316 | else: |
| 317 | clip_embed_zeroed = torch.zeros_like(clip_embed) |
| 318 | |
| 319 | # 1: patch the attention |
| 320 | self.instantid = instantid |
| 321 | self.instantid.to(self.device, dtype=self.dtype) |
| 322 | |
| 323 | image_prompt_embeds, uncond_image_prompt_embeds = self.instantid.get_image_embeds(clip_embed.to(self.device, dtype=self.dtype), clip_embed_zeroed.to(self.device, dtype=self.dtype)) |
| 324 | |
| 325 | image_prompt_embeds = image_prompt_embeds.to(self.device, dtype=self.dtype) |
| 326 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.to(self.device, dtype=self.dtype) |
| 327 | |
| 328 | work_model = model.clone() |
| 329 | |
| 330 | sigma_start = model.get_model_object("model_sampling").percent_to_sigma(start_at) |
| 331 | sigma_end = model.get_model_object("model_sampling").percent_to_sigma(end_at) |
| 332 | |
| 333 | if mask is not None: |
| 334 | mask = mask.to(self.device) |
| 335 | |
| 336 | patch_kwargs = { |
| 337 | "ipadapter": self.instantid, |
| 338 | "weight": ip_weight, |
nothing calls this directly
no test coverage detected