(self, name='resnet50', num_classes=10)
| 189 | class SupCEResNet(nn.Module): |
| 190 | """encoder + classifier""" |
| 191 | def __init__(self, name='resnet50', num_classes=10): |
| 192 | super(SupCEResNet, self).__init__() |
| 193 | model_fun, dim_in = model_dict[name] |
| 194 | self.encoder = model_fun() |
| 195 | self.fc = nn.Linear(dim_in, num_classes) |
| 196 | |
| 197 | def forward(self, x): |
| 198 | return self.fc(self.encoder(x)) |