(self, name='resnet50', num_classes=10)
| 201 | class LinearClassifier(nn.Module): |
| 202 | """Linear classifier""" |
| 203 | def __init__(self, name='resnet50', num_classes=10): |
| 204 | super(LinearClassifier, self).__init__() |
| 205 | _, feat_dim = model_dict[name] |
| 206 | self.fc = nn.Linear(feat_dim, num_classes) |
| 207 | |
| 208 | def forward(self, features): |
| 209 | return self.fc(features) |