Args: - convouts (list): A list of convouts for the corresponding layers in in_channels. Returns: - A list of FPN convouts in the same order as x with extra downsample layers if requested.
(self, convouts:List[torch.Tensor])
| 309 | |
| 310 | @script_method_wrapper |
| 311 | def forward(self, convouts:List[torch.Tensor]): |
| 312 | """ |
| 313 | Args: |
| 314 | - convouts (list): A list of convouts for the corresponding layers in in_channels. |
| 315 | Returns: |
| 316 | - A list of FPN convouts in the same order as x with extra downsample layers if requested. |
| 317 | """ |
| 318 | |
| 319 | out = [] |
| 320 | x = torch.zeros(1, device=convouts[0].device) |
| 321 | for i in range(len(convouts)): |
| 322 | out.append(x) |
| 323 | |
| 324 | # For backward compatability, the conv layers are stored in reverse but the input and output is |
| 325 | # given in the correct order. Thus, use j=-i-1 for the input and output and i for the conv layers. |
| 326 | j = len(convouts) |
| 327 | for lat_layer in self.lat_layers: |
| 328 | j -= 1 |
| 329 | |
| 330 | if j < len(convouts) - 1: |
| 331 | _, _, h, w = convouts[j].size() |
| 332 | x = F.interpolate(x, size=(h, w), mode=self.interpolation_mode, align_corners=False) |
| 333 | |
| 334 | x = x + lat_layer(convouts[j]) |
| 335 | out[j] = x |
| 336 | |
| 337 | # This janky second loop is here because TorchScript. |
| 338 | j = len(convouts) |
| 339 | for pred_layer in self.pred_layers: |
| 340 | j -= 1 |
| 341 | out[j] = pred_layer(out[j]) |
| 342 | |
| 343 | if self.relu_pred_layers: |
| 344 | F.relu(out[j], inplace=True) |
| 345 | |
| 346 | cur_idx = len(out) |
| 347 | |
| 348 | # In the original paper, this takes care of P6 |
| 349 | if self.use_conv_downsample: |
| 350 | for downsample_layer in self.downsample_layers: |
| 351 | out.append(downsample_layer(out[-1])) |
| 352 | else: |
| 353 | for idx in range(self.num_downsample): |
| 354 | # Note: this is an untested alternative to out.append(out[-1][:, :, ::2, ::2]). Thanks TorchScript. |
| 355 | out.append(nn.functional.max_pool2d(out[-1], 1, stride=2)) |
| 356 | |
| 357 | if self.relu_downsample_layers: |
| 358 | for idx in range(len(out) - cur_idx): |
| 359 | out[idx] = F.relu(out[idx + cur_idx], inplace=False) |
| 360 | |
| 361 | return out |
| 362 | |
| 363 | class FastMaskIoUNet(ScriptModuleWrapper): |
| 364 |