MCPcopy Index your code
hub / github.com/kyegomez/BitNet / backward

Method backward

bitnet/bit_mamba.py:103–134  ·  view source on GitHub ↗

Flows the gradient from the output to the input. Returns two new tensors. Args: ctx : A_in : (B, L, D, N), X : (B, D, L, N) grad_output_in : (B, L, D, N) Returns: gradA : (B, L, D, N), gradX : (B, L, D, N)

(ctx, grad_output_in)

Source from the content-addressed store, hash-verified

101
102 @staticmethod
103 def backward(ctx, grad_output_in):
104 """
105 Flows the gradient from the output to the input. Returns two new tensors.
106
107 Args:
108 ctx : A_in : (B, L, D, N), X : (B, D, L, N)
109 grad_output_in : (B, L, D, N)
110
111 Returns:
112 gradA : (B, L, D, N), gradX : (B, L, D, N)
113 """
114
115 A_in, X = ctx.saved_tensors
116
117 # clone tensors
118 A = A_in.clone()
119 # grad_output_in will be cloned with flip()
120
121 # prepare tensors
122 A = A.transpose(2, 1) # (B, D, L, N)
123 A = torch.cat((A[:, :, :1], A[:, :, 1:].flip(2)), dim=2)
124 grad_output_b = grad_output_in.transpose(2, 1)
125
126 # reverse parallel scan
127 grad_output_b = grad_output_b.flip(2)
128 PScan.pscan(A, grad_output_b)
129 grad_output_b = grad_output_b.flip(2)
130
131 Q = torch.zeros_like(X)
132 Q[:, :, 1:].add_(X[:, :, :-1] * grad_output_b[:, :, 1:])
133
134 return Q.transpose(2, 1), grad_output_b.transpose(2, 1)
135
136
137pscan = PScan.apply

Callers 1

train.pyFile · 0.80

Calls 1

pscanMethod · 0.80

Tested by

no test coverage detected