(self,
decoder_input_ids: Tensor,
encoder_output: Tensor,
position_ids=None,
token_type_ids=None,
use_cache=False,
attention_mask_params=None,
last_token_ids=None,
kv_cache_params=None,
attention_params=None,
hidden_states=None,
lora_params: LoraParams = None,
cross_kv_cache_gen: Optional[Tensor] = None,
cross_kv_reuse: Optional[Tensor] = None,
language_adapter_routings: Optional[Tensor] = None)
| 1213 | config.set_if_not_exist('relative_attention', False) |
| 1214 | |
| 1215 | def forward(self, |
| 1216 | decoder_input_ids: Tensor, |
| 1217 | encoder_output: Tensor, |
| 1218 | position_ids=None, |
| 1219 | token_type_ids=None, |
| 1220 | use_cache=False, |
| 1221 | attention_mask_params=None, |
| 1222 | last_token_ids=None, |
| 1223 | kv_cache_params=None, |
| 1224 | attention_params=None, |
| 1225 | hidden_states=None, |
| 1226 | lora_params: LoraParams = None, |
| 1227 | cross_kv_cache_gen: Optional[Tensor] = None, |
| 1228 | cross_kv_reuse: Optional[Tensor] = None, |
| 1229 | language_adapter_routings: Optional[Tensor] = None): |
| 1230 | if self.mapping.is_first_pp_rank(): |
| 1231 | assert isinstance(decoder_input_ids, Tensor) |
| 1232 | else: |
| 1233 | assert isinstance(hidden_states, Tensor) |
| 1234 | |
| 1235 | # In PP, layer 0 has ids as inputs, all other layers have hidden_states as inputs |
| 1236 | if self.mapping.is_first_pp_rank(): |
| 1237 | hidden_states = self.transformer.embedding(decoder_input_ids, |
| 1238 | position_ids, |
| 1239 | token_type_ids) |
| 1240 | self.register_network_output('embedding_layer_output', |
| 1241 | hidden_states) |
| 1242 | else: |
| 1243 | hidden_states = recv(hidden_states, self.mapping.prev_pp_rank()) |
| 1244 | |
| 1245 | kv_cache_params.fill_none_tensor_list(len(self.transformer.layers)) |
| 1246 | |
| 1247 | if use_cache: |
| 1248 | presents = [] |
| 1249 | |
| 1250 | for i, (decoder_layer, past) in enumerate( |
| 1251 | zip(self.transformer.layers, kv_cache_params.past_key_value)): |
| 1252 | |
| 1253 | lora_layer_params = None |
| 1254 | if lora_params is not None and lora_params.lora_ranks is not None: |
| 1255 | lora_layer_params = lora_params.get_layer_params(i) |
| 1256 | |
| 1257 | hidden_states = decoder_layer( |
| 1258 | hidden_states, |
| 1259 | encoder_output=encoder_output, |
| 1260 | attention_mask_params=attention_mask_params, |
| 1261 | use_cache=use_cache, |
| 1262 | kv_cache_params=KeyValueCacheParams( |
| 1263 | past_key_value=past, |
| 1264 | host_past_key_value_lengths=kv_cache_params. |
| 1265 | host_past_key_value_lengths, |
| 1266 | host_max_attention_window_sizes=kv_cache_params. |
| 1267 | host_max_attention_window_sizes, |
| 1268 | host_sink_token_length=kv_cache_params. |
| 1269 | host_sink_token_length, |
| 1270 | cache_indirection=kv_cache_params.cache_indirection, |
| 1271 | kv_cache_block_offsets=kv_cache_params. |
| 1272 | kv_cache_block_offsets, |
nothing calls this directly
no test coverage detected