| 43 | self.bo_diff = np.zeros(mem_cell_ct) |
| 44 | |
| 45 | def apply_diff(self, lr = 1): |
| 46 | self.wg -= lr * self.wg_diff |
| 47 | self.wi -= lr * self.wi_diff |
| 48 | self.wf -= lr * self.wf_diff |
| 49 | self.wo -= lr * self.wo_diff |
| 50 | self.bg -= lr * self.bg_diff |
| 51 | self.bi -= lr * self.bi_diff |
| 52 | self.bf -= lr * self.bf_diff |
| 53 | self.bo -= lr * self.bo_diff |
| 54 | # reset diffs to zero |
| 55 | self.wg_diff = np.zeros_like(self.wg) |
| 56 | self.wi_diff = np.zeros_like(self.wi) |
| 57 | self.wf_diff = np.zeros_like(self.wf) |
| 58 | self.wo_diff = np.zeros_like(self.wo) |
| 59 | self.bg_diff = np.zeros_like(self.bg) |
| 60 | self.bi_diff = np.zeros_like(self.bi) |
| 61 | self.bf_diff = np.zeros_like(self.bf) |
| 62 | self.bo_diff = np.zeros_like(self.bo) |
| 63 | |
| 64 | class LstmState: |
| 65 | def __init__(self, mem_cell_ct, x_dim): |