| 165 | class SupConResNet(nn.Module): |
| 166 | """backbone + projection head""" |
| 167 | def __init__(self, name='resnet50', head='mlp', feat_dim=128): |
| 168 | super(SupConResNet, self).__init__() |
| 169 | model_fun, dim_in = model_dict[name] |
| 170 | self.encoder = model_fun() |
| 171 | if head == 'linear': |
| 172 | self.head = nn.Linear(dim_in, feat_dim) |
| 173 | elif head == 'mlp': |
| 174 | self.head = nn.Sequential( |
| 175 | nn.Linear(dim_in, dim_in), |
| 176 | nn.ReLU(inplace=True), |
| 177 | nn.Linear(dim_in, feat_dim) |
| 178 | ) |
| 179 | else: |
| 180 | raise NotImplementedError( |
| 181 | 'head not supported: {}'.format(head)) |
| 182 | |
| 183 | def forward(self, x): |
| 184 | feat = self.encoder(x) |