MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / mamba_conv1d

Function mamba_conv1d

tensorrt_llm/functional.py:6839–6957  ·  view source on GitHub ↗

Parameters: input : Tensor (On GPU) The input tensor. Its shape is [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding conv_state_or_ptr : Tensor (On GPU or CPU) The conv state tensor. Its shape is [batch_size, dconv - 1, dim]

(input: Tensor,
                 conv_state_or_ptr: Tensor,
                 conv_weight: Tensor,
                 conv_bias: Tensor,
                 host_request_types: Tensor,
                 last_token_ids: Tensor,
                 dim: int,
                 dconv: int,
                 dtype: str,
                 pre_stride: int = 0,
                 post_stride: int = 0,
                 host_context_lengths: Optional[Tensor] = None,
                 slot_mapping: Optional[Tensor] = None,
                 apply_silu: bool = True)

Source from the content-addressed store, hash-verified

6837
6838
6839def mamba_conv1d(input: Tensor,
6840 conv_state_or_ptr: Tensor,
6841 conv_weight: Tensor,
6842 conv_bias: Tensor,
6843 host_request_types: Tensor,
6844 last_token_ids: Tensor,
6845 dim: int,
6846 dconv: int,
6847 dtype: str,
6848 pre_stride: int = 0,
6849 post_stride: int = 0,
6850 host_context_lengths: Optional[Tensor] = None,
6851 slot_mapping: Optional[Tensor] = None,
6852 apply_silu: bool = True):
6853 '''
6854 Parameters:
6855 input : Tensor (On GPU)
6856 The input tensor. Its shape is [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding
6857
6858 conv_state_or_ptr : Tensor (On GPU or CPU)
6859 The conv state tensor. Its shape is [batch_size, dconv - 1, dim]
6860 Or the CPU tensor of shape [1] for the pointer of paged states.
6861
6862 conv_weight : Tensor (On GPU)
6863 The weight tensor. Its shape is [1, dconv, dim]
6864
6865 conv_bias : Tensor (On GPU)
6866 The bias tensor. Its shape is [dim]
6867
6868 host_request_types : Tensor (On CPU)
6869 The tensor on the host that indicates if a request is in context or
6870 generation phase. Its shape is [batch_size]. See Inflight Batching
6871 in docs/source/advanced/gpt-attention.md,
6872
6873 last_token_ids : Tensor (On GPU)
6874 The inclusive prefix-sum of the lengths or the lengths of the
6875 sequences in the batch.
6876
6877 dim : int
6878 The hidden dimension of conv1d
6879
6880 dconv : int
6881 The window size of conv1d
6882
6883 dtype: str
6884 data type
6885
6886 pre_stride : int = 0
6887 The (pre) stride size of the input tensor.
6888 The valid values of the input tensor are input[..., pre_stride: dim-post_stride]
6889
6890 post_stride : int = 0
6891 The (post) stride size of the input tensor.
6892 The valid values of the input tensor are input[..., pre_stride: dim-post_stride]
6893
6894 host_context_lengths: Tensor (On CPU) (Optional)
6895 A host tensor that contains the lengths of the different inputs,
6896

Callers 1

forwardMethod · 0.85

Calls 8

str_dtype_to_trtFunction · 0.85
default_netFunction · 0.85
default_trtnetFunction · 0.85
_add_plugin_infoFunction · 0.85
_create_tensorFunction · 0.85
int8Method · 0.80
create_pluginMethod · 0.80
get_outputMethod · 0.45

Tested by

no test coverage detected