| 134 | dtype=config.dtype) |
| 135 | |
| 136 | def forward(self, |
| 137 | input_ids, |
| 138 | conv_states, |
| 139 | ssm_states, |
| 140 | host_request_types, |
| 141 | last_token_ids, |
| 142 | host_context_lengths, |
| 143 | slot_mapping: Optional[Tensor] = None): |
| 144 | hidden_states = self.vocab_embedding(input_ids) |
| 145 | |
| 146 | # Get conv state indices |
| 147 | indices = None |
| 148 | if not default_net().plugin_config.mamba_conv1d_plugin: |
| 149 | batch_size = shape(input_ids, 0) |
| 150 | indices = expand( |
| 151 | unsqueeze(arange(0, self.d_conv - 1, dtype='int32'), 0), |
| 152 | concat([batch_size, self.d_conv - 1])) |
| 153 | offsets = expand(unsqueeze(last_token_ids, 1), |
| 154 | concat([batch_size, self.d_conv - 1])) |
| 155 | indices = unsqueeze(indices + offsets, 1) |
| 156 | indices = expand( |
| 157 | indices, concat([batch_size, self.d_inner, self.d_conv - 1])) |
| 158 | |
| 159 | residual = cast(hidden_states, |
| 160 | 'float32') if self.residual_in_fp32 else hidden_states |
| 161 | hidden_values = [hidden_states, residual] |
| 162 | present_convs, present_ssms = [], [] |
| 163 | for layer, past_conv, past_ssm in zip(self.layers, conv_states, |
| 164 | ssm_states): |
| 165 | hidden_values = layer(hidden_values[0], hidden_values[1], past_conv, |
| 166 | past_ssm, host_request_types, last_token_ids, |
| 167 | host_context_lengths, slot_mapping, indices) |
| 168 | present_convs.append(hidden_values[2]) |
| 169 | present_ssms.append(hidden_values[3]) |
| 170 | hidden_states = hidden_values[0] |
| 171 | hidden_states = self.ln_f(hidden_states) |
| 172 | return hidden_states, tuple(present_convs), tuple(present_ssms) |
| 173 | |
| 174 | |
| 175 | class MambaForCausalLM(PretrainedModel): |