MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / forward

Method forward

tensorrt_llm/layers/ssm.py:48–113  ·  view source on GitHub ↗

Parameters: x: [B, L, D] or [T, D] conv_state: [B, W, D] or [1] of type int64 for paged state host_request_types: [B] last_token_ids: [B] host_context_lengths: [B] slot_mapping: [B] conv_indices: [B]

(self,
                x: Tensor,
                conv_state: Tensor,
                host_request_types: Tensor,
                last_token_ids: Tensor,
                host_context_lengths: Optional[Tensor] = None,
                slot_mapping: Optional[Tensor] = None,
                conv_indices: Optional[Tensor] = None)

Source from the content-addressed store, hash-verified

46 self.apply_silu = apply_silu
47
48 def forward(self,
49 x: Tensor,
50 conv_state: Tensor,
51 host_request_types: Tensor,
52 last_token_ids: Tensor,
53 host_context_lengths: Optional[Tensor] = None,
54 slot_mapping: Optional[Tensor] = None,
55 conv_indices: Optional[Tensor] = None):
56 '''
57 Parameters:
58 x: [B, L, D] or [T, D]
59 conv_state: [B, W, D] or [1] of type int64 for paged state
60 host_request_types: [B]
61 last_token_ids: [B]
62 host_context_lengths: [B]
63 slot_mapping: [B]
64 conv_indices: [B]
65 '''
66 if default_net().plugin_config.mamba_conv1d_plugin:
67 transposed_weight = permute(
68 view(self.weight.value, shape=[self.d_inner, 1, self.d_conv]),
69 (1, 2, 0))
70 x_conv, conv_state = mamba_conv1d(
71 x, conv_state, transposed_weight, self.bias.value,
72 host_request_types, last_token_ids, self.d_inner, self.d_conv,
73 self.dtype, self.pre_stride, self.post_stride,
74 host_context_lengths, slot_mapping, self.apply_silu)
75 else:
76 assert not default_net().plugin_config.paged_state
77 assert len(
78 x.shape
79 ) == 3, "remove_input_padding is not supported by OOTB for Mamba."
80 if self.pre_stride > 0:
81 _, x = split(x,
82 [self.pre_stride, self.d_inner + self.post_stride],
83 dim=-1)
84 if self.post_stride > 0:
85 x, _ = split(x, [self.d_inner, self.post_stride], dim=-1)
86 x = x.permute([0, 2, 1])
87
88 # In context phase, conv_state is a zero tensor, and it is used for padding
89 # In generation phase, conv_state is a tensor of the past x
90 x_pad = concat([conv_state, x], dim=2)
91
92 # Update conv_state
93 conv_state = gather(x_pad, 2, conv_indices)
94
95 # Convolution
96 x_pad = x_pad.view(
97 concat([shape(x_pad, 0),
98 shape(x_pad, 1),
99 shape(x_pad, 2), 1]))
100 x_conv = conv2d(x_pad,
101 self.weight.value,
102 self.bias.value,
103 groups=self.d_inner)
104 if self.apply_silu:
105 x_conv = ACT2FN['silu'](x_conv)

Callers

nothing calls this directly

Calls 11

default_netFunction · 0.85
viewFunction · 0.85
mamba_conv1dFunction · 0.85
concatFunction · 0.85
conv2dFunction · 0.85
permuteFunction · 0.50
splitFunction · 0.50
gatherFunction · 0.50
shapeFunction · 0.50
permuteMethod · 0.45
viewMethod · 0.45

Tested by

no test coverage detected