| 8 | |
| 9 | |
| 10 | class RemapLayer(nn.Module): |
| 11 | def __init__(self, fname): |
| 12 | super().__init__() |
| 13 | with open(fname) as fin: |
| 14 | self.mapping = torch.Tensor( |
| 15 | list(map(int, fin.readlines()))).to(torch.long) |
| 16 | |
| 17 | def forward(self, x): |
| 18 | ''' |
| 19 | x: [batch_size, class] |
| 20 | ''' |
| 21 | B = len(x) |
| 22 | dummy_cls = x.new_zeros((B, 1)) |
| 23 | expand_x = torch.cat([x, dummy_cls], dim=1) |
| 24 | return expand_x[:, self.mapping] |
no outgoing calls
no test coverage detected