| 28 | class PScan(torch.autograd.Function): |
| 29 | @staticmethod |
| 30 | def pscan(A, X): |
| 31 | # A : (B, D, L, N) |
| 32 | # X : (B, D, L, N) |
| 33 | |
| 34 | # modifies X in place by doing a parallel scan. |
| 35 | # more formally, X will be populated by these values : |
| 36 | # H[t] = A[t] * H[t-1] + X[t] with H[0] = 0 |
| 37 | # which are computed in parallel (2*log2(T) sequential steps (ideally), instead of T sequential steps) |
| 38 | |
| 39 | B, D, L, _ = A.size() |
| 40 | num_steps = int(math.log2(L)) |
| 41 | |
| 42 | # up sweep or reduction step |
| 43 | Aa = A |
| 44 | Xa = X |
| 45 | for k in range(num_steps): |
| 46 | T = 2 * (Xa.size(2) // 2) |
| 47 | |
| 48 | Aa = Aa[:, :, :T].view(B, D, T // 2, 2, -1) |
| 49 | Xa = Xa[:, :, :T].view(B, D, T // 2, 2, -1) |
| 50 | |
| 51 | Xa[:, :, :, 1].add_(Aa[:, :, :, 1].mul(Xa[:, :, :, 0])) |
| 52 | Aa[:, :, :, 1].mul_(Aa[:, :, :, 0]) |
| 53 | |
| 54 | Aa = Aa[:, :, :, 1] |
| 55 | Xa = Xa[:, :, :, 1] |
| 56 | |
| 57 | # down sweep |
| 58 | for k in range(num_steps - 1, -1, -1): |
| 59 | Aa = A[:, :, 2**k - 1 : L : 2**k] |
| 60 | Xa = X[:, :, 2**k - 1 : L : 2**k] |
| 61 | |
| 62 | T = 2 * (Xa.size(2) // 2) |
| 63 | |
| 64 | if T < Xa.size(2): |
| 65 | Xa[:, :, -1].add_(Aa[:, :, -1].mul(Xa[:, :, -2])) |
| 66 | Aa[:, :, -1].mul_(Aa[:, :, -2]) |
| 67 | |
| 68 | Aa = Aa[:, :, :T].view(B, D, T // 2, 2, -1) |
| 69 | Xa = Xa[:, :, :T].view(B, D, T // 2, 2, -1) |
| 70 | |
| 71 | Xa[:, :, 1:, 0].add_(Aa[:, :, 1:, 0].mul(Xa[:, :, :-1, 1])) |
| 72 | Aa[:, :, 1:, 0].mul_(Aa[:, :, :-1, 1]) |
| 73 | |
| 74 | @staticmethod |
| 75 | def forward(ctx, A_in, X_in): |