MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / forward

Method forward

tensorrt_llm/models/bert/model.py:456–482  ·  view source on GitHub ↗
(self, hidden_states, input_lengths, remove_input_padding)

Source from the content-addressed store, hash-verified

454 self.out_proj = Linear(hidden_size, num_labels)
455
456 def forward(self, hidden_states, input_lengths, remove_input_padding):
457
458 if not remove_input_padding:
459 # We "pool" the model by simply taking the hidden state corresponding
460 # to the first token.
461 first_token_tensor = select(hidden_states, 1, 0)
462 else:
463 # when remove_input_padding is enabled, the shape of hidden_states is [num_tokens, hidden_size]
464 # We can take the first token of each sequence according to input_lengths,
465 # and then do pooling similar to padding mode.
466 # For example, if input_lengths is [8, 5, 6], then the indices of first tokens
467 # should be [0, 8, 13]
468 first_token_indices = cumsum(
469 concat([
470 0,
471 slice(input_lengths,
472 starts=[0],
473 sizes=(shape(input_lengths) -
474 constant(np.array([1], dtype=np.int32))))
475 ]), 0)
476 first_token_tensor = index_select(hidden_states, 0,
477 first_token_indices)
478
479 x = self.dense(first_token_tensor)
480 x = ACT2FN['tanh'](x)
481 x = self.out_proj(x)
482 return x
483
484
485class BertForSequenceClassification(BertBase):

Callers

nothing calls this directly

Calls 7

selectFunction · 0.85
cumsumFunction · 0.85
concatFunction · 0.85
sliceFunction · 0.85
constantFunction · 0.85
index_selectFunction · 0.85
shapeFunction · 0.50

Tested by

no test coverage detected