MCPcopy
hub / github.com/OpenMOSS/MOSS / execute

Method execute

models_jittor/model.py:104–175  ·  view source on GitHub ↗
(
        self,
        hidden_states: Optional[jt.Var],
        attention_mask: Optional[jt.Var] = None,
        layer_past: Optional[Tuple[jt.Var]] = None,
        head_mask: Optional[jt.Var] = None,
        use_cache: Optional[bool] = False,
    )

Source from the content-addressed store, hash-verified

102 return attn_output, attn_weights
103
104 def execute(
105 self,
106 hidden_states: Optional[jt.Var],
107 attention_mask: Optional[jt.Var] = None,
108 layer_past: Optional[Tuple[jt.Var]] = None,
109 head_mask: Optional[jt.Var] = None,
110 use_cache: Optional[bool] = False,
111 ) -> Union[
112 Tuple[jt.Var, Tuple[jt.Var]],
113 Optional[Tuple[jt.Var, Tuple[jt.Var], Tuple[jt.Var, ...]]],
114 ]:
115 qkv = self.qkv_proj(hidden_states)
116 mp_num = 4
117 qkv_split = qkv.reshape(qkv.shape[:-1] + (mp_num, -1))
118
119 local_dim = self.head_dim * self.num_attention_heads // mp_num
120 query, value, key = jt.split(qkv_split, local_dim, dim=-1)
121 query = self._split_heads(query, self.num_attention_heads, self.head_dim, mp_num=mp_num)
122 key = self._split_heads(key, self.num_attention_heads, self.head_dim, mp_num=mp_num)
123
124 value = self._split_heads(value, self.num_attention_heads, self.head_dim, mp_num=mp_num)
125 value = value.permute(0, 2, 1, 3)
126
127 seq_len = key.shape[1]
128 offset = 0
129
130 if layer_past is not None:
131 offset = layer_past[0].shape[-2]
132 seq_len += offset
133
134 if self.rotary_dim is not None:
135 k_rot = key[:, :, :, : self.rotary_dim]
136 k_pass = key[:, :, :, self.rotary_dim :]
137
138 q_rot = query[:, :, :, : self.rotary_dim]
139 q_pass = query[:, :, :, self.rotary_dim :]
140
141 sincos = fixed_pos_embedding(k_rot, 1, seq_len=seq_len)
142 k_rot = apply_rotary_pos_emb(k_rot, sincos, offset=offset)
143 q_rot = apply_rotary_pos_emb(q_rot, sincos, offset=offset)
144
145 key = jt.cat([k_rot, k_pass], dim=-1)
146 query = jt.cat([q_rot, q_pass], dim=-1)
147 else:
148 sincos = fixed_pos_embedding(key, 1, seq_len=seq_len)
149 key = apply_rotary_pos_emb(key, sincos, offset=offset)
150 query = apply_rotary_pos_emb(query, sincos, offset=offset)
151
152 key = key.permute(0, 2, 1, 3)
153 query = query.permute(0, 2, 1, 3)
154
155 if layer_past is not None:
156 past_key = layer_past[0]
157 past_value = layer_past[1]
158 key = jt.cat((past_key, key), dim=-2)
159 value = jt.cat((past_value, value), dim=-2)
160
161 if use_cache is True:

Callers

nothing calls this directly

Calls 5

_split_headsMethod · 0.95
_attnMethod · 0.95
_merge_headsMethod · 0.95
fixed_pos_embeddingFunction · 0.85
apply_rotary_pos_embFunction · 0.70

Tested by

no test coverage detected