MCPcopy
hub / github.com/THUDM/LongWriter / forward

Method forward

train/patch/modeling_llama.py:316–407  ·  view source on GitHub ↗
(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.45
        **kwargs,
    )

Source from the content-addressed store, hash-verified

314 self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
315
316 def forward(
317 self,
318 hidden_states: torch.Tensor,
319 attention_mask: Optional[torch.Tensor] = None,
320 position_ids: Optional[torch.LongTensor] = None,
321 past_key_value: Optional[Cache] = None,
322 output_attentions: bool = False,
323 use_cache: bool = False,
324 cache_position: Optional[torch.LongTensor] = None,
325 position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
326 **kwargs,
327 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
328 bsz, q_len, _ = hidden_states.size()
329
330 if self.config.pretraining_tp > 1:
331 key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
332 query_slices = self.q_proj.weight.split(
333 (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
334 )
335 key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
336 value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
337
338 query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
339 query_states = torch.cat(query_states, dim=-1)
340
341 key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
342 key_states = torch.cat(key_states, dim=-1)
343
344 value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
345 value_states = torch.cat(value_states, dim=-1)
346
347 else:
348 query_states = self.q_proj(hidden_states)
349 key_states = self.k_proj(hidden_states)
350 value_states = self.v_proj(hidden_states)
351
352 query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
353 key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
354 value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
355
356 if position_embeddings is None:
357 logger.warning_once(
358 "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
359 "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
360 "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
361 "removed and `position_embeddings` will be mandatory."
362 )
363 cos, sin = self.rotary_emb(value_states, position_ids)
364 else:
365 cos, sin = position_embeddings
366 query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
367
368 if past_key_value is not None:
369 # sin and cos are specific to RoPE models; cache_position needed for the static cache
370 cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
371 key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
372
373 key_states = repeat_kv(key_states, self.num_key_value_groups)

Callers

nothing calls this directly

Calls 2

repeat_kvFunction · 0.85
apply_rotary_pos_embFunction · 0.70

Tested by

no test coverage detected