MCPcopy Index your code
hub / github.com/kyegomez/BitNet / forward

Method forward

bitnet/bit_mamba.py:75–100  ·  view source on GitHub ↗

Applies the parallel scan operation, as defined above. Returns a new tensor. Args: A_in : (B, L, D, N) X_in : (B, L, D, N) Returns: H : (B, L, D, N)

(ctx, A_in, X_in)

Source from the content-addressed store, hash-verified

73
74 @staticmethod
75 def forward(ctx, A_in, X_in):
76 """
77 Applies the parallel scan operation, as defined above. Returns a new tensor.
78
79 Args:
80 A_in : (B, L, D, N)
81 X_in : (B, L, D, N)
82
83 Returns:
84 H : (B, L, D, N)
85 """
86
87 # clone tensor (in-place ops)
88 A = A_in.clone() # (B, L, D, N)
89 X = X_in.clone() # (B, L, D, N)
90
91 # prepare tensors
92 A = A.transpose(2, 1) # (B, D, L, N)
93 X = X.transpose(2, 1) # (B, D, L, N)
94
95 # parallel scan
96 PScan.pscan(A, X)
97
98 ctx.save_for_backward(A_in, X)
99
100 return X.transpose(2, 1)
101
102 @staticmethod
103 def backward(ctx, grad_output_in):

Callers

nothing calls this directly

Calls 1

pscanMethod · 0.80

Tested by

no test coverage detected