(out, q, k, v, extra_options, module_key='', ipadapter=None, weight=1.0, cond=None, cond_alt=None, uncond=None, weight_type="linear", mask=None, sigma_start=0.0, sigma_end=1.0, unfold_batch=False, embeds_scaling='V only', **kwargs)
| 28 | return out.to(dtype=dtype) |
| 29 | |
| 30 | def instantid_attention(out, q, k, v, extra_options, module_key='', ipadapter=None, weight=1.0, cond=None, cond_alt=None, uncond=None, weight_type="linear", mask=None, sigma_start=0.0, sigma_end=1.0, unfold_batch=False, embeds_scaling='V only', **kwargs): |
| 31 | dtype = q.dtype |
| 32 | cond_or_uncond = extra_options["cond_or_uncond"] |
| 33 | block_type = extra_options["block"][0] |
| 34 | #block_id = extra_options["block"][1] |
| 35 | t_idx = extra_options["transformer_index"] |
| 36 | layers = 11 if '101_to_k_ip' in ipadapter.ip_layers.to_kvs else 16 |
| 37 | k_key = module_key + "_to_k_ip" |
| 38 | v_key = module_key + "_to_v_ip" |
| 39 | |
| 40 | # extra options for AnimateDiff |
| 41 | ad_params = extra_options['ad_params'] if "ad_params" in extra_options else None |
| 42 | |
| 43 | b = q.shape[0] |
| 44 | seq_len = q.shape[1] |
| 45 | batch_prompt = b // len(cond_or_uncond) |
| 46 | _, _, oh, ow = extra_options["original_shape"] |
| 47 | |
| 48 | if weight_type == 'ease in': |
| 49 | weight = weight * (0.05 + 0.95 * (1 - t_idx / layers)) |
| 50 | elif weight_type == 'ease out': |
| 51 | weight = weight * (0.05 + 0.95 * (t_idx / layers)) |
| 52 | elif weight_type == 'ease in-out': |
| 53 | weight = weight * (0.05 + 0.95 * (1 - abs(t_idx - (layers/2)) / (layers/2))) |
| 54 | elif weight_type == 'reverse in-out': |
| 55 | weight = weight * (0.05 + 0.95 * (abs(t_idx - (layers/2)) / (layers/2))) |
| 56 | elif weight_type == 'weak input' and block_type == 'input': |
| 57 | weight = weight * 0.2 |
| 58 | elif weight_type == 'weak middle' and block_type == 'middle': |
| 59 | weight = weight * 0.2 |
| 60 | elif weight_type == 'weak output' and block_type == 'output': |
| 61 | weight = weight * 0.2 |
| 62 | elif weight_type == 'strong middle' and (block_type == 'input' or block_type == 'output'): |
| 63 | weight = weight * 0.2 |
| 64 | elif isinstance(weight, dict): |
| 65 | if t_idx not in weight: |
| 66 | return 0 |
| 67 | |
| 68 | weight = weight[t_idx] |
| 69 | |
| 70 | if cond_alt is not None and t_idx in cond_alt: |
| 71 | cond = cond_alt[t_idx] |
| 72 | del cond_alt |
| 73 | |
| 74 | if unfold_batch: |
| 75 | # Check AnimateDiff context window |
| 76 | if ad_params is not None and ad_params["sub_idxs"] is not None: |
| 77 | if isinstance(weight, torch.Tensor): |
| 78 | weight = tensor_to_size(weight, ad_params["full_length"]) |
| 79 | weight = torch.Tensor(weight[ad_params["sub_idxs"]]) |
| 80 | if torch.all(weight == 0): |
| 81 | return 0 |
| 82 | weight = weight.repeat(len(cond_or_uncond), 1, 1) # repeat for cond and uncond |
| 83 | elif weight == 0: |
| 84 | return 0 |
| 85 | |
| 86 | # if image length matches or exceeds full_length get sub_idx images |
| 87 | if cond.shape[0] >= ad_params["full_length"]: |
nothing calls this directly
no test coverage detected