MCPcopy Index your code
hub / github.com/NVIDIA/TensorRT-LLM / forward

Method forward

tensorrt_llm/models/mamba/model.py:136–172  ·  view source on GitHub ↗
(self,
                input_ids,
                conv_states,
                ssm_states,
                host_request_types,
                last_token_ids,
                host_context_lengths,
                slot_mapping: Optional[Tensor] = None)

Source from the content-addressed store, hash-verified

134 dtype=config.dtype)
135
136 def forward(self,
137 input_ids,
138 conv_states,
139 ssm_states,
140 host_request_types,
141 last_token_ids,
142 host_context_lengths,
143 slot_mapping: Optional[Tensor] = None):
144 hidden_states = self.vocab_embedding(input_ids)
145
146 # Get conv state indices
147 indices = None
148 if not default_net().plugin_config.mamba_conv1d_plugin:
149 batch_size = shape(input_ids, 0)
150 indices = expand(
151 unsqueeze(arange(0, self.d_conv - 1, dtype='int32'), 0),
152 concat([batch_size, self.d_conv - 1]))
153 offsets = expand(unsqueeze(last_token_ids, 1),
154 concat([batch_size, self.d_conv - 1]))
155 indices = unsqueeze(indices + offsets, 1)
156 indices = expand(
157 indices, concat([batch_size, self.d_inner, self.d_conv - 1]))
158
159 residual = cast(hidden_states,
160 'float32') if self.residual_in_fp32 else hidden_states
161 hidden_values = [hidden_states, residual]
162 present_convs, present_ssms = [], []
163 for layer, past_conv, past_ssm in zip(self.layers, conv_states,
164 ssm_states):
165 hidden_values = layer(hidden_values[0], hidden_values[1], past_conv,
166 past_ssm, host_request_types, last_token_ids,
167 host_context_lengths, slot_mapping, indices)
168 present_convs.append(hidden_values[2])
169 present_ssms.append(hidden_values[3])
170 hidden_states = hidden_values[0]
171 hidden_states = self.ln_f(hidden_states)
172 return hidden_states, tuple(present_convs), tuple(present_ssms)
173
174
175class MambaForCausalLM(PretrainedModel):

Callers

nothing calls this directly

Calls 8

default_netFunction · 0.85
expandFunction · 0.85
unsqueezeFunction · 0.85
arangeFunction · 0.85
concatFunction · 0.85
castFunction · 0.85
shapeFunction · 0.50
appendMethod · 0.45

Tested by

no test coverage detected