Based on RARE TPS input: batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width] I_size : (height, width) of the input image I I_r_size : (height, width) of the rectified image I_r I_channel_num : the number of channels of
(self, F, I_size, I_r_size, I_channel_num=1)
| 9 | """ Rectification Network of RARE, namely TPS based STN """ |
| 10 | |
| 11 | def __init__(self, F, I_size, I_r_size, I_channel_num=1): |
| 12 | """ Based on RARE TPS |
| 13 | input: |
| 14 | batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width] |
| 15 | I_size : (height, width) of the input image I |
| 16 | I_r_size : (height, width) of the rectified image I_r |
| 17 | I_channel_num : the number of channels of the input image I |
| 18 | output: |
| 19 | batch_I_r: rectified image [batch_size x I_channel_num x I_r_height x I_r_width] |
| 20 | """ |
| 21 | super(TPS_SpatialTransformerNetwork, self).__init__() |
| 22 | self.F = F |
| 23 | self.I_size = I_size |
| 24 | self.I_r_size = I_r_size # = (I_r_height, I_r_width) |
| 25 | self.I_channel_num = I_channel_num |
| 26 | self.LocalizationNetwork = LocalizationNetwork(self.F, self.I_channel_num) |
| 27 | self.GridGenerator = GridGenerator(self.F, self.I_r_size) |
| 28 | |
| 29 | def forward(self, batch_I): |
| 30 | batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2 |
nothing calls this directly
no test coverage detected