MCPcopy
hub / github.com/THUDM/LongWriter / repeat_kv

Function repeat_kv

train/patch/modeling_llama.py:266–275  ·  view source on GitHub ↗

This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)

(hidden_states: torch.Tensor, n_rep: int)

Source from the content-addressed store, hash-verified

264
265
266def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
267 """
268 This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
269 num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
270 """
271 batch, num_key_value_heads, slen, head_dim = hidden_states.shape
272 if n_rep == 1:
273 return hidden_states
274 hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
275 return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
276
277
278class LlamaAttention(nn.Module):

Callers 3

forwardMethod · 0.85
forwardMethod · 0.85
forwardMethod · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected