MCPcopy
hub / github.com/ddbourgin/numpy-ml / TorchRNNCell

Class TorchRNNCell

numpy_ml/tests/nn_torch_models.py:1462–1526  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

1460
1461
1462class TorchRNNCell(nn.Module):
1463 def __init__(self, n_in, n_hid, params, **kwargs):
1464 super(TorchRNNCell, self).__init__()
1465
1466 self.layer1 = nn.RNNCell(n_in, n_hid, bias=True, nonlinearity="tanh")
1467
1468 # set weights and bias to match those of RNNCell
1469 # NB: we pass the *transpose* of the RNNCell weights and biases to
1470 # pytorch, meaning we need to check against the *transpose* of our
1471 # outputs for any function of the weights
1472 self.layer1.weight_ih = nn.Parameter(torch.FloatTensor(params["Wax"].T))
1473 self.layer1.weight_hh = nn.Parameter(torch.FloatTensor(params["Waa"].T))
1474 self.layer1.bias_ih = nn.Parameter(torch.FloatTensor(params["bx"].T))
1475 self.layer1.bias_hh = nn.Parameter(torch.FloatTensor(params["ba"].T))
1476
1477 def forward(self, X):
1478 self.X = X
1479 if not isinstance(self.X, torch.Tensor):
1480 self.X = torchify(self.X)
1481
1482 self.X.retain_grad()
1483
1484 # initial hidden state is 0
1485 n_ex, n_in, n_timesteps = self.X.shape
1486 n_out, n_out = self.layer1.weight_hh.shape
1487
1488 # initialize hidden states
1489 a0 = torchify(np.zeros((n_ex, n_out)))
1490 a0.retain_grad()
1491
1492 # forward pass
1493 A = []
1494 at = a0
1495 for t in range(n_timesteps):
1496 A += [at]
1497 at1 = self.layer1(self.X[:, :, t], at)
1498 at.retain_grad()
1499 at = at1
1500
1501 at.retain_grad()
1502 A += [at]
1503
1504 # don't inclue a0 in our outputs
1505 self.A = A[1:]
1506 return self.A
1507
1508 def extract_grads(self, X):
1509 self.forward(X)
1510 self.loss = torch.stack(self.A).sum()
1511 self.loss.backward()
1512 grads = {
1513 "X": self.X.detach().numpy(),
1514 "ba": self.layer1.bias_hh.detach().numpy(),
1515 "bx": self.layer1.bias_ih.detach().numpy(),
1516 "Wax": self.layer1.weight_ih.detach().numpy(),
1517 "Waa": self.layer1.weight_hh.detach().numpy(),
1518 "y": torch.stack(self.A).detach().numpy(),
1519 "dLdA": np.array([a.grad.numpy() for a in self.A]),

Callers 1

test_RNNCellFunction · 0.85

Calls

no outgoing calls

Tested by 1

test_RNNCellFunction · 0.68