| 328 | return Y_main, Y_skip |
| 329 | |
| 330 | def backward(self, dY_skip, dY_main=None): |
| 331 | dX_skip, dConv_1x1_out = self.add_skip.backward(dY_skip) |
| 332 | |
| 333 | # if this is the last wavenet block, dY_main will be None. if not, |
| 334 | # calculate the error contribution from dY_main and add it to the |
| 335 | # contribution from the skip path |
| 336 | dX_main = np.zeros_like(self.X_main) |
| 337 | if dY_main is not None: |
| 338 | dX_main, dConv_1x1_main = self.add_residual.backward(dY_main) |
| 339 | dConv_1x1_out += dConv_1x1_main |
| 340 | |
| 341 | dMultiply_out = self.conv_1x1.backward(dConv_1x1_out) |
| 342 | dTanh_out, dSigm_out = self.multiply_gate.backward(dMultiply_out) |
| 343 | |
| 344 | conv_dilation_out = self.derived_variables["conv_dilation_out"] |
| 345 | dTanh_in = dTanh_out * self.tanh.grad(conv_dilation_out) |
| 346 | dSigm_in = dSigm_out * self.sigm.grad(conv_dilation_out) |
| 347 | dDilation_out = dTanh_in + dSigm_in |
| 348 | |
| 349 | conv_back = self.conv_dilation.backward(dDilation_out) |
| 350 | dX_main += conv_back |
| 351 | |
| 352 | self._dv["dLdTanh"] = dTanh_out |
| 353 | self._dv["dLdSigmoid"] = dSigm_out |
| 354 | self._dv["dLdConv_1x1"] = dConv_1x1_out |
| 355 | self._dv["dLdMultiply"] = dMultiply_out |
| 356 | self._dv["dLdConv_dilation"] = dDilation_out |
| 357 | return dX_main, dX_skip |
| 358 | |
| 359 | |
| 360 | class SkipConnectionIdentityModule(ModuleBase): |