Applies the parallel scan operation, as defined above. Returns a new tensor. Args: A_in : (B, L, D, N) X_in : (B, L, D, N) Returns: H : (B, L, D, N)
(ctx, A_in, X_in)
| 73 | |
| 74 | @staticmethod |
| 75 | def forward(ctx, A_in, X_in): |
| 76 | """ |
| 77 | Applies the parallel scan operation, as defined above. Returns a new tensor. |
| 78 | |
| 79 | Args: |
| 80 | A_in : (B, L, D, N) |
| 81 | X_in : (B, L, D, N) |
| 82 | |
| 83 | Returns: |
| 84 | H : (B, L, D, N) |
| 85 | """ |
| 86 | |
| 87 | # clone tensor (in-place ops) |
| 88 | A = A_in.clone() # (B, L, D, N) |
| 89 | X = X_in.clone() # (B, L, D, N) |
| 90 | |
| 91 | # prepare tensors |
| 92 | A = A.transpose(2, 1) # (B, D, L, N) |
| 93 | X = X.transpose(2, 1) # (B, D, L, N) |
| 94 | |
| 95 | # parallel scan |
| 96 | PScan.pscan(A, X) |
| 97 | |
| 98 | ctx.save_for_backward(A_in, X) |
| 99 | |
| 100 | return X.transpose(2, 1) |
| 101 | |
| 102 | @staticmethod |
| 103 | def backward(ctx, grad_output_in): |