(dim, seqlen, width, has_bias, silu_activation, itype, channel_last, has_initial_states, return_final_states)
| 36 | @pytest.mark.parametrize('dim', [64, 4096 + 32]) |
| 37 | # @pytest.mark.parametrize('dim', [64]) |
| 38 | def test_causal_conv1d(dim, seqlen, width, has_bias, silu_activation, itype, channel_last, has_initial_states, return_final_states): |
| 39 | if not channel_last and (has_initial_states or return_final_states): |
| 40 | pytest.skip("Only channel_last support initial_states or return_final_states") |
| 41 | device = "cuda" |
| 42 | rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) |
| 43 | if itype == torch.bfloat16: |
| 44 | rtol, atol = 1e-2, 5e-2 |
| 45 | rtolw, atolw = (1e-3, 1e-3) |
| 46 | # set seed |
| 47 | torch.random.manual_seed(0) |
| 48 | batch = 2 |
| 49 | # batch = 1 |
| 50 | if not channel_last: |
| 51 | x = torch.randn(batch, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :].requires_grad_() |
| 52 | else: |
| 53 | x = rearrange( |
| 54 | torch.randn(batch, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s" |
| 55 | ).requires_grad_() |
| 56 | weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True) |
| 57 | if has_bias: |
| 58 | bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) |
| 59 | else: |
| 60 | bias = None |
| 61 | if has_initial_states: |
| 62 | initial_states = torch.randn(batch, width - 1, dim, device=device, dtype=itype).transpose(1, 2).requires_grad_() |
| 63 | else: |
| 64 | initial_states = None |
| 65 | x_ref = x.detach().clone().requires_grad_() |
| 66 | weight_ref = weight.detach().clone().requires_grad_() |
| 67 | bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None |
| 68 | initial_states_ref = initial_states.detach().clone().requires_grad_() if initial_states is not None else None |
| 69 | activation = None if not silu_activation else "silu" |
| 70 | out = causal_conv1d_fn(x, weight, bias, initial_states=initial_states, return_final_states=return_final_states, |
| 71 | activation=activation) |
| 72 | out_ref = causal_conv1d_ref(x_ref, weight_ref, bias_ref, initial_states=initial_states_ref, return_final_states=return_final_states, activation=activation) |
| 73 | if return_final_states: |
| 74 | out, final_states = out |
| 75 | out_ref, final_states_ref = out_ref |
| 76 | print(f"Final states max diff: {(final_states - final_states_ref).abs().max().item()}") |
| 77 | print(f"Final states mean diff: {(final_states - final_states_ref).abs().mean().item()}") |
| 78 | assert torch.allclose(final_states, final_states_ref, rtol=rtol, atol=atol) |
| 79 | |
| 80 | print(f"Output max diff: {(out - out_ref).abs().max().item()}") |
| 81 | print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") |
| 82 | assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) |
| 83 | |
| 84 | if return_final_states: |
| 85 | out += F.sigmoid(final_states).sum(dim=-1, keepdim=True) |
| 86 | out_ref += F.sigmoid(final_states_ref).sum(dim=-1, keepdim=True) |
| 87 | |
| 88 | g = torch.randn_like(out) |
| 89 | out.backward(g) |
| 90 | out_ref.backward(g) |
| 91 | |
| 92 | print(f"dx max diff: {(x.grad - x_ref.grad).abs().max().item()}") |
| 93 | print(f"dweight max diff: {(weight.grad - weight_ref.grad).abs().max().item()}") |
| 94 | if has_bias: |
| 95 | print(f"dbias max diff: {(bias.grad - bias_ref.grad).abs().max().item()}") |
nothing calls this directly
no test coverage detected