ResNets without fully connected layer
| 16 | |
| 17 | |
| 18 | class ResNet(models.ResNet): |
| 19 | """ResNets without fully connected layer""" |
| 20 | |
| 21 | def __init__(self, *args, **kwargs): |
| 22 | super(ResNet, self).__init__(*args, **kwargs) |
| 23 | self._out_features = self.fc.in_features |
| 24 | |
| 25 | def forward(self, x): |
| 26 | """""" |
| 27 | x = self.conv1(x) |
| 28 | x = self.bn1(x) |
| 29 | x = self.relu(x) |
| 30 | x = self.maxpool(x) |
| 31 | |
| 32 | x = self.layer1(x) |
| 33 | x = self.layer2(x) |
| 34 | x = self.layer3(x) |
| 35 | x = self.layer4(x) |
| 36 | |
| 37 | # x = self.avgpool(x) |
| 38 | # x = torch.flatten(x, 1) |
| 39 | # x = x.view(-1, self._out_features) |
| 40 | return x |
| 41 | |
| 42 | @property |
| 43 | def out_features(self) -> int: |
| 44 | """The dimension of output features""" |
| 45 | return self._out_features |
| 46 | |
| 47 | def copy_head(self) -> nn.Module: |
| 48 | """Copy the origin fully connected layer""" |
| 49 | return copy.deepcopy(self.fc) |
| 50 | |
| 51 | |
| 52 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): |