Backprop for a single timestep. Parameters ---------- dLdAt : :py:class:`ndarray ` of shape `(n_ex, n_out)` The gradient of the loss wrt. the layer outputs (ie., hidden states) at timestep `t`. Returns -------
(self, dLdAt)
| 3715 | return At |
| 3716 | |
| 3717 | def backward(self, dLdAt): |
| 3718 | """ |
| 3719 | Backprop for a single timestep. |
| 3720 | |
| 3721 | Parameters |
| 3722 | ---------- |
| 3723 | dLdAt : :py:class:`ndarray <numpy.ndarray>` of shape `(n_ex, n_out)` |
| 3724 | The gradient of the loss wrt. the layer outputs (ie., hidden |
| 3725 | states) at timestep `t`. |
| 3726 | |
| 3727 | Returns |
| 3728 | ------- |
| 3729 | dLdXt : :py:class:`ndarray <numpy.ndarray>` of shape `(n_ex, n_in)` |
| 3730 | The gradient of the loss wrt. the layer inputs at timestep `t`. |
| 3731 | """ |
| 3732 | assert self.trainable, "Layer is frozen" |
| 3733 | |
| 3734 | # decrement current step |
| 3735 | self.derived_variables["current_step"] -= 1 |
| 3736 | |
| 3737 | # extract context variables |
| 3738 | Zs = self.derived_variables["Z"] |
| 3739 | As = self.derived_variables["A"] |
| 3740 | t = self.derived_variables["current_step"] |
| 3741 | dA_acc = self.derived_variables["dLdA_accumulator"] |
| 3742 | |
| 3743 | # initialize accumulator |
| 3744 | if dA_acc is None: |
| 3745 | dA_acc = np.zeros_like(As[0]) |
| 3746 | |
| 3747 | # get network weights for gradient calcs |
| 3748 | Wax = self.parameters["Wax"] |
| 3749 | Waa = self.parameters["Waa"] |
| 3750 | |
| 3751 | # compute gradient components at timestep t |
| 3752 | dA = dLdAt + dA_acc |
| 3753 | dZ = self.act_fn.grad(Zs[t]) * dA |
| 3754 | dXt = dZ @ Wax.T |
| 3755 | |
| 3756 | # update parameter gradients with signal from current step |
| 3757 | self.gradients["Waa"] += As[t].T @ dZ |
| 3758 | self.gradients["Wax"] += self.X[t].T @ dZ |
| 3759 | self.gradients["ba"] += dZ.sum(axis=0, keepdims=True).T |
| 3760 | self.gradients["bx"] += dZ.sum(axis=0, keepdims=True).T |
| 3761 | |
| 3762 | # update accumulator variable for hidden state |
| 3763 | self.derived_variables["dLdA_accumulator"] = dZ @ Waa.T |
| 3764 | return dXt |
| 3765 | |
| 3766 | def flush_gradients(self): |
| 3767 | """Erase all the layer's derived variables and gradients.""" |