| 115 | |
| 116 | |
| 117 | class LogisticRegression(nn.Module): |
| 118 | def __init__(self, num_dim, num_class): |
| 119 | super().__init__() |
| 120 | self.linear = nn.Linear(num_dim, num_class) |
| 121 | torch.nn.init.xavier_uniform_(self.linear.weight.data) |
| 122 | self.linear.bias.data.fill_(0.0) |
| 123 | self.cross_entropy = nn.CrossEntropyLoss() |
| 124 | |
| 125 | def forward(self, x, y): |
| 126 | |
| 127 | logits = self.linear(x) |
| 128 | loss = self.cross_entropy(logits, y) |
| 129 | |
| 130 | return logits, loss |
no outgoing calls