MCPcopy
hub / github.com/THUDM/LongWriter / forward

Method forward

train/patch/modeling_llama.py:537–625  ·  view source on GitHub ↗
(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.45
        **kwargs,
    )

Source from the content-addressed store, hash-verified

535
536 # Adapted from LlamaAttention.forward
537 def forward(
538 self,
539 hidden_states: torch.Tensor,
540 attention_mask: Optional[torch.Tensor] = None,
541 position_ids: Optional[torch.LongTensor] = None,
542 past_key_value: Optional[Cache] = None,
543 output_attentions: bool = False,
544 use_cache: bool = False,
545 cache_position: Optional[torch.LongTensor] = None,
546 position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
547 **kwargs,
548 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
549 if output_attentions:
550 # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
551 logger.warning_once(
552 "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
553 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
554 )
555 return super().forward(
556 hidden_states=hidden_states,
557 attention_mask=attention_mask,
558 position_ids=position_ids,
559 past_key_value=past_key_value,
560 output_attentions=output_attentions,
561 use_cache=use_cache,
562 cache_position=cache_position,
563 position_embeddings=position_embeddings,
564 )
565
566 bsz, q_len, _ = hidden_states.size()
567
568 query_states = self.q_proj(hidden_states)
569 key_states = self.k_proj(hidden_states)
570 value_states = self.v_proj(hidden_states)
571
572 query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
573 key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
574 value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
575
576 if position_embeddings is None:
577 logger.warning_once(
578 "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
579 "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
580 "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
581 "removed and `position_embeddings` will be mandatory."
582 )
583 cos, sin = self.rotary_emb(value_states, position_ids)
584 else:
585 cos, sin = position_embeddings
586 query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
587
588 if past_key_value is not None:
589 # sin and cos are specific to RoPE models; cache_position needed for the static cache
590 cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
591 key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
592
593 key_states = repeat_kv(key_states, self.num_key_value_groups)
594 value_states = repeat_kv(value_states, self.num_key_value_groups)

Callers

nothing calls this directly

Calls 3

repeat_kvFunction · 0.85
apply_rotary_pos_embFunction · 0.70
forwardMethod · 0.45

Tested by

no test coverage detected