(self,
input_ids: Tensor,
position_ids=None,
use_cache=False,
last_token_ids=None,
attention_mask=None,
kv_cache_params=None,
attention_params=None,
mrope_params=None,
hidden_states=None,
prompt_embedding_table: Optional[Tensor] = None,
prompt_tasks: Optional[Tensor] = None,
prompt_vocab_size: Optional[Tensor] = None,
lora_params=None,
spec_decoding_params=None)
| 1035 | self.position_embedding_type = config.position_embedding_type |
| 1036 | |
| 1037 | def forward(self, |
| 1038 | input_ids: Tensor, |
| 1039 | position_ids=None, |
| 1040 | use_cache=False, |
| 1041 | last_token_ids=None, |
| 1042 | attention_mask=None, |
| 1043 | kv_cache_params=None, |
| 1044 | attention_params=None, |
| 1045 | mrope_params=None, |
| 1046 | hidden_states=None, |
| 1047 | prompt_embedding_table: Optional[Tensor] = None, |
| 1048 | prompt_tasks: Optional[Tensor] = None, |
| 1049 | prompt_vocab_size: Optional[Tensor] = None, |
| 1050 | lora_params=None, |
| 1051 | spec_decoding_params=None): |
| 1052 | |
| 1053 | # fill attention params. |
| 1054 | attention_params = Attention.fill_attention_params( |
| 1055 | self, attention_params) |
| 1056 | |
| 1057 | # split the sequence for context parallelism |
| 1058 | if self.config.mapping.cp_size > 1: |
| 1059 | if len(input_ids.shape) == 1: |
| 1060 | # input shape is [-1] |
| 1061 | input_ids, cp_join_index = cp_split_plugin( |
| 1062 | input_ids, |
| 1063 | attention_params.host_request_types, |
| 1064 | attention_params.host_context_lengths, |
| 1065 | self.config.mapping.cp_size, |
| 1066 | self.config.mapping.cp_rank, |
| 1067 | ) |
| 1068 | else: |
| 1069 | assert False, "Context parallelism with non-remove-padding is not supported yet." |
| 1070 | |
| 1071 | is_gemma_2_cg = self.config.has_config_group(Gemma2ConfigGroup) |
| 1072 | is_gemma_3_cg = self.config.has_config_group(Gemma3ConfigGroup) |
| 1073 | |
| 1074 | kwargs = { |
| 1075 | 'input_ids': input_ids, |
| 1076 | 'position_ids': position_ids, |
| 1077 | 'use_cache': use_cache, |
| 1078 | 'attention_mask': attention_mask, |
| 1079 | 'kv_cache_params': kv_cache_params, |
| 1080 | 'attention_params': attention_params, |
| 1081 | } |
| 1082 | if lora_params is not None: |
| 1083 | kwargs['lora_params'] = lora_params |
| 1084 | if hidden_states is not None: |
| 1085 | kwargs['hidden_states'] = hidden_states |
| 1086 | if prompt_embedding_table is not None: |
| 1087 | kwargs['prompt_embedding_table'] = prompt_embedding_table |
| 1088 | if prompt_tasks is not None: |
| 1089 | kwargs['prompt_tasks'] = prompt_tasks |
| 1090 | if prompt_vocab_size is not None: |
| 1091 | kwargs['prompt_vocab_size'] = prompt_vocab_size |
| 1092 | |
| 1093 | if spec_decoding_params is not None: |
| 1094 | kwargs['spec_decoding_params'] = spec_decoding_params |
nothing calls this directly
no test coverage detected