Parameters: input : Tensor (On GPU) The input tensor. Its shape is [batch_size, seq_len, dim] state_or_ptr : Tensor (On GPU or CPU) The ssm state tensor. Its shape is [batch_size, dstate, dim] Or the CPU tensor of shape [1] for the pointer of
(input: Tensor,
state_or_ptr: Tensor,
delta: Tensor,
delta_bias: Tensor,
A: Tensor,
BC: Tensor,
D: Tensor,
host_request_types: Tensor,
last_token_ids: Tensor,
dim: int,
dstate: int,
dt_rank: int,
delta_softplus: bool,
dtype: str,
z: Optional[Tensor] = None,
host_context_lengths: Optional[Tensor] = None,
slot_mapping: Optional[Tensor] = None,
nheads: int = 1,
ngroups: int = 1,
chunk_size: int = 256,
mamba_version: str = 'Mamba1')
| 6958 | |
| 6959 | |
| 6960 | def selective_scan(input: Tensor, |
| 6961 | state_or_ptr: Tensor, |
| 6962 | delta: Tensor, |
| 6963 | delta_bias: Tensor, |
| 6964 | A: Tensor, |
| 6965 | BC: Tensor, |
| 6966 | D: Tensor, |
| 6967 | host_request_types: Tensor, |
| 6968 | last_token_ids: Tensor, |
| 6969 | dim: int, |
| 6970 | dstate: int, |
| 6971 | dt_rank: int, |
| 6972 | delta_softplus: bool, |
| 6973 | dtype: str, |
| 6974 | z: Optional[Tensor] = None, |
| 6975 | host_context_lengths: Optional[Tensor] = None, |
| 6976 | slot_mapping: Optional[Tensor] = None, |
| 6977 | nheads: int = 1, |
| 6978 | ngroups: int = 1, |
| 6979 | chunk_size: int = 256, |
| 6980 | mamba_version: str = 'Mamba1'): |
| 6981 | ''' |
| 6982 | Parameters: |
| 6983 | input : Tensor (On GPU) |
| 6984 | The input tensor. Its shape is [batch_size, seq_len, dim] |
| 6985 | |
| 6986 | state_or_ptr : Tensor (On GPU or CPU) |
| 6987 | The ssm state tensor. Its shape is [batch_size, dstate, dim] |
| 6988 | Or the CPU tensor of shape [1] for the pointer of paged states. |
| 6989 | |
| 6990 | delta : Tensor (On GPU) |
| 6991 | The delta tensor. |
| 6992 | mamba: Its shape is [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding |
| 6993 | mamba2: Its shape is [batch_size, seq_len, nheads] or [num_tokens, nheads] for remove_input_padding |
| 6994 | |
| 6995 | delta_bias : Tensor (On GPU) |
| 6996 | The delta bias tensor. |
| 6997 | mamba: Its shape is [dim] |
| 6998 | mamba2: Its shape is [nheads] |
| 6999 | |
| 7000 | A : Tensor (On GPU) |
| 7001 | A matrix. |
| 7002 | mamba: Its shape is [dstate, dim] |
| 7003 | mamba2: Its shape is [nheads] |
| 7004 | |
| 7005 | BC : Tensor (On GPU) |
| 7006 | B and C matrix. |
| 7007 | mamba: Its shape is [batch_size, seq_len, dstate * 2] or [num_tokens, dstate * 2] for remove_input_padding |
| 7008 | mamba2: Its shape is [batch_size, seq_len, ngroups * dstate * 2] or [num_tokens, ngroups * dstate * 2] for remove_input_padding |
| 7009 | |
| 7010 | D : Tensor (On GPU) |
| 7011 | D matrix. |
| 7012 | mamba: Its shape is [dim] |
| 7013 | mamba2: Its shape is [nheads] |
| 7014 | |
| 7015 | host_request_types : Tensor (On CPU) |
| 7016 | The tensor on the host that indicates if a request is in context or |
| 7017 | generation phase. Its shape is [batch_size]. See Inflight Batching |
no test coverage detected