MCPcopy
hub / github.com/THUDM/LongWriter / forward

Method forward

train/patch/modeling_llama.py:425–526  ·  view source on GitHub ↗
(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.45
    )

Source from the content-addressed store, hash-verified

423 self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
424
425 def forward(
426 self,
427 hidden_states: torch.Tensor,
428 attention_mask: Optional[torch.LongTensor] = None,
429 position_ids: Optional[torch.LongTensor] = None,
430 past_key_value: Optional[Cache] = None,
431 output_attentions: bool = False,
432 use_cache: bool = False,
433 cache_position: Optional[torch.LongTensor] = None,
434 position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
435 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
436 if isinstance(past_key_value, StaticCache):
437 raise ValueError(
438 "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
439 "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
440 )
441
442 output_attentions = False
443
444 bsz, q_len, _ = hidden_states.size()
445
446 query_states = self.q_proj(hidden_states)
447 key_states = self.k_proj(hidden_states)
448 value_states = self.v_proj(hidden_states)
449
450 # Flash attention requires the input to have the shape
451 # batch_size x seq_length x head_dim x hidden_dim
452 # therefore we just need to keep the original shape
453 query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
454 key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
455 value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
456
457 if position_embeddings is None:
458 logger.warning_once(
459 "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
460 "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
461 "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
462 "removed and `position_embeddings` will be mandatory."
463 )
464 cos, sin = self.rotary_emb(value_states, position_ids)
465 else:
466 cos, sin = position_embeddings
467 query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
468
469 if past_key_value is not None:
470 # sin and cos are specific to RoPE models; cache_position needed for the static cache
471 cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
472 key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
473
474 # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
475 # to be able to avoid many of these transpose/reshape/view.
476 query_states = query_states.transpose(1, 2)
477 key_states = key_states.transpose(1, 2)
478 value_states = value_states.transpose(1, 2)
479
480 dropout_rate = self.attention_dropout if self.training else 0.0
481
482 # In PEFT, usually we cast the layer norms in float32 for training stability reasons

Callers

nothing calls this directly

Calls 1

apply_rotary_pos_embFunction · 0.70

Tested by

no test coverage detected