MCPcopy Index your code
hub / github.com/NVIDIA/TensorRT-LLM / forward

Method forward

tensorrt_llm/layers/ssm.py:172–228  ·  view source on GitHub ↗

Parameters: hidden_states: [B, L, D] or [T, D] conv_state: [B, W, D] or [1] of type int64 for paged state ssm_state: [B, N, D] or [1] of type int64 for paged state host_request_types: [B] last_token_ids: [B] host_contex

(self,
                hidden_states: Tensor,
                conv_state: Tensor,
                ssm_state: Tensor,
                host_request_types: Tensor,
                last_token_ids: Tensor,
                host_context_lengths: Optional[Tensor] = None,
                slot_mapping: Optional[Tensor] = None,
                conv_indices: Optional[Tensor] = None)

Source from the content-addressed store, hash-verified

170 gather_output=False)
171
172 def forward(self,
173 hidden_states: Tensor,
174 conv_state: Tensor,
175 ssm_state: Tensor,
176 host_request_types: Tensor,
177 last_token_ids: Tensor,
178 host_context_lengths: Optional[Tensor] = None,
179 slot_mapping: Optional[Tensor] = None,
180 conv_indices: Optional[Tensor] = None):
181 '''
182 Parameters:
183 hidden_states: [B, L, D] or [T, D]
184 conv_state: [B, W, D] or [1] of type int64 for paged state
185 ssm_state: [B, N, D] or [1] of type int64 for paged state
186 host_request_types: [B]
187 last_token_ids: [B]
188 host_context_lengths: [B]
189 slot_mapping: [B]
190 conv_indices: [B]
191 '''
192 # in_proj
193 x = self.in_proj_x(hidden_states)
194 z = self.in_proj_z(hidden_states)
195
196 x_conv, conv_state = self.conv1d(x, conv_state, host_request_types,
197 last_token_ids, host_context_lengths,
198 slot_mapping, conv_indices)
199
200 # Get dt, B and C
201 x_dbl = self.x_proj(x_conv)
202 if default_net().plugin_config.gemm_plugin:
203 dt = self.dt_proj(x_dbl)
204 else:
205 dt, _ = split(x_dbl, [self.dt_rank, self.d_state * 2], dim=-1)
206 dt = self.dt_proj(dt)
207
208 # selective scan
209 y, ssm_state = selective_scan(x_conv,
210 ssm_state,
211 dt,
212 self.dt_bias.value,
213 self.A.value,
214 x_dbl,
215 self.D.value,
216 host_request_types,
217 last_token_ids,
218 self.d_inner,
219 self.d_state,
220 self.dt_rank,
221 delta_softplus=True,
222 dtype=self.dtype,
223 z=z,
224 host_context_lengths=host_context_lengths,
225 slot_mapping=slot_mapping)
226 # out_proj
227 out = self.out_proj(y)
228 return out, conv_state, ssm_state
229

Callers

nothing calls this directly

Calls 3

default_netFunction · 0.85
selective_scanFunction · 0.85
splitFunction · 0.50

Tested by

no test coverage detected