| 89 | |
| 90 | class MLPHead(nn.Module): |
| 91 | def __init__(self, hidden_size, out_dim, num_layers=2, bottleneck_dim=256): |
| 92 | super().__init__() |
| 93 | self._num_layers = num_layers |
| 94 | self.mlp = nn.ModuleList() |
| 95 | for i in range(num_layers): |
| 96 | if i == num_layers - 1: |
| 97 | self.mlp.append( |
| 98 | nn.Linear(hidden_size, bottleneck_dim) |
| 99 | ) |
| 100 | else: |
| 101 | self.mlp.append(nn.Linear(hidden_size, hidden_size)) |
| 102 | # self.mlp.append(nn.LayerNorm(hidden_size)) |
| 103 | self.mlp.append(nn.PReLU()) |
| 104 | |
| 105 | self.apply(self._init_weights) |
| 106 | # self.last_layer = nn.Linear(bottleneck_dim, out_dim, bias=False) |
| 107 | self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) |
| 108 | self.last_layer.weight_g.data.fill_(1) |
| 109 | # self.last_layer.weight_g.requires_grad = False |
| 110 | # self.last_layer = nn.Linear(bottleneck_dim, out_dim, bias=False) |
| 111 | |
| 112 | def _init_weights(self, m): |
| 113 | if isinstance(m, nn.Linear): |