(B, LEN, modalities, device, inverted=True)
| 54 | |
| 55 | |
| 56 | def omni_attn_mask_naive(B, LEN, modalities, device, inverted=True): |
| 57 | attention_mask = torch.tril(torch.ones((B, 1, LEN, LEN), dtype=torch.long)).to(device) |
| 58 | for b in range(B): |
| 59 | modality_batch = modalities[b] |
| 60 | for offset, length in modality_batch: |
| 61 | attention_mask[b, :, offset:offset + length, offset:offset + length] = 1 |
| 62 | |
| 63 | if inverted: |
| 64 | inverted_attention_mask = 1 - attention_mask |
| 65 | inverted_attention_mask = inverted_attention_mask.masked_fill( |
| 66 | inverted_attention_mask.to(torch.bool), torch.iinfo(torch.long).min |
| 67 | ) |
| 68 | return inverted_attention_mask |
| 69 | else: |
| 70 | return attention_mask |
| 71 | |
| 72 | |
| 73 | def full_attn_mask_naive(B, LEN, device, inverted=True): |
no test coverage detected