Parameters: x: [B, L, D] or [T, D] conv_state: [B, W, D] or [1] of type int64 for paged state host_request_types: [B] last_token_ids: [B] host_context_lengths: [B] slot_mapping: [B] conv_indices: [B]
(self,
x: Tensor,
conv_state: Tensor,
host_request_types: Tensor,
last_token_ids: Tensor,
host_context_lengths: Optional[Tensor] = None,
slot_mapping: Optional[Tensor] = None,
conv_indices: Optional[Tensor] = None)
| 46 | self.apply_silu = apply_silu |
| 47 | |
| 48 | def forward(self, |
| 49 | x: Tensor, |
| 50 | conv_state: Tensor, |
| 51 | host_request_types: Tensor, |
| 52 | last_token_ids: Tensor, |
| 53 | host_context_lengths: Optional[Tensor] = None, |
| 54 | slot_mapping: Optional[Tensor] = None, |
| 55 | conv_indices: Optional[Tensor] = None): |
| 56 | ''' |
| 57 | Parameters: |
| 58 | x: [B, L, D] or [T, D] |
| 59 | conv_state: [B, W, D] or [1] of type int64 for paged state |
| 60 | host_request_types: [B] |
| 61 | last_token_ids: [B] |
| 62 | host_context_lengths: [B] |
| 63 | slot_mapping: [B] |
| 64 | conv_indices: [B] |
| 65 | ''' |
| 66 | if default_net().plugin_config.mamba_conv1d_plugin: |
| 67 | transposed_weight = permute( |
| 68 | view(self.weight.value, shape=[self.d_inner, 1, self.d_conv]), |
| 69 | (1, 2, 0)) |
| 70 | x_conv, conv_state = mamba_conv1d( |
| 71 | x, conv_state, transposed_weight, self.bias.value, |
| 72 | host_request_types, last_token_ids, self.d_inner, self.d_conv, |
| 73 | self.dtype, self.pre_stride, self.post_stride, |
| 74 | host_context_lengths, slot_mapping, self.apply_silu) |
| 75 | else: |
| 76 | assert not default_net().plugin_config.paged_state |
| 77 | assert len( |
| 78 | x.shape |
| 79 | ) == 3, "remove_input_padding is not supported by OOTB for Mamba." |
| 80 | if self.pre_stride > 0: |
| 81 | _, x = split(x, |
| 82 | [self.pre_stride, self.d_inner + self.post_stride], |
| 83 | dim=-1) |
| 84 | if self.post_stride > 0: |
| 85 | x, _ = split(x, [self.d_inner, self.post_stride], dim=-1) |
| 86 | x = x.permute([0, 2, 1]) |
| 87 | |
| 88 | # In context phase, conv_state is a zero tensor, and it is used for padding |
| 89 | # In generation phase, conv_state is a tensor of the past x |
| 90 | x_pad = concat([conv_state, x], dim=2) |
| 91 | |
| 92 | # Update conv_state |
| 93 | conv_state = gather(x_pad, 2, conv_indices) |
| 94 | |
| 95 | # Convolution |
| 96 | x_pad = x_pad.view( |
| 97 | concat([shape(x_pad, 0), |
| 98 | shape(x_pad, 1), |
| 99 | shape(x_pad, 2), 1])) |
| 100 | x_conv = conv2d(x_pad, |
| 101 | self.weight.value, |
| 102 | self.bias.value, |
| 103 | groups=self.d_inner) |
| 104 | if self.apply_silu: |
| 105 | x_conv = ACT2FN['silu'](x_conv) |
nothing calls this directly
no test coverage detected