| 21 | return (self.x * self.W).sum() |
| 22 | |
| 23 | class TinyNet: |
| 24 | def __init__(self, tensor): |
| 25 | self.x = _param(tensor, x_init.copy()) |
| 26 | self.W = _param(tensor, W_init.copy()) |
| 27 | self.m = tensor(m_init.copy()) |
| 28 | |
| 29 | def forward(self): |
| 30 | out = self.x.matmul(self.W).relu() |
| 31 | # print(out.detach().numpy()) |
| 32 | out = out.log_softmax(1) |
| 33 | out = out.mul(self.m).add(self.m).sum() |
| 34 | return out |
| 35 | |
| 36 | def step(tensor, optim, steps=1, teeny=False, **kwargs): |
| 37 | net = TeenyNet(tensor) if teeny else TinyNet(tensor) |