Localization Network of RARE, which predicts C' (K x 2) from I (I_width x I_height)
| 36 | |
| 37 | |
| 38 | class LocalizationNetwork(nn.Module): |
| 39 | """ Localization Network of RARE, which predicts C' (K x 2) from I (I_width x I_height) """ |
| 40 | |
| 41 | def __init__(self, F, I_channel_num): |
| 42 | super(LocalizationNetwork, self).__init__() |
| 43 | self.F = F |
| 44 | self.I_channel_num = I_channel_num |
| 45 | self.conv = nn.Sequential( |
| 46 | nn.Conv2d(in_channels=self.I_channel_num, out_channels=64, kernel_size=3, stride=1, padding=1, |
| 47 | bias=False), nn.BatchNorm2d(64), nn.ReLU(True), |
| 48 | nn.MaxPool2d(2, 2), # batch_size x 64 x I_height/2 x I_width/2 |
| 49 | nn.Conv2d(64, 128, 3, 1, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True), |
| 50 | nn.MaxPool2d(2, 2), # batch_size x 128 x I_height/4 x I_width/4 |
| 51 | nn.Conv2d(128, 256, 3, 1, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True), |
| 52 | nn.MaxPool2d(2, 2), # batch_size x 256 x I_height/8 x I_width/8 |
| 53 | nn.Conv2d(256, 512, 3, 1, 1, bias=False), nn.BatchNorm2d(512), nn.ReLU(True), |
| 54 | nn.AdaptiveAvgPool2d(1) # batch_size x 512 |
| 55 | ) |
| 56 | |
| 57 | self.localization_fc1 = nn.Sequential(nn.Linear(512, 256), nn.ReLU(True)) |
| 58 | self.localization_fc2 = nn.Linear(256, self.F * 2) |
| 59 | |
| 60 | # Init fc2 in LocalizationNetwork |
| 61 | self.localization_fc2.weight.data.fill_(0) |
| 62 | """ see RARE paper Fig. 6 (a) """ |
| 63 | ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) |
| 64 | ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2)) |
| 65 | ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2)) |
| 66 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) |
| 67 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) |
| 68 | initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) |
| 69 | self.localization_fc2.bias.data = torch.from_numpy(initial_bias).float().view(-1) |
| 70 | |
| 71 | def forward(self, batch_I): |
| 72 | """ |
| 73 | input: batch_I : Batch Input Image [batch_size x I_channel_num x I_height x I_width] |
| 74 | output: batch_C_prime : Predicted coordinates of fiducial points for input batch [batch_size x F x 2] |
| 75 | """ |
| 76 | batch_size = batch_I.size(0) |
| 77 | features = self.conv(batch_I).view(batch_size, -1) |
| 78 | batch_C_prime = self.localization_fc2(self.localization_fc1(features)).view(batch_size, self.F, 2) |
| 79 | return batch_C_prime |
| 80 | |
| 81 | |
| 82 | class GridGenerator(nn.Module): |
no outgoing calls
no test coverage detected
searching dependent graphs…