MCPcopy
hub / github.com/hustvl/Vim / test_causal_conv1d

Function test_causal_conv1d

causal-conv1d/tests/test_causal_conv1d.py:38–104  ·  view source on GitHub ↗
(dim, seqlen, width, has_bias, silu_activation, itype, channel_last, has_initial_states, return_final_states)

Source from the content-addressed store, hash-verified

36@pytest.mark.parametrize('dim', [64, 4096 + 32])
37# @pytest.mark.parametrize('dim', [64])
38def test_causal_conv1d(dim, seqlen, width, has_bias, silu_activation, itype, channel_last, has_initial_states, return_final_states):
39 if not channel_last and (has_initial_states or return_final_states):
40 pytest.skip("Only channel_last support initial_states or return_final_states")
41 device = "cuda"
42 rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
43 if itype == torch.bfloat16:
44 rtol, atol = 1e-2, 5e-2
45 rtolw, atolw = (1e-3, 1e-3)
46 # set seed
47 torch.random.manual_seed(0)
48 batch = 2
49 # batch = 1
50 if not channel_last:
51 x = torch.randn(batch, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :].requires_grad_()
52 else:
53 x = rearrange(
54 torch.randn(batch, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s"
55 ).requires_grad_()
56 weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
57 if has_bias:
58 bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
59 else:
60 bias = None
61 if has_initial_states:
62 initial_states = torch.randn(batch, width - 1, dim, device=device, dtype=itype).transpose(1, 2).requires_grad_()
63 else:
64 initial_states = None
65 x_ref = x.detach().clone().requires_grad_()
66 weight_ref = weight.detach().clone().requires_grad_()
67 bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None
68 initial_states_ref = initial_states.detach().clone().requires_grad_() if initial_states is not None else None
69 activation = None if not silu_activation else "silu"
70 out = causal_conv1d_fn(x, weight, bias, initial_states=initial_states, return_final_states=return_final_states,
71 activation=activation)
72 out_ref = causal_conv1d_ref(x_ref, weight_ref, bias_ref, initial_states=initial_states_ref, return_final_states=return_final_states, activation=activation)
73 if return_final_states:
74 out, final_states = out
75 out_ref, final_states_ref = out_ref
76 print(f"Final states max diff: {(final_states - final_states_ref).abs().max().item()}")
77 print(f"Final states mean diff: {(final_states - final_states_ref).abs().mean().item()}")
78 assert torch.allclose(final_states, final_states_ref, rtol=rtol, atol=atol)
79
80 print(f"Output max diff: {(out - out_ref).abs().max().item()}")
81 print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
82 assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
83
84 if return_final_states:
85 out += F.sigmoid(final_states).sum(dim=-1, keepdim=True)
86 out_ref += F.sigmoid(final_states_ref).sum(dim=-1, keepdim=True)
87
88 g = torch.randn_like(out)
89 out.backward(g)
90 out_ref.backward(g)
91
92 print(f"dx max diff: {(x.grad - x_ref.grad).abs().max().item()}")
93 print(f"dweight max diff: {(weight.grad - weight_ref.grad).abs().max().item()}")
94 if has_bias:
95 print(f"dbias max diff: {(bias.grad - bias_ref.grad).abs().max().item()}")

Callers

nothing calls this directly

Calls 7

causal_conv1d_fnFunction · 0.90
causal_conv1d_refFunction · 0.90
printFunction · 0.85
maxMethod · 0.80
cloneMethod · 0.45
backwardMethod · 0.45
toMethod · 0.45

Tested by

no test coverage detected