| 121 | |
| 122 | # context manager that replaces the attention function in a transformer block |
| 123 | class PatchAttention(contextlib.AbstractContextManager): |
| 124 | def __init__(self, attn_idx: Optional[Union[int, List[int]]] = None): |
| 125 | self.current_idx = -1 |
| 126 | |
| 127 | if isinstance(attn_idx, int): |
| 128 | self.attn_idx = [attn_idx] |
| 129 | elif attn_idx is None: |
| 130 | self.attn_idx = [0] |
| 131 | else: |
| 132 | self.attn_idx = list(attn_idx) |
| 133 | |
| 134 | def __enter__(self): |
| 135 | self.original_attention = comfy.ldm.modules.attention.optimized_attention |
| 136 | self.original_attention_masked = ( |
| 137 | comfy.ldm.modules.attention.optimized_attention_masked |
| 138 | ) |
| 139 | |
| 140 | comfy.ldm.modules.attention.optimized_attention = self.stg_attention |
| 141 | comfy.ldm.modules.attention.optimized_attention_masked = ( |
| 142 | self.stg_attention_masked |
| 143 | ) |
| 144 | |
| 145 | def __exit__(self, exc_type, exc_value, traceback): |
| 146 | comfy.ldm.modules.attention.optimized_attention = self.original_attention |
| 147 | comfy.ldm.modules.attention.optimized_attention_masked = ( |
| 148 | self.original_attention_masked |
| 149 | ) |
| 150 | |
| 151 | self.original_attention = None |
| 152 | self.original_attention_masked = None |
| 153 | |
| 154 | def stg_attention(self, q, k, v, heads, *args, **kwargs): |
| 155 | self.current_idx += 1 |
| 156 | if self.current_idx in self.attn_idx: |
| 157 | return v |
| 158 | else: |
| 159 | return self.original_attention(q, k, v, heads, *args, **kwargs) |
| 160 | |
| 161 | def stg_attention_masked(self, q, k, v, heads, *args, **kwargs): |
| 162 | self.current_idx += 1 |
| 163 | if self.current_idx in self.attn_idx: |
| 164 | return v |
| 165 | else: |
| 166 | return self.original_attention_masked(q, k, v, heads, *args, **kwargs) |
| 167 | |
| 168 | |
| 169 | class STGBlockWrapper: |