(self, features)
| 210 | return torch.cat(patches_batch, dim=0) |
| 211 | |
| 212 | def forward(self, features): |
| 213 | if self.training and self.config.out_ref: |
| 214 | outs_gdt_pred = [] |
| 215 | outs_gdt_label = [] |
| 216 | x, x1, x2, x3, x4, gdt_gt = features |
| 217 | else: |
| 218 | x, x1, x2, x3, x4 = features |
| 219 | outs = [] |
| 220 | |
| 221 | if self.config.dec_ipt: |
| 222 | patches_batch = self.get_patches_batch(x, x4) if self.split else x |
| 223 | x4 = torch.cat((x4, self.ipt_blk5(F.interpolate(patches_batch, size=x4.shape[2:], mode='bilinear', align_corners=True))), 1) |
| 224 | p4 = self.decoder_block4(x4) |
| 225 | m4 = self.conv_ms_spvn_4(p4) if self.config.ms_supervision else None |
| 226 | if self.config.out_ref: |
| 227 | p4_gdt = self.gdt_convs_4(p4) |
| 228 | if self.training: |
| 229 | # >> GT: |
| 230 | m4_dia = m4 |
| 231 | gdt_label_main_4 = gdt_gt * F.interpolate(m4_dia, size=gdt_gt.shape[2:], mode='bilinear', align_corners=True) |
| 232 | outs_gdt_label.append(gdt_label_main_4) |
| 233 | # >> Pred: |
| 234 | gdt_pred_4 = self.gdt_convs_pred_4(p4_gdt) |
| 235 | outs_gdt_pred.append(gdt_pred_4) |
| 236 | gdt_attn_4 = self.gdt_convs_attn_4(p4_gdt).sigmoid() |
| 237 | # >> Finally: |
| 238 | p4 = p4 * gdt_attn_4 |
| 239 | _p4 = F.interpolate(p4, size=x3.shape[2:], mode='bilinear', align_corners=True) |
| 240 | _p3 = _p4 + self.lateral_block4(x3) |
| 241 | |
| 242 | if self.config.dec_ipt: |
| 243 | patches_batch = self.get_patches_batch(x, _p3) if self.split else x |
| 244 | _p3 = torch.cat((_p3, self.ipt_blk4(F.interpolate(patches_batch, size=x3.shape[2:], mode='bilinear', align_corners=True))), 1) |
| 245 | p3 = self.decoder_block3(_p3) |
| 246 | m3 = self.conv_ms_spvn_3(p3) if self.config.ms_supervision else None |
| 247 | if self.config.out_ref: |
| 248 | p3_gdt = self.gdt_convs_3(p3) |
| 249 | if self.training: |
| 250 | # >> GT: |
| 251 | # m3 --dilation--> m3_dia |
| 252 | # G_3^gt * m3_dia --> G_3^m, which is the label of gradient |
| 253 | m3_dia = m3 |
| 254 | gdt_label_main_3 = gdt_gt * F.interpolate(m3_dia, size=gdt_gt.shape[2:], mode='bilinear', align_corners=True) |
| 255 | outs_gdt_label.append(gdt_label_main_3) |
| 256 | # >> Pred: |
| 257 | # p3 --conv--BN--> F_3^G, where F_3^G predicts the \hat{G_3} with xx |
| 258 | # F_3^G --sigmoid--> A_3^G |
| 259 | gdt_pred_3 = self.gdt_convs_pred_3(p3_gdt) |
| 260 | outs_gdt_pred.append(gdt_pred_3) |
| 261 | gdt_attn_3 = self.gdt_convs_attn_3(p3_gdt).sigmoid() |
| 262 | # >> Finally: |
| 263 | # p3 = p3 * A_3^G |
| 264 | p3 = p3 * gdt_attn_3 |
| 265 | _p3 = F.interpolate(p3, size=x2.shape[2:], mode='bilinear', align_corners=True) |
| 266 | _p2 = _p3 + self.lateral_block3(x2) |
| 267 | |
| 268 | if self.config.dec_ipt: |
| 269 | patches_batch = self.get_patches_batch(x, _p2) if self.split else x |
nothing calls this directly
no test coverage detected