MCPcopy
hub / github.com/showlab/Show-o / forward

Method forward

models/phi.py:302–397  ·  view source on GitHub ↗
(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
    )

Source from the content-addressed store, hash-verified

300 raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
301
302 def forward(
303 self,
304 hidden_states: torch.Tensor,
305 attention_mask: Optional[torch.Tensor] = None,
306 position_ids: Optional[torch.LongTensor] = None,
307 past_key_value: Optional[Cache] = None,
308 output_attentions: bool = False,
309 use_cache: bool = False,
310 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
311 bsz, q_len, _ = hidden_states.size()
312
313 query_states = self.q_proj(hidden_states)
314 key_states = self.k_proj(hidden_states)
315 value_states = self.v_proj(hidden_states)
316
317 query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
318 key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
319 value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
320
321 if self.qk_layernorm:
322 query_states = self.q_layernorm(query_states)
323 key_states = self.k_layernorm(key_states)
324
325 kv_seq_len = key_states.shape[-2]
326 if past_key_value is not None:
327 if self.layer_idx is None:
328 raise ValueError(
329 f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
330 "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
331 "with a layer index."
332 )
333 kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
334 cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
335
336 # Partial rotary embedding
337 query_rot, query_pass = (
338 query_states[..., : self.rotary_emb.dim],
339 query_states[..., self.rotary_emb.dim :],
340 )
341 key_rot, key_pass = (
342 key_states[..., : self.rotary_emb.dim],
343 key_states[..., self.rotary_emb.dim :],
344 )
345 # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
346 query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
347
348 # [batch_size, seq_length, num_heads, head_dim]
349 query_states = torch.cat((query_rot, query_pass), dim=-1)
350 key_states = torch.cat((key_rot, key_pass), dim=-1)
351
352 if past_key_value is not None:
353 cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
354 key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
355
356 key_states = repeat_kv(key_states, self.num_key_value_groups)
357 value_states = repeat_kv(value_states, self.num_key_value_groups)
358
359 # Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow

Callers

nothing calls this directly

Calls 4

toMethod · 0.80
apply_rotary_pos_embFunction · 0.70
repeat_kvFunction · 0.70
updateMethod · 0.45

Tested by

no test coverage detected