Compute second derivatives / deriv. of loss wrt. dX, dW, and db
(self, dLdy, X, dLdy_bwd)
| 2542 | return dX, dW, dB |
| 2543 | |
| 2544 | def _bwd2(self, dLdy, X, dLdy_bwd): |
| 2545 | """Compute second derivatives / deriv. of loss wrt. dX, dW, and db""" |
| 2546 | W = self.parameters["W"] |
| 2547 | b = self.parameters["b"] |
| 2548 | W_sparse = W * self.parameters["W_mask"] |
| 2549 | |
| 2550 | dZ = self.act_fn.grad(X @ W_sparse + b) |
| 2551 | ddZ = self.act_fn.grad2(X @ W_sparse + b) |
| 2552 | |
| 2553 | ddX = dLdy @ W * dZ |
| 2554 | ddW = dLdy.T @ (dLdy_bwd * dZ) |
| 2555 | ddB = np.sum(dLdy @ W_sparse * dLdy_bwd * ddZ, axis=0, keepdims=True) |
| 2556 | return ddX, ddW, ddB |
| 2557 | |
| 2558 | def update(self): |
| 2559 | """ |