MCPcopy
hub / github.com/lm-sys/FastChat / forward

Function forward

fastchat/model/monkey_patch_non_inplace.py:30–114  ·  view source on GitHub ↗
(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_value: Optional[Tuple[torch.Tensor]] = None,
    output_attentions: bool = False,
    use_cache: bool = False,
    padding_mask: Optional[torch.LongTensor] = None,
)

Source from the content-addressed store, hash-verified

28
29
30def forward(
31 self,
32 hidden_states: torch.Tensor,
33 attention_mask: Optional[torch.Tensor] = None,
34 position_ids: Optional[torch.LongTensor] = None,
35 past_key_value: Optional[Tuple[torch.Tensor]] = None,
36 output_attentions: bool = False,
37 use_cache: bool = False,
38 padding_mask: Optional[torch.LongTensor] = None,
39) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
40 bsz, q_len, _ = hidden_states.size()
41
42 query_states = (
43 self.q_proj(hidden_states)
44 .view(bsz, q_len, self.num_heads, self.head_dim)
45 .transpose(1, 2)
46 )
47 key_states = (
48 self.k_proj(hidden_states)
49 .view(bsz, q_len, self.num_heads, self.head_dim)
50 .transpose(1, 2)
51 )
52 value_states = (
53 self.v_proj(hidden_states)
54 .view(bsz, q_len, self.num_heads, self.head_dim)
55 .transpose(1, 2)
56 )
57
58 kv_seq_len = key_states.shape[-2]
59 if past_key_value is not None:
60 kv_seq_len += past_key_value[0].shape[-2]
61 cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
62 query_states, key_states = apply_rotary_pos_emb(
63 query_states, key_states, cos, sin, position_ids
64 )
65 # [bsz, nh, t, hd]
66
67 if past_key_value is not None:
68 # reuse k, v, self_attention
69 key_states = torch.cat([past_key_value[0], key_states], dim=2)
70 value_states = torch.cat([past_key_value[1], value_states], dim=2)
71
72 past_key_value = (key_states, value_states) if use_cache else None
73
74 attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(
75 self.head_dim
76 )
77
78 if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
79 raise ValueError(
80 f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
81 f" {attn_weights.size()}"
82 )
83
84 if attention_mask is not None:
85 if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
86 raise ValueError(
87 f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"

Callers

nothing calls this directly

Calls 2

toMethod · 0.80
apply_rotary_pos_embFunction · 0.70

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…