Wraps transformer blocks to be able to skip attention layers.
| 167 | |
| 168 | |
| 169 | class STGBlockWrapper: |
| 170 | """Wraps transformer blocks to be able to skip attention layers.""" |
| 171 | |
| 172 | def __init__(self, block, stg_flag: STGFlag, idx: int): |
| 173 | self.flag = stg_flag |
| 174 | self.idx = idx |
| 175 | self.block = block |
| 176 | |
| 177 | def __call__(self, args, extra_args): |
| 178 | context_manager = contextlib.nullcontext() |
| 179 | |
| 180 | stg_indexes = args["transformer_options"].get("stg_indexes", [0]) |
| 181 | if self.flag.do_skip and self.idx in self.flag.skip_layers: |
| 182 | context_manager = PatchAttention(stg_indexes) |
| 183 | |
| 184 | with context_manager: |
| 185 | hidden_state = extra_args["original_block"](args) |
| 186 | return hidden_state |
| 187 | |
| 188 | |
| 189 | class STGGuider(comfy.samplers.CFGGuider): |
no outgoing calls
no test coverage detected