MCPcopy
hub / github.com/hustvl/Vim / test_mamba_inner_fn

Function test_mamba_inner_fn

mamba-1p1p1/tests/ops/test_selective_scan.py:160–247  ·  view source on GitHub ↗
(is_variable_B, is_variable_C, seqlen, itype, wtype)

Source from the content-addressed store, hash-verified

158@pytest.mark.parametrize("is_variable_B", [False, True])
159# @pytest.mark.parametrize("is_variable_B", [True])
160def test_mamba_inner_fn(is_variable_B, is_variable_C, seqlen, itype, wtype):
161 device = 'cuda'
162 rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3)
163 if itype == torch.bfloat16:
164 rtol, atol = 3e-2, 5e-2
165 rtolw, atolw = (1e-3, 1e-3)
166 # If we have z, the errors on the weights seem higher
167 rtolw = max(rtolw, rtol)
168 atolw = max(atolw, atol)
169 # set seed
170 torch.random.manual_seed(0)
171 batch_size = 2
172 dim = 768
173 dstate = 8
174 dt_rank = 48
175 is_complex = wtype == torch.complex64
176 xz = torch.randn(batch_size, 2 * dim, seqlen, device=device, dtype=itype, requires_grad=True)
177 conv1d_weight = torch.randn(dim, 1, 3, device=device, dtype=torch.float32, requires_grad=True)
178 conv1d_bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
179 x_proj_weight = torch.randn(dt_rank + (bool(is_variable_B) + bool(is_variable_C)) * dstate
180 * (1 if not is_complex else 2),
181 dim, device=device, dtype=itype, requires_grad=True)
182 delta_proj_weight = torch.randn(dim, dt_rank, device=device, dtype=itype, requires_grad=True)
183 out_proj_weight = torch.randn(dim // 2, dim, device=device, dtype=itype, requires_grad=True)
184 out_proj_bias = None
185 A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_()
186 B = (torch.randn(dim, dstate, device=device, dtype=wtype, requires_grad=True)
187 if not is_variable_B else None)
188 C = (torch.randn(dim, dstate, device=device, dtype=wtype, requires_grad=True)
189 if not is_variable_C else None)
190 D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
191 delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)).requires_grad_()
192 B_proj_bias = None
193 C_proj_bias = None
194 xz_ref = xz.detach().clone().requires_grad_()
195 conv1d_weight_ref = conv1d_weight.detach().clone().requires_grad_()
196 conv1d_bias_ref = conv1d_bias.detach().clone().requires_grad_()
197 x_proj_weight_ref = x_proj_weight.detach().clone().requires_grad_()
198 delta_proj_weight_ref = delta_proj_weight.detach().clone().requires_grad_()
199 out_proj_weight_ref = out_proj_weight.detach().clone().requires_grad_()
200 out_proj_bias_ref = (out_proj_bias.detach().clone().requires_grad_()
201 if out_proj_bias is not None else None)
202 A_ref = A.detach().clone().requires_grad_()
203 B_ref = B.detach().clone().requires_grad_() if B is not None else None
204 C_ref = C.detach().clone().requires_grad_() if C is not None else None
205 D_ref = D.detach().clone().requires_grad_()
206 delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None
207 out = mamba_inner_fn(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
208 out_proj_weight, out_proj_bias,
209 A, B, C, D, delta_bias=delta_bias, delta_softplus=True)
210 out_ref = mamba_inner_ref(xz_ref, conv1d_weight_ref, conv1d_bias_ref, x_proj_weight_ref,
211 delta_proj_weight_ref, out_proj_weight_ref, out_proj_bias_ref,
212 A_ref, B_ref, C_ref, D_ref,
213 delta_bias=delta_bias_ref, delta_softplus=True)
214 # dA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
215 # dt_u = delta * u
216
217 print(f'Output max diff: {(out - out_ref).abs().max().item()}')

Callers

nothing calls this directly

Calls 6

mamba_inner_fnFunction · 0.90
mamba_inner_refFunction · 0.90
printFunction · 0.85
maxMethod · 0.80
cloneMethod · 0.45
backwardMethod · 0.45

Tested by

no test coverage detected