(seq_len: int, device: torch.device)
| 53 | |
| 54 | |
| 55 | def _create_causal_mask(seq_len: int, device: torch.device): |
| 56 | return torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device)) |
| 57 | |
| 58 | |
| 59 | def _index_causal_mask(mask: torch.Tensor, input_pos: torch.Tensor): |