MCPcopy
hub / github.com/ddbourgin/numpy-ml / backward

Method backward

numpy_ml/neural_nets/layers/layers.py:3048–3091  ·  view source on GitHub ↗

Compute the gradient of the loss with respect to the layer parameters. Notes ----- Relies on :meth:`~numpy_ml.neural_nets.utils.im2col` and :meth:`~numpy_ml.neural_nets.utils.col2im` to vectorize the gradient calculation. See the private met

(self, dLdy, retain_grads=True)

Source from the content-addressed store, hash-verified

3046 return Y
3047
3048 def backward(self, dLdy, retain_grads=True):
3049 """
3050 Compute the gradient of the loss with respect to the layer parameters.
3051
3052 Notes
3053 -----
3054 Relies on :meth:`~numpy_ml.neural_nets.utils.im2col` and
3055 :meth:`~numpy_ml.neural_nets.utils.col2im` to vectorize the
3056 gradient calculation.
3057
3058 See the private method :meth:`_backward_naive` for a more straightforward
3059 implementation.
3060
3061 Parameters
3062 ----------
3063 dLdy : :py:class:`ndarray <numpy.ndarray>` of shape `(n_ex, out_rows, out_cols, out_ch)` or list of arrays
3064 The gradient(s) of the loss with respect to the layer output(s).
3065 retain_grads : bool
3066 Whether to include the intermediate parameter gradients computed
3067 during the backward pass in the final parameter update. Default is
3068 True.
3069
3070 Returns
3071 -------
3072 dX : :py:class:`ndarray <numpy.ndarray>` of shape `(n_ex, in_rows, in_cols, in_ch)`
3073 The gradient of the loss with respect to the layer input volume.
3074 """ # noqa: E501
3075 assert self.trainable, "Layer is frozen"
3076 if not isinstance(dLdy, list):
3077 dLdy = [dLdy]
3078
3079 dX = []
3080 X = self.X
3081 Z = self.derived_variables["Z"]
3082
3083 for dy, x, z in zip(dLdy, X, Z):
3084 dx, dw, db = self._bwd(dy, x, z)
3085 dX.append(dx)
3086
3087 if retain_grads:
3088 self.gradients["W"] += dw
3089 self.gradients["b"] += db
3090
3091 return dX[0] if len(X) == 1 else dX
3092
3093 def _bwd(self, dLdy, X, Z):
3094 """Actual computation of gradient of the loss wrt. X, W, and b"""

Callers 1

test_Conv2DFunction · 0.95

Calls 1

_bwdMethod · 0.95

Tested by 1

test_Conv2DFunction · 0.76