(is_variable_B, is_variable_C, seqlen, itype, wtype)
| 158 | @pytest.mark.parametrize("is_variable_B", [False, True]) |
| 159 | # @pytest.mark.parametrize("is_variable_B", [True]) |
| 160 | def test_mamba_inner_fn(is_variable_B, is_variable_C, seqlen, itype, wtype): |
| 161 | device = 'cuda' |
| 162 | rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3) |
| 163 | if itype == torch.bfloat16: |
| 164 | rtol, atol = 3e-2, 5e-2 |
| 165 | rtolw, atolw = (1e-3, 1e-3) |
| 166 | # If we have z, the errors on the weights seem higher |
| 167 | rtolw = max(rtolw, rtol) |
| 168 | atolw = max(atolw, atol) |
| 169 | # set seed |
| 170 | torch.random.manual_seed(0) |
| 171 | batch_size = 2 |
| 172 | dim = 768 |
| 173 | dstate = 8 |
| 174 | dt_rank = 48 |
| 175 | is_complex = wtype == torch.complex64 |
| 176 | xz = torch.randn(batch_size, 2 * dim, seqlen, device=device, dtype=itype, requires_grad=True) |
| 177 | conv1d_weight = torch.randn(dim, 1, 3, device=device, dtype=torch.float32, requires_grad=True) |
| 178 | conv1d_bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) |
| 179 | x_proj_weight = torch.randn(dt_rank + (bool(is_variable_B) + bool(is_variable_C)) * dstate |
| 180 | * (1 if not is_complex else 2), |
| 181 | dim, device=device, dtype=itype, requires_grad=True) |
| 182 | delta_proj_weight = torch.randn(dim, dt_rank, device=device, dtype=itype, requires_grad=True) |
| 183 | out_proj_weight = torch.randn(dim // 2, dim, device=device, dtype=itype, requires_grad=True) |
| 184 | out_proj_bias = None |
| 185 | A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_() |
| 186 | B = (torch.randn(dim, dstate, device=device, dtype=wtype, requires_grad=True) |
| 187 | if not is_variable_B else None) |
| 188 | C = (torch.randn(dim, dstate, device=device, dtype=wtype, requires_grad=True) |
| 189 | if not is_variable_C else None) |
| 190 | D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) |
| 191 | delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)).requires_grad_() |
| 192 | B_proj_bias = None |
| 193 | C_proj_bias = None |
| 194 | xz_ref = xz.detach().clone().requires_grad_() |
| 195 | conv1d_weight_ref = conv1d_weight.detach().clone().requires_grad_() |
| 196 | conv1d_bias_ref = conv1d_bias.detach().clone().requires_grad_() |
| 197 | x_proj_weight_ref = x_proj_weight.detach().clone().requires_grad_() |
| 198 | delta_proj_weight_ref = delta_proj_weight.detach().clone().requires_grad_() |
| 199 | out_proj_weight_ref = out_proj_weight.detach().clone().requires_grad_() |
| 200 | out_proj_bias_ref = (out_proj_bias.detach().clone().requires_grad_() |
| 201 | if out_proj_bias is not None else None) |
| 202 | A_ref = A.detach().clone().requires_grad_() |
| 203 | B_ref = B.detach().clone().requires_grad_() if B is not None else None |
| 204 | C_ref = C.detach().clone().requires_grad_() if C is not None else None |
| 205 | D_ref = D.detach().clone().requires_grad_() |
| 206 | delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None |
| 207 | out = mamba_inner_fn(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, |
| 208 | out_proj_weight, out_proj_bias, |
| 209 | A, B, C, D, delta_bias=delta_bias, delta_softplus=True) |
| 210 | out_ref = mamba_inner_ref(xz_ref, conv1d_weight_ref, conv1d_bias_ref, x_proj_weight_ref, |
| 211 | delta_proj_weight_ref, out_proj_weight_ref, out_proj_bias_ref, |
| 212 | A_ref, B_ref, C_ref, D_ref, |
| 213 | delta_bias=delta_bias_ref, delta_softplus=True) |
| 214 | # dA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) |
| 215 | # dt_u = delta * u |
| 216 | |
| 217 | print(f'Output max diff: {(out - out_ref).abs().max().item()}') |
nothing calls this directly
no test coverage detected