Compute the module output on a single minibatch. Parameters ---------- X_main : :py:class:`ndarray ` of shape `(n_ex, in_rows, in_cols, in_ch)` The input volume consisting of `n_ex` examples, each with dimension (`in_rows`, `in
(self, X_main, X_skip=None)
| 285 | } |
| 286 | |
| 287 | def forward(self, X_main, X_skip=None): |
| 288 | """ |
| 289 | Compute the module output on a single minibatch. |
| 290 | |
| 291 | Parameters |
| 292 | ---------- |
| 293 | X_main : :py:class:`ndarray <numpy.ndarray>` of shape `(n_ex, in_rows, in_cols, in_ch)` |
| 294 | The input volume consisting of `n_ex` examples, each with dimension |
| 295 | (`in_rows`, `in_cols`, `in_ch`). |
| 296 | X_skip : :py:class:`ndarray <numpy.ndarray>` of shape `(n_ex, in_rows, in_cols, in_ch)`, or None |
| 297 | The output of the preceding skip-connection if this is not the |
| 298 | first module in the network. |
| 299 | |
| 300 | Returns |
| 301 | ------- |
| 302 | Y_main : :py:class:`ndarray <numpy.ndarray>` of shape `(n_ex, out_rows, out_cols, out_ch)` |
| 303 | The output of the main pathway. |
| 304 | Y_skip : :py:class:`ndarray <numpy.ndarray>` of shape `(n_ex, out_rows, out_cols, out_ch)` |
| 305 | The output of the skip-connection pathway. |
| 306 | """ |
| 307 | self.X_main, self.X_skip = X_main, X_skip |
| 308 | conv_dilation_out = self.conv_dilation.forward(X_main) |
| 309 | |
| 310 | tanh_gate = self.tanh.fn(conv_dilation_out) |
| 311 | sigm_gate = self.sigm.fn(conv_dilation_out) |
| 312 | |
| 313 | multiply_gate_out = self.multiply_gate.forward([tanh_gate, sigm_gate]) |
| 314 | conv_1x1_out = self.conv_1x1.forward(multiply_gate_out) |
| 315 | |
| 316 | # if this is the first wavenet block, initialize the "previous" skip |
| 317 | # connection sum to 0 |
| 318 | self.X_skip = np.zeros_like(conv_1x1_out) if X_skip is None else X_skip |
| 319 | |
| 320 | Y_skip = self.add_skip.forward([X_skip, conv_1x1_out]) |
| 321 | Y_main = self.add_residual.forward([X_main, conv_1x1_out]) |
| 322 | |
| 323 | self._dv["tanh_out"] = tanh_gate |
| 324 | self._dv["sigm_out"] = sigm_gate |
| 325 | self._dv["conv_dilation_out"] = conv_dilation_out |
| 326 | self._dv["multiply_gate_out"] = multiply_gate_out |
| 327 | self._dv["conv_1x1_out"] = conv_1x1_out |
| 328 | return Y_main, Y_skip |
| 329 | |
| 330 | def backward(self, dY_skip, dY_main=None): |
| 331 | dX_skip, dConv_1x1_out = self.add_skip.backward(dY_skip) |