MCPcopy Index your code
hub / github.com/Lightricks/ComfyUI-LTXVideo / PatchAttention

Class PatchAttention

stg.py:123–166  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

121
122# context manager that replaces the attention function in a transformer block
123class PatchAttention(contextlib.AbstractContextManager):
124 def __init__(self, attn_idx: Optional[Union[int, List[int]]] = None):
125 self.current_idx = -1
126
127 if isinstance(attn_idx, int):
128 self.attn_idx = [attn_idx]
129 elif attn_idx is None:
130 self.attn_idx = [0]
131 else:
132 self.attn_idx = list(attn_idx)
133
134 def __enter__(self):
135 self.original_attention = comfy.ldm.modules.attention.optimized_attention
136 self.original_attention_masked = (
137 comfy.ldm.modules.attention.optimized_attention_masked
138 )
139
140 comfy.ldm.modules.attention.optimized_attention = self.stg_attention
141 comfy.ldm.modules.attention.optimized_attention_masked = (
142 self.stg_attention_masked
143 )
144
145 def __exit__(self, exc_type, exc_value, traceback):
146 comfy.ldm.modules.attention.optimized_attention = self.original_attention
147 comfy.ldm.modules.attention.optimized_attention_masked = (
148 self.original_attention_masked
149 )
150
151 self.original_attention = None
152 self.original_attention_masked = None
153
154 def stg_attention(self, q, k, v, heads, *args, **kwargs):
155 self.current_idx += 1
156 if self.current_idx in self.attn_idx:
157 return v
158 else:
159 return self.original_attention(q, k, v, heads, *args, **kwargs)
160
161 def stg_attention_masked(self, q, k, v, heads, *args, **kwargs):
162 self.current_idx += 1
163 if self.current_idx in self.attn_idx:
164 return v
165 else:
166 return self.original_attention_masked(q, k, v, heads, *args, **kwargs)
167
168
169class STGBlockWrapper:

Callers 1

__call__Method · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected