MCPcopy
hub / github.com/ddbourgin/numpy-ml / backward

Method backward

numpy_ml/neural_nets/layers/layers.py:2756–2797  ·  view source on GitHub ↗

Compute the gradient of the loss with respect to the layer parameters. Notes ----- Relies on :meth:`~numpy_ml.neural_nets.utils.im2col` and :meth:`~numpy_ml.neural_nets.utils.col2im` to vectorize the gradient calculation. See the private method :met

(self, dLdy, retain_grads=True)

Source from the content-addressed store, hash-verified

2754 return Y
2755
2756 def backward(self, dLdy, retain_grads=True):
2757 """
2758 Compute the gradient of the loss with respect to the layer parameters.
2759
2760 Notes
2761 -----
2762 Relies on :meth:`~numpy_ml.neural_nets.utils.im2col` and
2763 :meth:`~numpy_ml.neural_nets.utils.col2im` to vectorize the
2764 gradient calculation. See the private method :meth:`_backward_naive`
2765 for a more straightforward implementation.
2766
2767 Parameters
2768 ----------
2769 dLdy : :py:class:`ndarray <numpy.ndarray>` of shape `(n_ex, l_out, out_ch)` or list of arrays
2770 The gradient(s) of the loss with respect to the layer output(s).
2771 retain_grads : bool
2772 Whether to include the intermediate parameter gradients computed
2773 during the backward pass in the final parameter update. Default is
2774 True.
2775
2776 Returns
2777 -------
2778 dX : :py:class:`ndarray <numpy.ndarray>` of shape `(n_ex, l_in, in_ch)`
2779 The gradient of the loss with respect to the layer input volume.
2780 """ # noqa: E501
2781 assert self.trainable, "Layer is frozen"
2782 if not isinstance(dLdy, list):
2783 dLdy = [dLdy]
2784
2785 X = self.X
2786 Z = self.derived_variables["Z"]
2787
2788 dX = []
2789 for dy, x, z in zip(dLdy, X, Z):
2790 dx, dw, db = self._bwd(dy, x, z)
2791 dX.append(dx)
2792
2793 if retain_grads:
2794 self.gradients["W"] += dw
2795 self.gradients["b"] += db
2796
2797 return dX[0] if len(X) == 1 else dX
2798
2799 def _bwd(self, dLdy, X, Z):
2800 """Actual computation of gradient of the loss wrt. X, W, and b"""

Callers 1

test_Conv1DFunction · 0.95

Calls 1

_bwdMethod · 0.95

Tested by 1

test_Conv1DFunction · 0.76