MCPcopy Index your code
hub / github.com/ZhengPeng7/BiRefNet / forward

Method forward

models/birefnet.py:212–305  ·  view source on GitHub ↗
(self, features)

Source from the content-addressed store, hash-verified

210 return torch.cat(patches_batch, dim=0)
211
212 def forward(self, features):
213 if self.training and self.config.out_ref:
214 outs_gdt_pred = []
215 outs_gdt_label = []
216 x, x1, x2, x3, x4, gdt_gt = features
217 else:
218 x, x1, x2, x3, x4 = features
219 outs = []
220
221 if self.config.dec_ipt:
222 patches_batch = self.get_patches_batch(x, x4) if self.split else x
223 x4 = torch.cat((x4, self.ipt_blk5(F.interpolate(patches_batch, size=x4.shape[2:], mode='bilinear', align_corners=True))), 1)
224 p4 = self.decoder_block4(x4)
225 m4 = self.conv_ms_spvn_4(p4) if self.config.ms_supervision else None
226 if self.config.out_ref:
227 p4_gdt = self.gdt_convs_4(p4)
228 if self.training:
229 # >> GT:
230 m4_dia = m4
231 gdt_label_main_4 = gdt_gt * F.interpolate(m4_dia, size=gdt_gt.shape[2:], mode='bilinear', align_corners=True)
232 outs_gdt_label.append(gdt_label_main_4)
233 # >> Pred:
234 gdt_pred_4 = self.gdt_convs_pred_4(p4_gdt)
235 outs_gdt_pred.append(gdt_pred_4)
236 gdt_attn_4 = self.gdt_convs_attn_4(p4_gdt).sigmoid()
237 # >> Finally:
238 p4 = p4 * gdt_attn_4
239 _p4 = F.interpolate(p4, size=x3.shape[2:], mode='bilinear', align_corners=True)
240 _p3 = _p4 + self.lateral_block4(x3)
241
242 if self.config.dec_ipt:
243 patches_batch = self.get_patches_batch(x, _p3) if self.split else x
244 _p3 = torch.cat((_p3, self.ipt_blk4(F.interpolate(patches_batch, size=x3.shape[2:], mode='bilinear', align_corners=True))), 1)
245 p3 = self.decoder_block3(_p3)
246 m3 = self.conv_ms_spvn_3(p3) if self.config.ms_supervision else None
247 if self.config.out_ref:
248 p3_gdt = self.gdt_convs_3(p3)
249 if self.training:
250 # >> GT:
251 # m3 --dilation--> m3_dia
252 # G_3^gt * m3_dia --> G_3^m, which is the label of gradient
253 m3_dia = m3
254 gdt_label_main_3 = gdt_gt * F.interpolate(m3_dia, size=gdt_gt.shape[2:], mode='bilinear', align_corners=True)
255 outs_gdt_label.append(gdt_label_main_3)
256 # >> Pred:
257 # p3 --conv--BN--> F_3^G, where F_3^G predicts the \hat{G_3} with xx
258 # F_3^G --sigmoid--> A_3^G
259 gdt_pred_3 = self.gdt_convs_pred_3(p3_gdt)
260 outs_gdt_pred.append(gdt_pred_3)
261 gdt_attn_3 = self.gdt_convs_attn_3(p3_gdt).sigmoid()
262 # >> Finally:
263 # p3 = p3 * A_3^G
264 p3 = p3 * gdt_attn_3
265 _p3 = F.interpolate(p3, size=x2.shape[2:], mode='bilinear', align_corners=True)
266 _p2 = _p3 + self.lateral_block3(x2)
267
268 if self.config.dec_ipt:
269 patches_batch = self.get_patches_batch(x, _p2) if self.split else x

Callers

nothing calls this directly

Calls 1

get_patches_batchMethod · 0.95

Tested by

no test coverage detected