| 1571 | |
| 1572 | |
| 1573 | class TorchEmbeddingLayer(nn.Module): |
| 1574 | def __init__(self, vocab_size, n_out, params, **kwargs): |
| 1575 | super(TorchEmbeddingLayer, self).__init__() |
| 1576 | self.layer1 = nn.Embedding(vocab_size, n_out) |
| 1577 | |
| 1578 | # explicitly set embedding weights |
| 1579 | self.layer1.weight = nn.Parameter(torch.FloatTensor(params["W"])) |
| 1580 | self.model = nn.Sequential(self.layer1) |
| 1581 | |
| 1582 | def forward(self, X): |
| 1583 | self.X = X |
| 1584 | if not isinstance(X, torch.Tensor): |
| 1585 | self.X = torch.from_numpy(X) |
| 1586 | |
| 1587 | self.out1 = self.layer1(self.X) |
| 1588 | self.out1.retain_grad() |
| 1589 | |
| 1590 | def extract_grads(self, X): |
| 1591 | self.forward(X) |
| 1592 | self.loss1 = self.out1.sum() |
| 1593 | self.loss1.backward() |
| 1594 | grads = { |
| 1595 | "X": self.X.detach().numpy(), |
| 1596 | "W": self.layer1.weight.detach().numpy(), |
| 1597 | "y": self.out1.detach().numpy(), |
| 1598 | "dLdy": self.out1.grad.numpy(), |
| 1599 | "dLdW": self.layer1.weight.grad.numpy(), |
| 1600 | } |
| 1601 | return grads |
| 1602 | |
| 1603 | |
| 1604 | class TorchSDPAttentionLayer(nn.Module): |
no outgoing calls