MCPcopy
hub / github.com/ddbourgin/numpy-ml / forward

Method forward

numpy_ml/neural_nets/layers/layers.py:3006–3046  ·  view source on GitHub ↗

Compute the layer output given input volume `X`. Parameters ---------- X : :py:class:`ndarray ` of shape `(n_ex, in_rows, in_cols, in_ch)` The input volume consisting of `n_ex` examples, each with dimension (`in_rows`, `in_cols

(self, X, retain_derived=True)

Source from the content-addressed store, hash-verified

3004 }
3005
3006 def forward(self, X, retain_derived=True):
3007 """
3008 Compute the layer output given input volume `X`.
3009
3010 Parameters
3011 ----------
3012 X : :py:class:`ndarray <numpy.ndarray>` of shape `(n_ex, in_rows, in_cols, in_ch)`
3013 The input volume consisting of `n_ex` examples, each with dimension
3014 (`in_rows`, `in_cols`, `in_ch`).
3015 retain_derived : bool
3016 Whether to retain the variables calculated during the forward pass
3017 for use later during backprop. If False, this suggests the layer
3018 will not be expected to backprop through wrt. this input. Default
3019 is True.
3020
3021 Returns
3022 -------
3023 Y : :py:class:`ndarray <numpy.ndarray>` of shape `(n_ex, out_rows, out_cols, out_ch)`
3024 The layer output.
3025 """ # noqa: E501
3026 if not self.is_initialized:
3027 self.in_ch = X.shape[3]
3028 self._init_params()
3029
3030 W = self.parameters["W"]
3031 b = self.parameters["b"]
3032
3033 n_ex, in_rows, in_cols, in_ch = X.shape
3034 s, p, d = self.stride, self.pad, self.dilation
3035
3036 # pad the input and perform the forward convolution
3037 Z = conv2D(X, W, s, p, d) + b
3038 Y = self.act_fn(Z)
3039
3040 if retain_derived:
3041 self.X.append(X)
3042 self.derived_variables["Z"].append(Z)
3043 self.derived_variables["out_rows"].append(Z.shape[1])
3044 self.derived_variables["out_cols"].append(Z.shape[2])
3045
3046 return Y
3047
3048 def backward(self, dLdy, retain_grads=True):
3049 """

Callers 1

test_Conv2DFunction · 0.95

Calls 3

_init_paramsMethod · 0.95
conv2DFunction · 0.85
act_fnMethod · 0.80

Tested by 1

test_Conv2DFunction · 0.76