Actual computation of gradient of the loss wrt. X, W, and b
(self, dLdy, X)
| 2528 | return dX[0] if len(X) == 1 else dX |
| 2529 | |
| 2530 | def _bwd(self, dLdy, X): |
| 2531 | """Actual computation of gradient of the loss wrt. X, W, and b""" |
| 2532 | W = self.parameters["W"] |
| 2533 | b = self.parameters["b"] |
| 2534 | W_sparse = W * self.parameters["W_mask"] |
| 2535 | |
| 2536 | Z = X @ W_sparse + b |
| 2537 | dZ = dLdy * self.act_fn.grad(Z) |
| 2538 | |
| 2539 | dX = dZ @ W_sparse.T |
| 2540 | dW = X.T @ dZ |
| 2541 | dB = dZ.sum(axis=0, keepdims=True) |
| 2542 | return dX, dW, dB |
| 2543 | |
| 2544 | def _bwd2(self, dLdy, X, dLdy_bwd): |
| 2545 | """Compute second derivatives / deriv. of loss wrt. dX, dW, and db""" |