MCPcopy
hub / github.com/cubiq/ComfyUI_InstantID / instantid_attention

Function instantid_attention

CrossAttentionPatch.py:30–190  ·  view source on GitHub ↗
(out, q, k, v, extra_options, module_key='', ipadapter=None, weight=1.0, cond=None, cond_alt=None, uncond=None, weight_type="linear", mask=None, sigma_start=0.0, sigma_end=1.0, unfold_batch=False, embeds_scaling='V only', **kwargs)

Source from the content-addressed store, hash-verified

28 return out.to(dtype=dtype)
29
30def instantid_attention(out, q, k, v, extra_options, module_key='', ipadapter=None, weight=1.0, cond=None, cond_alt=None, uncond=None, weight_type="linear", mask=None, sigma_start=0.0, sigma_end=1.0, unfold_batch=False, embeds_scaling='V only', **kwargs):
31 dtype = q.dtype
32 cond_or_uncond = extra_options["cond_or_uncond"]
33 block_type = extra_options["block"][0]
34 #block_id = extra_options["block"][1]
35 t_idx = extra_options["transformer_index"]
36 layers = 11 if '101_to_k_ip' in ipadapter.ip_layers.to_kvs else 16
37 k_key = module_key + "_to_k_ip"
38 v_key = module_key + "_to_v_ip"
39
40 # extra options for AnimateDiff
41 ad_params = extra_options['ad_params'] if "ad_params" in extra_options else None
42
43 b = q.shape[0]
44 seq_len = q.shape[1]
45 batch_prompt = b // len(cond_or_uncond)
46 _, _, oh, ow = extra_options["original_shape"]
47
48 if weight_type == 'ease in':
49 weight = weight * (0.05 + 0.95 * (1 - t_idx / layers))
50 elif weight_type == 'ease out':
51 weight = weight * (0.05 + 0.95 * (t_idx / layers))
52 elif weight_type == 'ease in-out':
53 weight = weight * (0.05 + 0.95 * (1 - abs(t_idx - (layers/2)) / (layers/2)))
54 elif weight_type == 'reverse in-out':
55 weight = weight * (0.05 + 0.95 * (abs(t_idx - (layers/2)) / (layers/2)))
56 elif weight_type == 'weak input' and block_type == 'input':
57 weight = weight * 0.2
58 elif weight_type == 'weak middle' and block_type == 'middle':
59 weight = weight * 0.2
60 elif weight_type == 'weak output' and block_type == 'output':
61 weight = weight * 0.2
62 elif weight_type == 'strong middle' and (block_type == 'input' or block_type == 'output'):
63 weight = weight * 0.2
64 elif isinstance(weight, dict):
65 if t_idx not in weight:
66 return 0
67
68 weight = weight[t_idx]
69
70 if cond_alt is not None and t_idx in cond_alt:
71 cond = cond_alt[t_idx]
72 del cond_alt
73
74 if unfold_batch:
75 # Check AnimateDiff context window
76 if ad_params is not None and ad_params["sub_idxs"] is not None:
77 if isinstance(weight, torch.Tensor):
78 weight = tensor_to_size(weight, ad_params["full_length"])
79 weight = torch.Tensor(weight[ad_params["sub_idxs"]])
80 if torch.all(weight == 0):
81 return 0
82 weight = weight.repeat(len(cond_or_uncond), 1, 1) # repeat for cond and uncond
83 elif weight == 0:
84 return 0
85
86 # if image length matches or exceeds full_length get sub_idx images
87 if cond.shape[0] >= ad_params["full_length"]:

Callers

nothing calls this directly

Calls 1

tensor_to_sizeFunction · 0.85

Tested by

no test coverage detected