| 1016 | return self.A |
| 1017 | |
| 1018 | def extract_grads(self, X): |
| 1019 | self.forward(X) |
| 1020 | self.loss = self.A.sum() |
| 1021 | self.loss.backward() |
| 1022 | |
| 1023 | # forward |
| 1024 | w_ii, w_if, w_ic, w_io = self.layer1.weight_ih_l0.chunk(4, 0) |
| 1025 | w_hi, w_hf, w_hc, w_ho = self.layer1.weight_hh_l0.chunk(4, 0) |
| 1026 | bu_f, bf_f, bc_f, bo_f = self.layer1.bias_ih_l0.chunk(4, 0) |
| 1027 | |
| 1028 | Wu_f = torch.cat([torch.t(w_hi), torch.t(w_ii)], dim=0) |
| 1029 | Wf_f = torch.cat([torch.t(w_hf), torch.t(w_if)], dim=0) |
| 1030 | Wc_f = torch.cat([torch.t(w_hc), torch.t(w_ic)], dim=0) |
| 1031 | Wo_f = torch.cat([torch.t(w_ho), torch.t(w_io)], dim=0) |
| 1032 | |
| 1033 | dw_ii, dw_if, dw_ic, dw_io = self.layer1.weight_ih_l0.grad.chunk(4, 0) |
| 1034 | dw_hi, dw_hf, dw_hc, dw_ho = self.layer1.weight_hh_l0.grad.chunk(4, 0) |
| 1035 | dbu_f, dbf_f, dbc_f, dbo_f = self.layer1.bias_ih_l0.grad.chunk(4, 0) |
| 1036 | |
| 1037 | dWu_f = torch.cat([torch.t(dw_hi), torch.t(dw_ii)], dim=0) |
| 1038 | dWf_f = torch.cat([torch.t(dw_hf), torch.t(dw_if)], dim=0) |
| 1039 | dWc_f = torch.cat([torch.t(dw_hc), torch.t(dw_ic)], dim=0) |
| 1040 | dWo_f = torch.cat([torch.t(dw_ho), torch.t(dw_io)], dim=0) |
| 1041 | |
| 1042 | # backward |
| 1043 | w_ii, w_if, w_ic, w_io = self.layer1.weight_ih_l0_reverse.chunk(4, 0) |
| 1044 | w_hi, w_hf, w_hc, w_ho = self.layer1.weight_hh_l0_reverse.chunk(4, 0) |
| 1045 | bu_b, bf_b, bc_b, bo_b = self.layer1.bias_ih_l0_reverse.chunk(4, 0) |
| 1046 | |
| 1047 | Wu_b = torch.cat([torch.t(w_hi), torch.t(w_ii)], dim=0) |
| 1048 | Wf_b = torch.cat([torch.t(w_hf), torch.t(w_if)], dim=0) |
| 1049 | Wc_b = torch.cat([torch.t(w_hc), torch.t(w_ic)], dim=0) |
| 1050 | Wo_b = torch.cat([torch.t(w_ho), torch.t(w_io)], dim=0) |
| 1051 | |
| 1052 | dw_ii, dw_if, dw_ic, dw_io = self.layer1.weight_ih_l0_reverse.grad.chunk(4, 0) |
| 1053 | dw_hi, dw_hf, dw_hc, dw_ho = self.layer1.weight_hh_l0_reverse.grad.chunk(4, 0) |
| 1054 | dbu_b, dbf_b, dbc_b, dbo_b = self.layer1.bias_ih_l0_reverse.grad.chunk(4, 0) |
| 1055 | |
| 1056 | dWu_b = torch.cat([torch.t(dw_hi), torch.t(dw_ii)], dim=0) |
| 1057 | dWf_b = torch.cat([torch.t(dw_hf), torch.t(dw_if)], dim=0) |
| 1058 | dWc_b = torch.cat([torch.t(dw_hc), torch.t(dw_ic)], dim=0) |
| 1059 | dWo_b = torch.cat([torch.t(dw_ho), torch.t(dw_io)], dim=0) |
| 1060 | |
| 1061 | orig, X_swap = [0, 1, 2], [-1, -3, -2] |
| 1062 | grads = { |
| 1063 | "X": np.moveaxis(self.X.detach().numpy(), orig, X_swap), |
| 1064 | "Wu_f": Wu_f.detach().numpy(), |
| 1065 | "Wf_f": Wf_f.detach().numpy(), |
| 1066 | "Wc_f": Wc_f.detach().numpy(), |
| 1067 | "Wo_f": Wo_f.detach().numpy(), |
| 1068 | "bu_f": bu_f.detach().numpy().reshape(-1, 1), |
| 1069 | "bf_f": bf_f.detach().numpy().reshape(-1, 1), |
| 1070 | "bc_f": bc_f.detach().numpy().reshape(-1, 1), |
| 1071 | "bo_f": bo_f.detach().numpy().reshape(-1, 1), |
| 1072 | "Wu_b": Wu_b.detach().numpy(), |
| 1073 | "Wf_b": Wf_b.detach().numpy(), |
| 1074 | "Wc_b": Wc_b.detach().numpy(), |
| 1075 | "Wo_b": Wo_b.detach().numpy(), |