MCPcopy Index your code
hub / github.com/kyegomez/BitNet / pscan

Method pscan

bitnet/bit_mamba.py:30–72  ·  view source on GitHub ↗
(A, X)

Source from the content-addressed store, hash-verified

28class PScan(torch.autograd.Function):
29 @staticmethod
30 def pscan(A, X):
31 # A : (B, D, L, N)
32 # X : (B, D, L, N)
33
34 # modifies X in place by doing a parallel scan.
35 # more formally, X will be populated by these values :
36 # H[t] = A[t] * H[t-1] + X[t] with H[0] = 0
37 # which are computed in parallel (2*log2(T) sequential steps (ideally), instead of T sequential steps)
38
39 B, D, L, _ = A.size()
40 num_steps = int(math.log2(L))
41
42 # up sweep or reduction step
43 Aa = A
44 Xa = X
45 for k in range(num_steps):
46 T = 2 * (Xa.size(2) // 2)
47
48 Aa = Aa[:, :, :T].view(B, D, T // 2, 2, -1)
49 Xa = Xa[:, :, :T].view(B, D, T // 2, 2, -1)
50
51 Xa[:, :, :, 1].add_(Aa[:, :, :, 1].mul(Xa[:, :, :, 0]))
52 Aa[:, :, :, 1].mul_(Aa[:, :, :, 0])
53
54 Aa = Aa[:, :, :, 1]
55 Xa = Xa[:, :, :, 1]
56
57 # down sweep
58 for k in range(num_steps - 1, -1, -1):
59 Aa = A[:, :, 2**k - 1 : L : 2**k]
60 Xa = X[:, :, 2**k - 1 : L : 2**k]
61
62 T = 2 * (Xa.size(2) // 2)
63
64 if T < Xa.size(2):
65 Xa[:, :, -1].add_(Aa[:, :, -1].mul(Xa[:, :, -2]))
66 Aa[:, :, -1].mul_(Aa[:, :, -2])
67
68 Aa = Aa[:, :, :T].view(B, D, T // 2, 2, -1)
69 Xa = Xa[:, :, :T].view(B, D, T // 2, 2, -1)
70
71 Xa[:, :, 1:, 0].add_(Aa[:, :, 1:, 0].mul(Xa[:, :, :-1, 1]))
72 Aa[:, :, 1:, 0].mul_(Aa[:, :, :-1, 1])
73
74 @staticmethod
75 def forward(ctx, A_in, X_in):

Callers 2

forwardMethod · 0.80
backwardMethod · 0.80

Calls

no outgoing calls

Tested by

no test coverage detected