| 1553 | self.out1.retain_grad() |
| 1554 | |
| 1555 | def extract_grads(self, X): |
| 1556 | self.forward(X) |
| 1557 | self.loss1 = self.out1.sum() |
| 1558 | self.loss1.backward() |
| 1559 | grads = { |
| 1560 | "X": self.X.detach().numpy(), |
| 1561 | "b": self.layer1.bias.detach().numpy(), |
| 1562 | "W": self.layer1.weight.detach().numpy(), |
| 1563 | "y": self.out1.detach().numpy(), |
| 1564 | "dLdy": self.out1.grad.numpy(), |
| 1565 | "dLdZ": self.z1.grad.numpy(), |
| 1566 | "dLdB": self.layer1.bias.grad.numpy(), |
| 1567 | "dLdW": self.layer1.weight.grad.numpy(), |
| 1568 | "dLdX": self.X.grad.numpy(), |
| 1569 | } |
| 1570 | return grads |
| 1571 | |
| 1572 | |
| 1573 | class TorchEmbeddingLayer(nn.Module): |