MCPcopy
hub / github.com/ddbourgin/numpy-ml / extract_grads

Method extract_grads

numpy_ml/tests/nn_torch_models.py:1411–1459  ·  view source on GitHub ↗
(self, X)

Source from the content-addressed store, hash-verified

1409 return self.A, self.C
1410
1411 def extract_grads(self, X):
1412 self.forward(X)
1413 self.loss = torch.stack(self.A).sum()
1414 self.loss.backward()
1415
1416 w_ii, w_if, w_ic, w_io = self.layer1.weight_ih.chunk(4, 0)
1417 w_hi, w_hf, w_hc, w_ho = self.layer1.weight_hh.chunk(4, 0)
1418 bu, bf, bc, bo = self.layer1.bias_ih.chunk(4, 0)
1419
1420 Wu = torch.cat([torch.t(w_hi), torch.t(w_ii)], dim=0)
1421 Wf = torch.cat([torch.t(w_hf), torch.t(w_if)], dim=0)
1422 Wc = torch.cat([torch.t(w_hc), torch.t(w_ic)], dim=0)
1423 Wo = torch.cat([torch.t(w_ho), torch.t(w_io)], dim=0)
1424
1425 dw_ii, dw_if, dw_ic, dw_io = self.layer1.weight_ih.grad.chunk(4, 0)
1426 dw_hi, dw_hf, dw_hc, dw_ho = self.layer1.weight_hh.grad.chunk(4, 0)
1427 dbu, dbf, dbc, dbo = self.layer1.bias_ih.grad.chunk(4, 0)
1428
1429 dWu = torch.cat([torch.t(dw_hi), torch.t(dw_ii)], dim=0)
1430 dWf = torch.cat([torch.t(dw_hf), torch.t(dw_if)], dim=0)
1431 dWc = torch.cat([torch.t(dw_hc), torch.t(dw_ic)], dim=0)
1432 dWo = torch.cat([torch.t(dw_ho), torch.t(dw_io)], dim=0)
1433
1434 grads = {
1435 "X": self.X.detach().numpy(),
1436 "Wu": Wu.detach().numpy(),
1437 "Wf": Wf.detach().numpy(),
1438 "Wc": Wc.detach().numpy(),
1439 "Wo": Wo.detach().numpy(),
1440 "bu": bu.detach().numpy().reshape(-1, 1),
1441 "bf": bf.detach().numpy().reshape(-1, 1),
1442 "bc": bc.detach().numpy().reshape(-1, 1),
1443 "bo": bo.detach().numpy().reshape(-1, 1),
1444 "C": torch.stack(self.C).detach().numpy(),
1445 "y": np.swapaxes(
1446 np.swapaxes(torch.stack(self.A).detach().numpy(), 1, 0), 1, 2
1447 ),
1448 "dLdA": np.array([a.grad.numpy() for a in self.A]),
1449 "dLdWu": dWu.numpy(),
1450 "dLdWf": dWf.numpy(),
1451 "dLdWc": dWc.numpy(),
1452 "dLdWo": dWo.numpy(),
1453 "dLdBu": dbu.numpy().reshape(-1, 1),
1454 "dLdBf": dbf.numpy().reshape(-1, 1),
1455 "dLdBc": dbc.numpy().reshape(-1, 1),
1456 "dLdBo": dbo.numpy().reshape(-1, 1),
1457 "dLdX": self.X.grad.numpy(),
1458 }
1459 return grads
1460
1461
1462class TorchRNNCell(nn.Module):

Callers 1

test_LSTMCellFunction · 0.95

Calls 2

forwardMethod · 0.95
backwardMethod · 0.45

Tested by 1

test_LSTMCellFunction · 0.76