MCPcopy Index your code
hub / github.com/NVIDIA/TensorRT-LLM / selective_scan

Function selective_scan

tensorrt_llm/functional.py:6960–7129  ·  view source on GitHub ↗

Parameters: input : Tensor (On GPU) The input tensor. Its shape is [batch_size, seq_len, dim] state_or_ptr : Tensor (On GPU or CPU) The ssm state tensor. Its shape is [batch_size, dstate, dim] Or the CPU tensor of shape [1] for the pointer of

(input: Tensor,
                   state_or_ptr: Tensor,
                   delta: Tensor,
                   delta_bias: Tensor,
                   A: Tensor,
                   BC: Tensor,
                   D: Tensor,
                   host_request_types: Tensor,
                   last_token_ids: Tensor,
                   dim: int,
                   dstate: int,
                   dt_rank: int,
                   delta_softplus: bool,
                   dtype: str,
                   z: Optional[Tensor] = None,
                   host_context_lengths: Optional[Tensor] = None,
                   slot_mapping: Optional[Tensor] = None,
                   nheads: int = 1,
                   ngroups: int = 1,
                   chunk_size: int = 256,
                   mamba_version: str = 'Mamba1')

Source from the content-addressed store, hash-verified

6958
6959
6960def selective_scan(input: Tensor,
6961 state_or_ptr: Tensor,
6962 delta: Tensor,
6963 delta_bias: Tensor,
6964 A: Tensor,
6965 BC: Tensor,
6966 D: Tensor,
6967 host_request_types: Tensor,
6968 last_token_ids: Tensor,
6969 dim: int,
6970 dstate: int,
6971 dt_rank: int,
6972 delta_softplus: bool,
6973 dtype: str,
6974 z: Optional[Tensor] = None,
6975 host_context_lengths: Optional[Tensor] = None,
6976 slot_mapping: Optional[Tensor] = None,
6977 nheads: int = 1,
6978 ngroups: int = 1,
6979 chunk_size: int = 256,
6980 mamba_version: str = 'Mamba1'):
6981 '''
6982 Parameters:
6983 input : Tensor (On GPU)
6984 The input tensor. Its shape is [batch_size, seq_len, dim]
6985
6986 state_or_ptr : Tensor (On GPU or CPU)
6987 The ssm state tensor. Its shape is [batch_size, dstate, dim]
6988 Or the CPU tensor of shape [1] for the pointer of paged states.
6989
6990 delta : Tensor (On GPU)
6991 The delta tensor.
6992 mamba: Its shape is [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding
6993 mamba2: Its shape is [batch_size, seq_len, nheads] or [num_tokens, nheads] for remove_input_padding
6994
6995 delta_bias : Tensor (On GPU)
6996 The delta bias tensor.
6997 mamba: Its shape is [dim]
6998 mamba2: Its shape is [nheads]
6999
7000 A : Tensor (On GPU)
7001 A matrix.
7002 mamba: Its shape is [dstate, dim]
7003 mamba2: Its shape is [nheads]
7004
7005 BC : Tensor (On GPU)
7006 B and C matrix.
7007 mamba: Its shape is [batch_size, seq_len, dstate * 2] or [num_tokens, dstate * 2] for remove_input_padding
7008 mamba2: Its shape is [batch_size, seq_len, ngroups * dstate * 2] or [num_tokens, ngroups * dstate * 2] for remove_input_padding
7009
7010 D : Tensor (On GPU)
7011 D matrix.
7012 mamba: Its shape is [dim]
7013 mamba2: Its shape is [nheads]
7014
7015 host_request_types : Tensor (On CPU)
7016 The tensor on the host that indicates if a request is in context or
7017 generation phase. Its shape is [batch_size]. See Inflight Batching

Callers 2

forwardMethod · 0.85
forwardMethod · 0.85

Calls 8

str_dtype_to_trtFunction · 0.85
default_netFunction · 0.85
default_trtnetFunction · 0.85
_add_plugin_infoFunction · 0.85
_create_tensorFunction · 0.85
int8Method · 0.80
create_pluginMethod · 0.80
get_outputMethod · 0.45

Tested by

no test coverage detected