| 1527 | |
| 1528 | |
| 1529 | class TorchFCLayer(nn.Module): |
| 1530 | def __init__(self, n_in, n_hid, act_fn, params, **kwargs): |
| 1531 | super(TorchFCLayer, self).__init__() |
| 1532 | self.layer1 = nn.Linear(n_in, n_hid) |
| 1533 | |
| 1534 | # explicitly set weights and bias |
| 1535 | # NB: we pass the *transpose* of the weights to pytorch, meaning |
| 1536 | # we'll need to check against the *transpose* of our outputs for |
| 1537 | # any function of the weights |
| 1538 | self.layer1.weight = nn.Parameter(torch.FloatTensor(params["W"].T)) |
| 1539 | self.layer1.bias = nn.Parameter(torch.FloatTensor(params["b"])) |
| 1540 | |
| 1541 | self.act_fn = act_fn |
| 1542 | self.model = nn.Sequential(self.layer1, self.act_fn) |
| 1543 | |
| 1544 | def forward(self, X): |
| 1545 | self.X = X |
| 1546 | if not isinstance(X, torch.Tensor): |
| 1547 | self.X = torchify(X) |
| 1548 | |
| 1549 | self.z1 = self.layer1(self.X) |
| 1550 | self.z1.retain_grad() |
| 1551 | |
| 1552 | self.out1 = self.act_fn(self.z1) |
| 1553 | self.out1.retain_grad() |
| 1554 | |
| 1555 | def extract_grads(self, X): |
| 1556 | self.forward(X) |
| 1557 | self.loss1 = self.out1.sum() |
| 1558 | self.loss1.backward() |
| 1559 | grads = { |
| 1560 | "X": self.X.detach().numpy(), |
| 1561 | "b": self.layer1.bias.detach().numpy(), |
| 1562 | "W": self.layer1.weight.detach().numpy(), |
| 1563 | "y": self.out1.detach().numpy(), |
| 1564 | "dLdy": self.out1.grad.numpy(), |
| 1565 | "dLdZ": self.z1.grad.numpy(), |
| 1566 | "dLdB": self.layer1.bias.grad.numpy(), |
| 1567 | "dLdW": self.layer1.weight.grad.numpy(), |
| 1568 | "dLdX": self.X.grad.numpy(), |
| 1569 | } |
| 1570 | return grads |
| 1571 | |
| 1572 | |
| 1573 | class TorchEmbeddingLayer(nn.Module): |
no outgoing calls