Parameters: hidden_states: [B, L, D] or [T, D] conv_state: [B, W, D] or [1] of type int64 for paged state ssm_state: [B, N, D] or [1] of type int64 for paged state host_request_types: [B] last_token_ids: [B] host_contex
(self,
hidden_states: Tensor,
conv_state: Tensor,
ssm_state: Tensor,
host_request_types: Tensor,
last_token_ids: Tensor,
host_context_lengths: Optional[Tensor] = None,
slot_mapping: Optional[Tensor] = None,
conv_indices: Optional[Tensor] = None)
| 170 | gather_output=False) |
| 171 | |
| 172 | def forward(self, |
| 173 | hidden_states: Tensor, |
| 174 | conv_state: Tensor, |
| 175 | ssm_state: Tensor, |
| 176 | host_request_types: Tensor, |
| 177 | last_token_ids: Tensor, |
| 178 | host_context_lengths: Optional[Tensor] = None, |
| 179 | slot_mapping: Optional[Tensor] = None, |
| 180 | conv_indices: Optional[Tensor] = None): |
| 181 | ''' |
| 182 | Parameters: |
| 183 | hidden_states: [B, L, D] or [T, D] |
| 184 | conv_state: [B, W, D] or [1] of type int64 for paged state |
| 185 | ssm_state: [B, N, D] or [1] of type int64 for paged state |
| 186 | host_request_types: [B] |
| 187 | last_token_ids: [B] |
| 188 | host_context_lengths: [B] |
| 189 | slot_mapping: [B] |
| 190 | conv_indices: [B] |
| 191 | ''' |
| 192 | # in_proj |
| 193 | x = self.in_proj_x(hidden_states) |
| 194 | z = self.in_proj_z(hidden_states) |
| 195 | |
| 196 | x_conv, conv_state = self.conv1d(x, conv_state, host_request_types, |
| 197 | last_token_ids, host_context_lengths, |
| 198 | slot_mapping, conv_indices) |
| 199 | |
| 200 | # Get dt, B and C |
| 201 | x_dbl = self.x_proj(x_conv) |
| 202 | if default_net().plugin_config.gemm_plugin: |
| 203 | dt = self.dt_proj(x_dbl) |
| 204 | else: |
| 205 | dt, _ = split(x_dbl, [self.dt_rank, self.d_state * 2], dim=-1) |
| 206 | dt = self.dt_proj(dt) |
| 207 | |
| 208 | # selective scan |
| 209 | y, ssm_state = selective_scan(x_conv, |
| 210 | ssm_state, |
| 211 | dt, |
| 212 | self.dt_bias.value, |
| 213 | self.A.value, |
| 214 | x_dbl, |
| 215 | self.D.value, |
| 216 | host_request_types, |
| 217 | last_token_ids, |
| 218 | self.d_inner, |
| 219 | self.d_state, |
| 220 | self.dt_rank, |
| 221 | delta_softplus=True, |
| 222 | dtype=self.dtype, |
| 223 | z=z, |
| 224 | host_context_lengths=host_context_lengths, |
| 225 | slot_mapping=slot_mapping) |
| 226 | # out_proj |
| 227 | out = self.out_proj(y) |
| 228 | return out, conv_state, ssm_state |
| 229 |
nothing calls this directly
no test coverage detected