()
| 701 | kv_uncond = [[[], []] for _ in range(attn_counter)] |
| 702 | |
| 703 | def clear_cache(): |
| 704 | for storage in [kv_cond, kv_uncond]: |
| 705 | for kesy, values in storage: |
| 706 | kesy.clear() |
| 707 | values.clear() |
| 708 | |
| 709 | branch_n = len(conditions) + 2 |
| 710 | group_mask = torch.ones([branch_n, branch_n], dtype=torch.bool) |