This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
(hidden_states: torch.Tensor, n_rep: int)
| 264 | |
| 265 | |
| 266 | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
| 267 | """ |
| 268 | This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
| 269 | num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) |
| 270 | """ |
| 271 | batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
| 272 | if n_rep == 1: |
| 273 | return hidden_states |
| 274 | hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) |
| 275 | return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
| 276 | |
| 277 | |
| 278 | class LlamaAttention(nn.Module): |