| 1409 | return self.A, self.C |
| 1410 | |
| 1411 | def extract_grads(self, X): |
| 1412 | self.forward(X) |
| 1413 | self.loss = torch.stack(self.A).sum() |
| 1414 | self.loss.backward() |
| 1415 | |
| 1416 | w_ii, w_if, w_ic, w_io = self.layer1.weight_ih.chunk(4, 0) |
| 1417 | w_hi, w_hf, w_hc, w_ho = self.layer1.weight_hh.chunk(4, 0) |
| 1418 | bu, bf, bc, bo = self.layer1.bias_ih.chunk(4, 0) |
| 1419 | |
| 1420 | Wu = torch.cat([torch.t(w_hi), torch.t(w_ii)], dim=0) |
| 1421 | Wf = torch.cat([torch.t(w_hf), torch.t(w_if)], dim=0) |
| 1422 | Wc = torch.cat([torch.t(w_hc), torch.t(w_ic)], dim=0) |
| 1423 | Wo = torch.cat([torch.t(w_ho), torch.t(w_io)], dim=0) |
| 1424 | |
| 1425 | dw_ii, dw_if, dw_ic, dw_io = self.layer1.weight_ih.grad.chunk(4, 0) |
| 1426 | dw_hi, dw_hf, dw_hc, dw_ho = self.layer1.weight_hh.grad.chunk(4, 0) |
| 1427 | dbu, dbf, dbc, dbo = self.layer1.bias_ih.grad.chunk(4, 0) |
| 1428 | |
| 1429 | dWu = torch.cat([torch.t(dw_hi), torch.t(dw_ii)], dim=0) |
| 1430 | dWf = torch.cat([torch.t(dw_hf), torch.t(dw_if)], dim=0) |
| 1431 | dWc = torch.cat([torch.t(dw_hc), torch.t(dw_ic)], dim=0) |
| 1432 | dWo = torch.cat([torch.t(dw_ho), torch.t(dw_io)], dim=0) |
| 1433 | |
| 1434 | grads = { |
| 1435 | "X": self.X.detach().numpy(), |
| 1436 | "Wu": Wu.detach().numpy(), |
| 1437 | "Wf": Wf.detach().numpy(), |
| 1438 | "Wc": Wc.detach().numpy(), |
| 1439 | "Wo": Wo.detach().numpy(), |
| 1440 | "bu": bu.detach().numpy().reshape(-1, 1), |
| 1441 | "bf": bf.detach().numpy().reshape(-1, 1), |
| 1442 | "bc": bc.detach().numpy().reshape(-1, 1), |
| 1443 | "bo": bo.detach().numpy().reshape(-1, 1), |
| 1444 | "C": torch.stack(self.C).detach().numpy(), |
| 1445 | "y": np.swapaxes( |
| 1446 | np.swapaxes(torch.stack(self.A).detach().numpy(), 1, 0), 1, 2 |
| 1447 | ), |
| 1448 | "dLdA": np.array([a.grad.numpy() for a in self.A]), |
| 1449 | "dLdWu": dWu.numpy(), |
| 1450 | "dLdWf": dWf.numpy(), |
| 1451 | "dLdWc": dWc.numpy(), |
| 1452 | "dLdWo": dWo.numpy(), |
| 1453 | "dLdBu": dbu.numpy().reshape(-1, 1), |
| 1454 | "dLdBf": dbf.numpy().reshape(-1, 1), |
| 1455 | "dLdBc": dbc.numpy().reshape(-1, 1), |
| 1456 | "dLdBo": dbo.numpy().reshape(-1, 1), |
| 1457 | "dLdX": self.X.grad.numpy(), |
| 1458 | } |
| 1459 | return grads |
| 1460 | |
| 1461 | |
| 1462 | class TorchRNNCell(nn.Module): |