Args: mask: (max_seq_len, max_seq_len) input_pos: (batch_size, seq_len) Returns: (batch_size, seq_len, max_seq_len)
(mask: torch.Tensor, input_pos: torch.Tensor)
| 57 | |
| 58 | |
| 59 | def _index_causal_mask(mask: torch.Tensor, input_pos: torch.Tensor): |
| 60 | """ |
| 61 | Args: |
| 62 | mask: (max_seq_len, max_seq_len) |
| 63 | input_pos: (batch_size, seq_len) |
| 64 | |
| 65 | Returns: |
| 66 | (batch_size, seq_len, max_seq_len) |
| 67 | """ |
| 68 | r = mask[input_pos, :] |
| 69 | return r |
| 70 | |
| 71 | |
| 72 | def _multinomial_sample_one_no_sync(probs): # Does multinomial sampling without a cuda synchronization |